简介
TFRecord是一种二进制格式文件,理论上可以保存任何格式的信息,可以将任何类型数据转化为Tensorflow所支持的格式,这种方法可以让数据集和网络架构更容易相互匹配。TFRecord文件包含了tf.train.Example协议内存块(protocol buffer)。可以将数据填入Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecord文件。
Protocol Buffer(protobuf)是Google公司出口的一种独立于开发语言,独立于平台的可扩展结构化徐磊机制。与xml、json类似。是一种数据交互格式协议。
Example消息体:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| message Example { Features features = 1; };
message Features { // Map from feature name to feature. map<string, Feature> feature = 1; };
message Feature { // Each feature can be exactly one kind. oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } };
message BytesList { repeated bytes value = 1; } message FloatList { repeated float value = 1 [packed = true]; } message Int64List { repeated int64 value = 1 [packed = true]; }
|
一个包含一系列的feature属性。灭一个feature是一个map,也就是key-value键值对的形式,key取值是String类型,而value是feature类型的消息体,他的类型有三种:
- BytesList
- FloatList
- Int64List
它们都是列表形式。
图片转TfRecord格式
python代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| import os import numpy as np from PIL import Image import tensorflow as tf
filename = "./images/train.txt"
file_path = r"/image/tfrecord/train/"
bestnum = 2000
num = 0
recordfilenum = 0
def load_images_labels(filename = filename): images_list = [] labels_list = [] with open(filename) as f: lines = f.readlines() for line in lines: content = line.rstrip().split(' ') name = content[0] for value in content[1:]: labels_list.append(int(value)) images_list.append(name) state = np.random.get_state() np.random.shuffle(images_list) np.random.set_state(state) np.random.shuffle(labels_list)
return images_list, labels_list
if __name__ == '__main__':
ftrecordfilename = ("train-%d.tfrecords" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename))
images_list, labels_list = load_images_labels()
for i, image_path in enumerate(images_list):
if i !=0 and i % bestnum == 0: print("(%d/%d) 已完成" % (i, len(images_list))) recordfilenum += 1 ftrecordfilename = ("train-%d.tfrecords" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename))
img = Image.open(image_path, 'r') img_raw = img.tobytes() example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels_list[i]])), 'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), })) writer.write(example.SerializeToString()) writer.close()
|
读取TFRecord格式文件
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| import tensorflow as tf
img_size = 224
def read_and_decode_tfrecord(filename): filename_deque = tf.train.string_input_producer(filename) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_deque) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string)}) label = tf.cast(features['label'], tf.int32) img = tf.decode_raw(features['image_raw'], tf.uint8) img = tf.reshape(img, [img_size, img_size, 3])
img = tf.image.convert_image_dtype(img, dtype=tf.float32) label = tf.one_hot(label, 8) return img, label
train_list = ['./images/test/train-0.tfrecords', './images/test/train-1.tfrecords']
img, label = read_and_decode_tfrecord(train_list)
img_batch, label_batch = tf.train.batch([img, label], num_threads=1, batch_size=10, capacity=1000)
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(5): b_image, b_label = sess.run([img_batch, label_batch]) print(b_label) coord.request_stop() coord.join(threads)
|