1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > tensorflow--制作数据集tfrecords文件

tensorflow--制作数据集tfrecords文件

时间:2019-02-22 22:51:49

相关推荐

tensorflow--制作数据集tfrecords文件

利用tensorflow生成tfrecords文件,制作自己的数据集。

目录

1. tfrecords文件2. 利用图像数据集生成tfrecords文件3. 读取解析tfrecords文件

1. tfrecords文件

tfrecords是一种二进制文件,对内存较为友好,在使用tensorflow训练过程中可以多线程获取数据,放入内存,而不是一次性全部读取,这样可以提高内存的利用率。

Tips:可以将一些图片和标签制作成这种格式

一般来说需要两步:

第一:用tf.train.Example协议存储训练数据,数据的特征用键值对应形式表示;比如"img_raw"表示图像数据,"label"表示标签数据。它们值的类型可以是BytesList/FloatList/Int64List等。

第二:用SerializeToString()把数据序列化成字符串存储

例:将图像和标签保存为tfrecords:

writer_train= tf.python_io.TFRecordWriter(tfrecords_path)example = tf.train.Example(features=tf.train.Features(feature={'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer_train.write(example.SerializeToString())writer_train.close()

2. 利用图像数据集生成tfrecords文件

深度学习中,图像是常用的数据集,制作图像数据集前需要进行尺寸缩放,再进行二进制转换:

import tensorflow as tffrom PIL import Imageimport ostfrecords_path = "./train.tfrecords" #生成tfrecords的位置object_path = "./" #图像数据集的位置dim = 224#缩放尺寸writer_train= tf.python_io.TFRecordWriter(tfrecords_path)total = os.listdir(object_path)num = len(total)num_i = 1value = 0 #这里标签的值设为零,可以根据图片名来自己设定标签值for index in total:img_path=os.path.join(object_path,index)img=Image.open(img_path) #打开图像img=img.resize((dim,dim)) #图像尺寸变换img_raw=img.tobytes() #将数据转换为二进制example = tf.train.Example(features=tf.train.Features(feature={'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer_train.write(example.SerializeToString()) #序列化为字符串sys.stdout.write('--------%.4f%%-----'%(num_i/float(num)*100)) #查看制作进度sys.stdout.write('\r')sys.stdout.flush()num_i = num_i +1writer_train.close()

3. 读取解析tfrecords文件

(1)利用tf.train.shuffle_batch或者tf.train.batch来取数据,前者可以打乱数据顺序

#根据文件名生成一个队列,如果tfrecords文件较多时filename_queue = tf.train.string_input_producer([tfrecords_path],shuffle=True) #打乱reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string)})image = tf.decode_raw(features['img_raw'], tf.uint8) #利用decode_raw解码图像image = tf.reshape(image,[dim,dim,3])#reshape 224*224*3 #reshape到一定维度image = tf.cast(image,tf.float32)*(1./255)#image张量可以除以255,*(1./255)label = tf.cast(features['label'], tf.int32)#利用cast解码标签img_batch, label_batch = tf.train.shuffle_batch([image,label],batch_size=batch_size, #批量数batch sizenum_threads=4,#启用的线程数capacity= 640, #容器大小min_after_dequeue=5)#打乱的最小数

另外需要用sess执行会话,这样便可喂入神经网络:(需要打开线程协调器否则无法工作

coord = tf.train.Coordinator() threads=tf.train.start_queue_runners(sess=sess,coord=coord) #打开线程协调器for 循环体:...img_,label_ = sess.run([img_batch, label_batch])sess.run([train_step,accuracy,loss],feed_dict={x:img_,y_:label_})...coord.request_stop()coord.join(threads)#关闭线程协调器

(2)利用tf.data.TFRecordDataset读取数据

用法:首先读取tfrecords文件,然后对数据更改:进行打乱,训练轮数,批量数等,需要注意的是,batch一定要放在shuffle后面

dataset = tf.data.TFRecordDataset([tfrecords_path])dataset = dataset\.repeat(epochs_data)\.shuffle(1000)\.batch(batch_size)\.map(load_image,num_parallel_calls = 8)#注意一定要将shuffle放在batch前iter = dataset.make_initializable_iterator()#得到迭代器,make_one_shot_iterator()train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值

另外还需要定义一个map,使图像和标签解码出来:

def load_image(serialized_example): features={'label': tf.io.FixedLenFeature([], tf.int64),'img_raw' : tf.io.FixedLenFeature([], tf.string)}parsed_example = tf.io.parse_example(serialized_example,features)#同上,也需要对图像和标签数据进行解码image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)image = tf.reshape(image,[-1,dim,dim,3])image = tf.cast(image,tf.float32)*(1./255)label = tf.cast(parsed_example['label'], tf.int32)label = tf.reshape(label,[-1,1])return image,labeldef dataset_tfrecords(tfrecords_path): dataset = tf.data.TFRecordDataset([tfrecords_path])'''这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):但这里只有一个tfrecords文件,所以直接将tfrecords_path设置到文件而非文件夹'''dataset = dataset\.repeat(epochs_data)\.shuffle(1000)\.batch(batch_size)\.map(load_image,num_parallel_calls = 8)#注意一定要将shuffle放在batch前iter = dataset.make_initializable_iterator()#得到迭代器,make_one_shot_iterator()train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值return train_datas,iter

在之后的语句中,需要使用sess来执行会话,并需要对迭代器初始化:

sess = tf.Session()train_datas,iter = dataset_tfrecords(tfrecords_path)sess.run(iter.initializer)for i in numsteps:...train_datas_ = sess.run(train_datas) #执行会话以后获得图像数据...

需要注意的是,当所有的数据遍历完,相当于迭代器超过了范围的时候,会tensorflow会抛出一个错误:tf.errors.OutOfRangeError,可以把这个错误作为退出循环的条件break;

for i in numsteps:try :......except tf.errors.OutOfRangeError: break


本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。