1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > Tensorflow—tfrecord数据集生成与使用

Tensorflow—tfrecord数据集生成与使用

时间:2023-08-28 01:27:23

相关推荐

Tensorflow—tfrecord数据集生成与使用

参考内容:

数据读取的官方教程:Tensorflow导入数据以及使用数据

tfrecord数据集生成:

数据准备:图片数据+图片目录与label一一对应的的txt

先读取图片信息的txt文件,得到每个图片的路径以及它们的标签,然后对这个图片作一些预处理,最后将图片以及它对应的标签序列化,并建立图片和标签的索引(即以下代码的”img_raw”, “label”)。详见代码。

import randomimport tensorflow as tffrom PIL import Imagedef create_record(records_path, data_path, img_txt):# 声明一个TFRecordWriterwriter = tf.python_io.TFRecordWriter(records_path)# 读取图片信息,并且将读入的图片顺序打乱img_list = []with open(img_txt, 'r') as fr:img_list = fr.readlines()random.shuffle(img_list)cnt = 0# 遍历每一张图片信息for img_info in img_list:# 图片相对路径img_name = img_info.split(' ')[0]# 图片类别img_cls = int(img_info.split(' ')[1])img_path = data_path + img_nameimg = Image.open(img_path)# 对图片进行预处理(缩放,减去均值,二值化等等)img = img.resize((128, 128))img_raw = img.tobytes()# 声明将要写入tfrecord的key值(即图片,标签)example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[img_cls])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))# 将信息写入指定路径writer.write(example.SerializeToString())# 打印一些提示信息~cnt += 1if cnt % 1000 == 0:print "processed %d images" % cntwriter.close()# 指定你想要生成tfrecord名称,图片文件夹路径,含有图片信息的txt文件records_path = '/the/name/of/your/haha.tfrecords'data_path = '/the/root/of/your/image_folder/'img_txt = '/image/labels/list.txt'create_record(records_path, data_path, img_txt)

tfrecord数据集使用:

目前为止,使用TFrecord最方便的方式是用TensorFlow的Dataset ApI。在这里,劝大家千万千万不要用queue的方式读取数据(麻烦且已经过时)。

首先,我们定义好_parse_function,这个函数是用来指定TFrecord中索引的(即上文中的”img_raw”, “label”)。然后我们定义一个TFRecordDataset,并借助_parse_function来读取数据。最后,为了得到每一轮的训练数据,我们只需要再额外声明一个iterator,每次调用get_next()就可以啦。

# 定义如何解析TFrecord数据def _parse_function(example_proto):features = tf.parse_single_example(example_proto,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw': tf.FixedLenFeature([], tf.string)})# 取出我们需要的数据(标签,图片)label = features['label']img = features['img_raw']img = tf.decode_raw(img, tf.uint8)# 对标签以及图片作预处理img = tf.reshape(img, [128, 128, 3])img = tf.cast(img, tf.float32) * (1. / 255) - 0.5label = tf.cast(label, tf.int32)return img, label# 得到获取data batch的迭代器def data_iterator(tfrecords):# 声明TFRecordDatasetdataset = tf.contrib.data.TFRecordDataset(tfrecords)dataset = dataset.map(_parse_function)# 打乱顺序,无限重复训练数据,定义好batch sizedataset = dataset.shuffle(buffer_size=1000).repeat().batch(128)# 定义one_shot_iterator。官方上有许多类型的iterrator,这种是最简单的iterator = dataset.make_one_shot_iterator()return iterator# 指定TFrecords路径,得到training iterator。train_tfrecords = '/your/path/to/haha.tfrecords'train_iterator = data_iterator(train_tfrecords)# 使用方式举例with tf.Session(config= tfconfig) as sess:tf.initialize_all_variables().run()train_batch = train_iterator.get_next()for step in xrange(50000):train_x, train_y = sess.run(train_batch)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。