1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 【TensorFlow】用TFRecord方式对数据进行读取(一)

【TensorFlow】用TFRecord方式对数据进行读取(一)

时间:2019-10-25 07:06:47

相关推荐

【TensorFlow】用TFRecord方式对数据进行读取(一)

在做深度学习项目时,在模型训练前,通常要对训练/验证图像进行读取操作。之前博文《TensorFlow 卷积神经网络 - 猫狗识别》使用的是OpenCV读取的方式。使用OpenCV把图像读成矩阵形式当然可以满足模型训练的要求,此方式在处理小批量图像时还可以,如果处理大批量图像,就显得有点慢了。

对于大型项目、大批量的图像,经常用TFRecord的方式对数据进行读取。TFRecord是TensorFlow支持的格式,速度快,1W以上的量建议使用TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据。其优势是能更好的利用内存,更方便地复制和移动,这更符合TensorFlow执行引擎的处理方式。通常数据转换成tfrecord格式需要写个小程序将每一个样本组装成protocol buffer定义的Example的对象,序列化成字符串,再由tf.python_io.TFRecordWriter写入文件即可。

在使用TFRecord方式读取数据之前,通常需要把相同类型的数据放在同一个文件夹。例如:

上图中,“flower_photos”为总文件夹,里面放了5个子文件夹,即把所有的玫瑰图片放到“roses”文件夹,所有的向日葵图片放到“sunflowers”文件夹,等等。这样做的目的是方便完成“图片路径”--“图片标签(例:1、2、3)”--“图片名称(例:daisy、dandelion、roses)”之间的映射。

roses文件夹下的图片:

程序实现

目录结构:

flower_label.txt:

此文件的内容存放./flower_photos目录下的5个子文件名称,方便程序读取图片。

daisydandelionrosessunflowerstulips

build_image_data.py:

# coding=utf-8from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom datetime import datetimeimport osimport randomimport sys# 多线程制作,速度更快。数据预处理、建立数据源写一块import threadingimport numpy as npimport tensorflow as tf# 定义string和int类型参数# 没演示验证集,只有训练集,可以在目录里面加上验证集。train_directory为参数名tf.app.flags.DEFINE_string('train_directory', './flower_photos/', 'Training data directory')# 验证集,未指定单独的验证集,偷懒tf.app.flags.DEFINE_string('validation_directory', './flower_photos/', 'Validation data directory')# TFRecord输出目录tf.app.flags.DEFINE_string('output_directory', './data/', 'Output data directory')# 想生成几个TFrecord文件,train_shards / num_threads 要能够整除,这样才好能分配数量tf.app.flags.DEFINE_integer('train_shards', 2, 'Number of shards in training TFRecord files.')# 同上,不做验证集,只做训练集tf.app.flags.DEFINE_integer('validation_shards', 0, 'Number of shards in validation TFRecord files.')# 启动线程的个数tf.app.flags.DEFINE_integer('num_threads', 2, 'Number of threads to preprocess the images.')# The labels file contains a list of valid labels are held in this file .# Assumes that the file contains entries as such:# dog# cat# flower# where each line corresponds to a labels. We map each label contained in# the file to an integer corresponding to the line number starting from 0.# flower_label.txt和子文件夹的名字一一对应tf.app.flags.DEFINE_string('labels_file', './flower_label.txt', 'labels file')# 获得上述定义的参数FLAGS = tf.app.flags.FLAGSdef _int64_feature(value):"""Wrapper for inserting int64 feature into Example proto.""""""isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()isinstance() 与 type() 区别:type() 不会认为子类是一种父类类型,不考虑继承关系。isinstance() 会认为子类是一种父类类型,考虑继承关系。"""if not isinstance(value, list):value = [value]return tf.train.Feature(int64_list=tf.train.Int64List(value=value))def _bytes_feature(value):"""Wrapper for inserting bytes features into Example proto."""return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _find_image_files(data_dir, labels_file):"""Build a list of all images files and labels in the data set.:param data_dir: string, path to the root directory of images.:param labels_file: string, path to the labels file.The list of valid labels are held in this file, Assumes that the file contains entries as such:dogcat flowerwhere each line corresponds to a label. We map each label contained in the file to an integer staring with theinteger 0 corresponding to the label contained in the first line.:return:filenames: list of strings; each string is a path to an image file.texts: list of strings; each string is the class, e.g. 'dog'labels: list of integer; each integer identifies the ground truth."""print('目标文件夹位置:%s.' % data_dir)# 读flower_label.txt文件的内容"""tf.gfile.FastGFile(path, decodestyle) 函数功能:实现对图片的读取。 函数参数:(1)path:图片所在路径 (2)decodestyle:图片的解码方式。(‘r’:UTF-8编码; ‘rb’:非UTF-8编码)"""unique_labels = [l.strip() for l in tf.gfile.FastGFile(labels_file, 'r').readlines()]labels = []filenames = []texts = []# Leave label index 0 empty as a background class.label_index = 1# Construct the list of JPEG files and labels.for text in unique_labels:jpeg_file_path = '%s/%s/*' % (data_dir, text)try:# tf.gfile.Glob()用于返回与给定模式匹配的文件列表matching_files = tf.gfile.Glob(jpeg_file_path)except:print(jpeg_file_path)continue# 从“1”开始,扩充每一图片类别的labelslabels.extend([label_index] * len(matching_files))# 根据flower_label.txt内容,扩充textstexts.extend([text] * len(matching_files))filenames.extend(matching_files)label_index += 1# shuffle the ordering of all image files in order to guarantee# random ordering of the images with respect to label in the# saved TFRecord files. Make the randomization repeatable.# 洗牌,把当前顺序打乱,标签为1、2、3、4、5、打乱shuffled_index = list(range(len(filenames)))# 保证shuffled_index之后每次的随机一样random.seed(12345)random.shuffle(shuffled_index)# 数据重新排列,执行完shuffle之后,数据可以对应上filenames = [filenames[i] for i in shuffled_index]texts = [texts[i] for i in shuffled_index]labels = [labels[i] for i in shuffled_index]print('Found %d JPEG files across %d labels inside %s.' % (len(filenames), len(unique_labels), data_dir))return filenames, texts, labelsclass ImageCoder(object):"""Helper class that provides TensorFlow image coding utilities."""# 把所有图片转换成.jpg的RGB的形式def __init__(self):# Create a single Session to run all image coding calls.self._sess = tf.Session()# Initializes function that converts PNG to JPEG data.# 确保所有图像格式都相同self._png_data = tf.placeholder(dtype=tf.string)# 解码为3通道image = tf.image.decode_png(self._png_data, channels=3)# 编码为RGBself._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)# Initializes function that decodes RGB JPEG data.self._decode_jpeg_data = tf.placeholder(dtype=tf.string)self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)def png_to_jpeg(self, image_data):return self._sess.run(self._png_to_jpeg, feed_dict={self._png_data: image_data})def decode_jpeg(self, image_data):image = self._sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data})assert len(image.shape) == 3assert image.shape[2] == 3return imagedef _process_image(filename, coder):"""Process a single image file.:param filename: string, path to an image file e.g., '/path/to/example.JPG'.:param coder: instance of ImageCoder to provide TensorFlow image coding utils.:return: image_buffer: string, JPEG encoding of RGB image.height: integer, image height in pixels.width: integer, image width in pixels."""# Read the image file.with tf.gfile.FastGFile(filename, 'rb') as f:image_data = f.read()# Convert any PNG to JPEG's for consistency.if _is_png(filename):print('Converting PNG to JPEG for %s' % filename)image_data = coder.png_to_jpeg(image_data)# Decode the RGB JPEG.image = coder.decode_jpeg(image_data)# Check that image converted to RGB. h, w, channelassert len(image.shape) == 3height = image.shape[0]width = image.shape[1]# 判断是否三通道assert image.shape[2] == 3return image_data, height, widthdef _is_png(filename):"""Determine if a file contains a PNG format image.:param filename: string, path of the iamge file.:return: boolean indicating if the image is a PNG."""return '.png' in filenamedef _convert_to_example(filename, image_buffer, label, text, height, width):"""Build an Example proto for an example.:param filename: string, path to an image file, e.g., '/path/to/example.JPG':param image_buffer: string, JPEG encoding of RGB image:param label: integer, identifier for the ground truth for the network:param text: string, unique human-readable, e.g. 'dog':param height: integer, image height in pixels:param width: integer, image width in pixels:return: Example proto"""colorspace = 'RGB'channels = 3image_format = 'JPEG'# pat.as_bytes(),将字节或unicode转换为字节,使用utf-8编码文本example = tf.train.Example(features=tf.train.Features(feature={'image/height': _int64_feature(height),'image/width': _int64_feature(width),'image/colorspace': _bytes_feature(pat.as_bytes(colorspace)),'image/channels': _int64_feature(channels),'image/class/label': _int64_feature(label),'image/class/text': _bytes_feature(pat.as_bytes(text)),'image/format': _bytes_feature(pat.as_bytes(image_format)),'image/filename': _bytes_feature(pat.as_bytes(os.path.basename(filename))),'image/encoded': _bytes_feature(pat.as_bytes(image_buffer))# 'image/encoded': _bytes_feature(image_buffer)}))return exampledef _process_image_files_batch(coder, thread_index, ranges, name, filenames, texts, labels, num_shards):"""Processes and saves list of images as TFRecord in 1 thread.:param coder: instance of ImageCoder to provide TensorFlow image coding utils.:param thread_index: integer, unique batch to run index is within [0, len(ranges)].:param ranges: list of pairs of integers specifying ranges of each batches to analyze in parallel.:param name: string, unique identifier specifying the data set.:param filenames: list of strings; each string is a path to an image file.:param texts: list of strings; each string is human readable, e.g. 'dog'.:param labels: list of integer; each integer identifies the ground truth.:param num_shards: integer number of shards for this data set.:return:"""# Each thread produces N shards where N=int(num_shards / num_threads).# For instance, if num_shards=128, and the num_threads=2, then the first thread would produce shards[0, 64].num_threads = len(ranges)assert not num_shards % num_threadsnum_shards_per_batch = int(num_shards / num_threads)shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1], num_shards_per_batch + 1).astype(int)num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]counter = 0for s in range(num_shards_per_batch):# Generate a sharded version of the file name, e.g. 'train-00001-of-00002'shard = thread_index * num_shards_per_batch + soutput_filename = '%s-%.5d-of-%.5d.tfrecord' % (name, shard, num_shards)output_file = os.path.join(FLAGS.output_directory, output_filename)writer = tf.python_io.TFRecordWriter(output_file)shard_counter = 0files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)for i in files_in_shard:filename = filenames[i] # 全路径label = labels[i] # 标签text = texts[i] # 文件夹名称image_buffer, height, width = _process_image(filename, coder)example = _convert_to_example(filename, image_buffer, label, text, height, width)writer.write(example.SerializeToString())shard_counter += 1counter += 1if not counter % 1000:print('%s [thread %d]: Processed %d of %d image in thread batch.' % (datetime.now(), thread_index, counter, num_files_in_thread))sys.stdout.flush()writer.close()print('%s [thread %d]: Wrote %d images to %s' % (datetime.now(), thread_index, shard_counter, output_file))# 关闭多线程sys.stdout.flush()shard_counter = 0print('%s [thread %d]: Wrote %d images to %d shards.' % (datetime.now(), thread_index, counter, num_files_in_thread))sys.stdout.flush()def _process_image_files(name, filenames, texts, labels, num_shards):"""Process and save list of image as TFRecord of Example protos.:param name: string, unique identifier specifying the data set:param filenames: list of strings; each string is a path to an image file:param texts: list of strings; each string is human readable, e.g.'dog:param labels: list of integer identifies the ground truth:param num_shards: integer number os shards for this data set.:return:"""# filenames、texts、labels数量相对应assert len(filenames) == len(texts)assert len(filenames) == len(labels)# Break all images into batches with a [ranges[i][0], ranges[i][1]].# [0, 1835, 3670],从0至1835交给一个线程做;1835至3670交给另一个线程完成。spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)# 把spacing分成两部分,得到[0, 1835]和[1835, 3670]ranges = []for i in range(len(spacing) - 1):ranges.append([spacing[i], spacing[i + 1]])# Launch a thread for each batch.print('launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))sys.stdout.flush()# Create a mechanism for monitoring when all threads are finished.# TensorFlow的线程管理器coord = tf.train.Coordinator()# Create a generic TensorFlow-based utility for converting all image coding.coder = ImageCoder()threads = []for thread_index in range(len(ranges)):args = (coder, thread_index, ranges, name, filenames, texts, labels, num_shards)t = threading.Thread(target=_process_image_files_batch, args=args)t.start()threads.append(t)# Wait for all the threads to terminate.coord.join(threads)print('%s: Finished writing all %d images in data set.' % (datetime.now(), len(filenames)))sys.stdout.flush()def _process_dataset(name, directory, num_shards, labels_file):"""Process a complete data set and save it as a TFRecord.Args:name: string, unique identifier specifying the data set.directory: string, root path to the data set.num_shards: integer number if shards for this data set.labels_file: string, path to the labels file."""filenames, texts, labels = _find_image_files(directory, labels_file)_process_image_files(name, filenames, texts, labels, num_shards)def main(unused_argv):assert not FLAGS.train_shards % FLAGS.num_threads, ('在测试集中,线程数量应用建立文件个数相对应')assert not FLAGS.validation_shards % FLAGS.num_threads, ('在验证集中,线程数量应用建立文件个数相对应')print('生成数据文件夹%s' % FLAGS.output_directory)# run it!# 训练集_process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, FLAGS.labels_file)# 验证集# _process_dataset('validation', FLAGS.validation_directory, FLAGS.validation_shards, FLAGS.labels_file)if __name__ == '__main__':tf.app.run()

执行结果:

生成数据文件夹./data/目标文件夹位置:./flower_photos/.Instructions for updating:Use tf.gfile.GFile.Found 3670 JPEG files across 5 labels inside ./flower_photos/.launching 2 threads for spacings: [[0, 1835], [1835, 3670]]-08-28 12:49:17.142402 [thread 0]: Processed 1000 of 1835 image in thread batch.-08-28 12:49:17.362402 [thread 1]: Processed 1000 of 1835 image in thread batch.-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to ./data/train-00000-of-00002.tfrecord-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to 1835 shards.-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to ./data/train-00001-of-00002.tfrecord-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to 1835 shards.-08-28 12:49:26.274402: Finished writing all 3670 images in data set.

生成的TFRecord文件:

参考:

/moyu123456789/article/details/83956366

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。