TensorFlow 数据输入方式

最近使用TensorFlow时候发现读入图片数据的时候很难受,自己写的一些代码又感觉很蹩脚,官方的文档易读性太低,太费劲了.
打算好好整理一下数据输入这部分代码.

方式

直接加载到内存中

这种方式简直暴力,老早之前我数据没有很多,就是这么做的,把数据先用pickle,将其保存下来,运行的时候直接从保存的文件中获取,这种方式只适合一些小型的文本文件数据,图片数据还是算了,电脑会卡死的.

动态输入placeholder

在每次训练时加载一些数据到内存当中,以placeholder的形式feed到graph当中.
这种方式需要自己写一些代码,比如

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class ImageData:
def ___init__(self, batch_size):
# define some attr
self.batch_size = batch_size
def next_batch(self):
....
return np.array([imread('name%s.png'%i for i in range(self.batch_seze))]), labels
...

image_reader = ImageData(batch_size)
for step in range(10001):
batch = image_reader.next_batch()
sess.run(batch, feed_dict={x: batch[0], y: batch[1]})
...

在训练时实例化相应的类, 每次运行加载一次,需要说明的是,这样需要大量地重复读写多个小文件,所以把这些文件放到固态硬盘下面对提升读取速度是个不错的选择.

这个方法是比较类似于pytorch的datasetdataloader的,同时采用pytorch的这两个类来实现TF中的数据输入也是很可行的.

详情见Data Loading and Processing Tutorial.

tfrecord

TF推荐的一种文件格式为tfrecord,把多个零碎的原始文件编码为一个较大的tfrecord文件,运行时通过解码器将其转化为Tensors, 这篇tutorial解释了TF中读入数据的方法,这里代码已经写得很好了,但读完发现还是不会……

可能用到的TF对象和方法, (先大致看名字了解一下,详细后面会涉及到):

数据转化

首先通过一个tf.TFRecordWriter对象将原始零碎小文件转化为tfrecord格式文件,代码参考某个地方的:

1
2
3
4
5
6
7
8
9
10
11
12
writer = tf.python_io.TFRecordWriter('./data/test_deepmatching.tfrecords') # 初始化TFRecordWriter对象

# write exanple
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())

....
writer.close()

读入tfrecord

TF将文件名生成队列,并行解码读取里面的Tensors.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})

img = tf.decode_raw(features['img_raw'], tf.uint8)
print img.shape
img = tf.reshape(img, [65, 130, 1])
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
# pre-processing
label = tf.cast(features['label'], tf.int32)
return img, label

上述函数返回我们的img 和label ops,运行时候还需要run一下这几个ops.

为了实现并行读取,我们需要创建多个线程,使用QueueRunner兑现来运行.

需要一个Coordnator对象来管理这些线程.

1
2
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

训练时,

1
2
3
4
5
6
for step in range(10001):
batch = sess.run([img_batch, label_batch])
sess.run(train_op)
...

coord.request_stop()

上面这种读取不是很好,可能会遇到一些错误,导致线程无法正常关闭,还在等待读取队列,必须强行kill掉才可以,可以采用官方的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
try:
while not coord.should_stop():
# Run training steps or whatever
sess.run(train_op)

except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()