1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > Tensorflow2.0(1):加载图片数据集--TFRecord

Tensorflow2.0(1):加载图片数据集--TFRecord

时间:2019-07-27 19:02:14

相关推荐

Tensorflow2.0(1):加载图片数据集--TFRecord

目录

1、TFRecord介绍

2、TFRecord格式数据文件处理过程

3、TFRecord格式

4、生成TFRecord格式数据

5、TFRecord数据文件解码

6、解码并生成Dataset数据集

7、查看第一批元素

1、TFRecord介绍

TFRecord 是 TensorFlow 中的数据集中存储格式,TFRecord是一种二进制文件。

将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而更高效地进行大规模的模型训练。

TFRecord 内部使用了二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可。简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个 TFRecord 文件,来提高处理效率。

2、TFRecord格式数据文件处理过程

将形式各样的数据集整理为 TFRecord 格式,可以对数据集中的每个元素进行以下步骤:

(1)读取该数据元素到内存;

(2)将该元素转换为 tf.train.Example 对象(每一个 tf.train.Example 由若干个 tf.train.Feature 的字典组成,因此需要先建立 Feature 的字典);

(3)将该 tf.train.Example 对象序列化为字符串,并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件。

读取 TFRecord 数据可按照以下步骤:

(1)通过 tf.data.TFRecordDataset 读入原始的 TFRecord 文件(此时文件中的 tf.train.Example 对象尚未被反序列化),获得一个 tf.data.Dataset数据集对象;

(2)通过 Dataset.map 方法,对该数据集对象中的每一个序列化的 tf.train.Example 字符串执行 tf.io.parse_single_example 函数,从而实现反序列化。

3、TFRecord格式

TFRecord内部包含了多个tf.train.Example, 而Exampleprotocol buffer(protobuf) 数据标准的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而每一个feature是一个key-value的键值对,其中,key是string类型,而value的取值有三种:

bytes_list:可以存储stringbyte两种数据类型。float_list:可以存储float(float32)double(float64)两种数据类型 。int64_list:可以存储bool, enum, int32, uint32, int64, uint64

tf.train.Feature 支持三种数据格式:

tf.train.BytesList :字符串或原始 Byte 文件(如图片),通过 bytes_list 参数传入一个由字符串数组初始化的 tf.train.BytesList 对象;tf.train.FloatList :浮点数,通过 float_list 参数传入一个由浮点数数组初始化的 tf.train.FloatList 对象;tf.train.Int64List :整数,通过 int64_list 参数传入一个由整数数组初始化的 tf.train.Int64List 对象。

如果只希望保存一个元素而非数组,传入一个只有一个元素的数组即可

4、生成TFRecord格式数据

import osimport tensorflow as tf

# 读取数据集中图片文件名和标签def read_image_filenames (data_dir) :cat_dir = data_dir + "cat/"dog_dir = data_dir + "dog/"cat_filenames = [cat_dir + fn for fn in os.listdir(cat_dir)]dog_filenames = [dog_dir + fn for fn in os.listdir(dog_dir)]filenames = cat_filenames + dog_filenames# 将cat类的标签设为0, dog类的标签设为1labels = [0]* len(cat_filenames) + [1] *len(dog_filenames)return filenames,labels

# 定义生成TFRecord格式数据文件函数def write_TFRecord_file(filenames,labels,tfrecord_file):with tf.io.TFRecordWriter(tfrecord_file) as writer:for filename,label in zip(filenames,labels) :# 读取数据集图片到内存,image 为一个 Byte类型的字符串image = open(filename,"rb").read()# 建立tf.train.Feature字典feature = {# 图片是一个Bytes对象'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),# 标签是一个Int对象'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}# 通过feature字典建立Exampleexample = tf.train.Example(features=tf.train.Features(feature=feature))# 将Example序列化并写入TFRecord 文件writer.write(example.SerializeToString())

train_data_dir = './data_small/train/' # 数据集路径tfrecord_file = train_data_dir + 'train.tfrecords' # 生成的tfrecord路径 if not os.path.isfile(tfrecord_file): # 判断train.tfrecord是否存在train_filenames,train_labels = read_image_filenames(train_data_dir)write_TFRecord_file(train_filenames,train_labels,tfrecord_file)print('write TFRecord file:',tfrecord_file)else:print(tfrecord_file,'already exists.')

5、TFRecord数据文件解码

1、定义TFRecord数据文件解码函数

# 定义Feature结构,告诉解码器每个Feature的类型是什么,要与生成的TFrecord的类型一致feature_description = {"image":tf.io.FixedLenFeature([],tf.string),"label":tf.io.FixedLenFeature([],tf.int64)}# 将TFRecord 文件中的每一个序列化的 tf.train.Example 解码def parse_example(example_string):feature_dict = tf.io.parse_single_example(example_string,feature_description)feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解码JPEG图片feature_dict['image'] = tf.image.resize(feature_dict['image'],[224,224])/ 255.0 # 改变图片尺寸并进行归一化return feature_dict['image'],feature_dict['label']

2、定义读取TFRecord文件,解码并生成Dataset数据集的函数

def read_TFRecond_file(tfrecord_file):# 读取TFRecord 文件raw_dataset = tf.data.TFRecordDataset(tfrecord_file)# 解码dataset = raw_dataset.map(parse_example)return dataset

3、tf.data.TFRecordDataset

tfrecord文件创建一个TFRecordDataset类的实例对象

参数:tf.data.TFRecordDataset(filenames,compression_type=None,

buffer_size=None,num_parallel_reads=None)

一般只传第一个参数filenames即可 ,生成的tfrecord文件

6、解码并生成Dataset数据集

# Dataset的数据缓冲器大小,和数据集大小及规律有关buffer_size = 20000# Dataset的数据批次大小,每批次多少个样本数batch_size = 8

dataset_train = read_TFRecond_file(tfrecord_file) # 解码dataset_train = dataset_train.shuffle(buffer_size) # 打乱数据dataset_train = dataset_train.batch(batch_size) # 分批次进行读取

7、查看第一批元素

import matplotlib.pyplot as pltsub_dataset = dataset_train.take(1) # 读取第一个批次for images,labels in sub_dataset:fig,axs = plt.subplots(1, batch_size)for i in range(batch_size):axs[i].set_title(labels.numpy()[i])axs[i].imshow(images.numpy()[i])axs[i].set_xticks([])axs[i].set_yticks([])plt.show()

案例实例地址:Tfrecord介绍以及实例· GitHub

链接:猫狗大战数据集

提取码:kqgt

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。