1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > TensorFlow TFRecord数据集的生成与显示

TensorFlow TFRecord数据集的生成与显示

时间:2020-07-15 08:30:05

相关推荐

TensorFlow TFRecord数据集的生成与显示

##TFRecord##

TensorFlow提供了TFRecord的格式来统一存储数据,TFRecord格式是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等。

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为Tensor。

##Image to TFRecord##

首先我们使用TensorFlow提供的Flowers数据集做这个实验,数据集在我本地的路径为:

这是一个五分类的数据,以类别的形式组织数据,这非常符合我们自己组织数据集的习惯。其中一个分类中大概有700张左右的图片:

现在我们就把上面的数据制作出TFRecord,在这里需要说明下,TFRecord的生成要注意两点:

1.很多时候,我们的图片尺寸并不是统一的,所以在生成的TFRecord中需要包含图像的width和height这两个信息,这样在解析图片的时候,我们才能把二进制的数据重新reshape成图片;

2.TensorFlow官方的建议是一个TFRecord中最好图片的数量为1000张左右,这个很好理解,如果我们有上万张图片,却只打成一个包,这样是很不利于多线程读取的。所以我们需要根据图像数据自动去选择到底打包几个TFRecord出来。

我们可以用下面的代码实现这两个目的:

import os import tensorflow as tf from PIL import Image #图片路径cwd = 'F:\\flowersdata\\trainimages\\'#文件路径filepath = 'F:\\flowersdata\\tfrecord\\'#存放图片个数bestnum = 1000#第几个图片num = 0#第几个TFRecord文件recordfilenum = 0#类别classes=['daisy','dandelion','roses','sunflowers','tulips']#tfrecords格式文件名ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)#类别和路径for index,name in enumerate(classes):print(index)print(name)class_path=cwd+name+'\\'for img_name in os.listdir(class_path): num=num+1if num>bestnum:num = 1recordfilenum = recordfilenum + 1#tfrecords格式文件名ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)#print('路径',class_path)#print('第几个图片:',num)#print('文件的个数',recordfilenum)#print('图片名:',img_name)img_path = class_path+img_name #每一个图片的地址img=Image.open(img_path,'r')size = img.sizeprint(size[1],size[0])print(size)#print(img.mode)img_raw=img.tobytes()#将图片转化为二进制格式example = tf.train.Example(features=tf.train.Features(feature={'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),'img_width':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),'img_height':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))})) writer.write(example.SerializeToString()) #序列化为字符串writer.close()

在上面的代码中,我们规定了一个TFRecord中只放1000张图:

bestnum = 1000

并且将一张图的4个信息打包到TFRecord中,分别是:

example = tf.train.Example(features=tf.train.Features(feature={'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),'img_width':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),'img_height':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))}))

##TFRecord to Image##

在上面我们打包了四个TFRecord文件,下面我们把这些数据读取并显示出来,看看制作的效果,这个过程很大一部分是和TensorFlow组织batch是一样的了。

import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt#写入图片路径swd = 'F:\\flowersdata\\show\\'#TFRecord文件路径data_path = 'F:\\flowersdata\\tfrecord\\traindata.tfrecords-003'# 获取文件名列表data_files = tf.gfile.Glob(data_path)print(data_files)# 文件名列表生成器filename_queue = tf.train.string_input_producer(data_files,shuffle=True) reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue) #返回文件名和文件features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),'img_width': tf.FixedLenFeature([], tf.int64),'img_height': tf.FixedLenFeature([], tf.int64),}) #取出包含image和label的feature对象#tf.decode_raw可以将字符串解析成图像对应的像素数组image = tf.decode_raw(features['img_raw'], tf.uint8)height = tf.cast(features['img_height'],tf.int32)width = tf.cast(features['img_width'],tf.int32)label = tf.cast(features['label'], tf.int32)channel = 3image = tf.reshape(image, [height,width,channel])with tf.Session() as sess: #开始一个会话init_op = tf.initialize_all_variables()sess.run(init_op)#启动多线程coord=tf.train.Coordinator()threads= tf.train.start_queue_runners(coord=coord)for i in range(15):#image_down = np.asarray(image_down.eval(), dtype='uint8')plt.imshow(image.eval())plt.show()single,l = sess.run([image,label])#在会话中取出image和labelimg=Image.fromarray(single, 'RGB')#这里Image是之前提到的img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片#print(single,l)coord.request_stop()coord.join(threads)

注意:

1.我们在使用reshape去将二进制数据重新变成图片的时候,用的就是之前打包进去的width和height,否则程序会出错;

image = tf.reshape(image, [height,width,channel])

2.在图片存储时的命名方式为:mun_Label_calss id

3.代码也可以实时show出当前的图片:

完整代码也可以点击这里下载。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。