1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > TensorFlow2.0 利用TFRecord存取数据集 分批次读取训练

TensorFlow2.0 利用TFRecord存取数据集 分批次读取训练

时间:2020-06-01 00:19:01

相关推荐

TensorFlow2.0 利用TFRecord存取数据集 分批次读取训练

目录

头文件

一、读取数据集(图片名)

二、将数据集图片、标签写入TFRecord

三、从TFRecord中读取数据集

四、构建模型

五、训练模型

实验结果

头文件

import tensorflow as tfimport os

一、读取数据集(图片名)

data_dir = "D:/dataset/cats_and_dogs_filtered"train_cat_dir = data_dir + "/train/cats/"train_dog_dir = data_dir + "/train/dogs/"train_tfrecord_file = data_dir + "/train/train.tfrecords"test_cat_dir = data_dir + "/validation/cats/"test_dog_dir = data_dir + "/validation/dogs/"test_tfrecord_file = data_dir + "/validation/test.tfrecords"train_cat_filenames = [train_cat_dir + filename for filename in os.listdir(train_cat_dir)]train_dog_filenames = [train_dog_dir + filename for filename in os.listdir(train_dog_dir)]train_filenames = train_cat_filenames + train_dog_filenamestrain_labels = [0]*len(train_cat_filenames) + [1]*len(train_dog_filenames)test_cat_filenames = [test_cat_dir + filename for filename in os.listdir(test_cat_dir)]test_dog_filenames = [test_dog_dir + filename for filename in os.listdir(test_dog_dir)]test_filenames = test_cat_filenames + test_dog_filenamestest_labels = [0]*len(test_cat_filenames) + [1]*len(test_dog_filenames)

二、将数据集图片、标签写入TFRecord

def encoder(filenames, labels, tfrecord_file):with tf.io.TFRecordWriter(tfrecord_file) as writer:for filename, label in zip(filenames, labels):image = open(filename, 'rb').read()feature = {# 建立feature字典'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}# 通过字典创建exampleexample = tf.train.Example(features=tf.train.Features(feature=feature))# 将example序列化并写入字典writer.write(example.SerializeToString())encoder(train_filenames, train_labels, train_tfrecord_file)encoder(test_filenames, test_labels, test_tfrecord_file)

三、从TFRecord中读取数据集

def decoder(tfrecord_file, is_train_dataset=None):dataset = tf.data.TFRecordDataset(tfrecord_file)feature_discription = {'image': tf.io.FixedLenFeature([], tf.string),'label': tf.io.FixedLenFeature([], tf.int64)}def _parse_example(example_string): # 解码每一个examplefeature_dic = tf.io.parse_single_example(example_string, feature_discription)feature_dic['image'] = tf.io.decode_jpeg(feature_dic['image'])feature_dic['image'] = tf.image.resize(feature_dic['image'], [256, 256])/255.0return feature_dic['image'], feature_dic['label']batch_size = 32if is_train_dataset is not None:dataset = dataset.map(_parse_example).shuffle(buffer_size=2000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)else:dataset = dataset.map(_parse_example)dataset = dataset.batch(batch_size)return datasettrain_data = decoder(train_tfrecord_file, 1)test_data = decoder(test_tfrecord_file)

四、构建模型

class CNNModel(tf.keras.models.Model):def __init__(self):super(CNNModel, self).__init__()self.conv1 = tf.keras.layers.Conv2D(12, 3, activation='relu')self.maxpool1 = tf.keras.layers.MaxPooling2D()self.conv2 = tf.keras.layers.Conv2D(12, 5, activation='relu')self.maxpool2 = tf.keras.layers.MaxPooling2D()self.flatten = tf.keras.layers.Flatten()self.d1 = tf.keras.layers.Dense(64, activation='relu')self.d2 = tf.keras.layers.Dense(2, activation='softmax')def call(self, inputs):x = self.conv1(inputs)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.flatten(x)x = self.d1(x)x = self.d2(x)return x

五、训练模型

def train_CNNModel():model = CNNModel()loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()optimizer = tf.keras.optimizers.Adam(0.001)train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')@tf.functiondef train_step(images, labels):with tf.GradientTape() as tape:logits = model(images)loss = loss_obj(labels, logits)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_acc(labels, logits)@tf.functiondef test_step(images, labels):logits = model(images)test_acc(labels, logits)Epochs = 5for epoch in range(Epochs):train_acc.reset_states()test_acc.reset_states()for images, labels in train_data:train_step(images, labels)for images, labels in test_data:test_step(images, labels)tmp = 'Epoch {}, Acc {}, Test Acc {}'print(tmp.format(epoch + 1,train_acc.result() * 100,test_acc.result() * 100))train_CNNModel()

实验结果

Epoch 1, Acc 51.45000076293945, Test Acc 51.70000076293945Epoch 2, Acc 60.650001525878906, Test Acc 58.099998474121094Epoch 3, Acc 70.5, Test Acc 63.30000305175781Epoch 4, Acc 78.05000305175781, Test Acc 69.30000305175781Epoch 5, Acc 87.4000015258789, Test Acc 69.19999694824219

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。