目录
头文件
一、读取数据集(图片名)
二、将数据集图片、标签写入TFRecord
三、从TFRecord中读取数据集
四、构建模型
五、训练模型
实验结果
头文件
import tensorflow as tfimport os
一、读取数据集(图片名)
data_dir = "D:/dataset/cats_and_dogs_filtered"train_cat_dir = data_dir + "/train/cats/"train_dog_dir = data_dir + "/train/dogs/"train_tfrecord_file = data_dir + "/train/train.tfrecords"test_cat_dir = data_dir + "/validation/cats/"test_dog_dir = data_dir + "/validation/dogs/"test_tfrecord_file = data_dir + "/validation/test.tfrecords"train_cat_filenames = [train_cat_dir + filename for filename in os.listdir(train_cat_dir)]train_dog_filenames = [train_dog_dir + filename for filename in os.listdir(train_dog_dir)]train_filenames = train_cat_filenames + train_dog_filenamestrain_labels = [0]*len(train_cat_filenames) + [1]*len(train_dog_filenames)test_cat_filenames = [test_cat_dir + filename for filename in os.listdir(test_cat_dir)]test_dog_filenames = [test_dog_dir + filename for filename in os.listdir(test_dog_dir)]test_filenames = test_cat_filenames + test_dog_filenamestest_labels = [0]*len(test_cat_filenames) + [1]*len(test_dog_filenames)
二、将数据集图片、标签写入TFRecord
def encoder(filenames, labels, tfrecord_file):with tf.io.TFRecordWriter(tfrecord_file) as writer:for filename, label in zip(filenames, labels):image = open(filename, 'rb').read()feature = {# 建立feature字典'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}# 通过字典创建exampleexample = tf.train.Example(features=tf.train.Features(feature=feature))# 将example序列化并写入字典writer.write(example.SerializeToString())encoder(train_filenames, train_labels, train_tfrecord_file)encoder(test_filenames, test_labels, test_tfrecord_file)
三、从TFRecord中读取数据集
def decoder(tfrecord_file, is_train_dataset=None):dataset = tf.data.TFRecordDataset(tfrecord_file)feature_discription = {'image': tf.io.FixedLenFeature([], tf.string),'label': tf.io.FixedLenFeature([], tf.int64)}def _parse_example(example_string): # 解码每一个examplefeature_dic = tf.io.parse_single_example(example_string, feature_discription)feature_dic['image'] = tf.io.decode_jpeg(feature_dic['image'])feature_dic['image'] = tf.image.resize(feature_dic['image'], [256, 256])/255.0return feature_dic['image'], feature_dic['label']batch_size = 32if is_train_dataset is not None:dataset = dataset.map(_parse_example).shuffle(buffer_size=2000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)else:dataset = dataset.map(_parse_example)dataset = dataset.batch(batch_size)return datasettrain_data = decoder(train_tfrecord_file, 1)test_data = decoder(test_tfrecord_file)
四、构建模型
class CNNModel(tf.keras.models.Model):def __init__(self):super(CNNModel, self).__init__()self.conv1 = tf.keras.layers.Conv2D(12, 3, activation='relu')self.maxpool1 = tf.keras.layers.MaxPooling2D()self.conv2 = tf.keras.layers.Conv2D(12, 5, activation='relu')self.maxpool2 = tf.keras.layers.MaxPooling2D()self.flatten = tf.keras.layers.Flatten()self.d1 = tf.keras.layers.Dense(64, activation='relu')self.d2 = tf.keras.layers.Dense(2, activation='softmax')def call(self, inputs):x = self.conv1(inputs)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.flatten(x)x = self.d1(x)x = self.d2(x)return x
五、训练模型
def train_CNNModel():model = CNNModel()loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()optimizer = tf.keras.optimizers.Adam(0.001)train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')@tf.functiondef train_step(images, labels):with tf.GradientTape() as tape:logits = model(images)loss = loss_obj(labels, logits)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_acc(labels, logits)@tf.functiondef test_step(images, labels):logits = model(images)test_acc(labels, logits)Epochs = 5for epoch in range(Epochs):train_acc.reset_states()test_acc.reset_states()for images, labels in train_data:train_step(images, labels)for images, labels in test_data:test_step(images, labels)tmp = 'Epoch {}, Acc {}, Test Acc {}'print(tmp.format(epoch + 1,train_acc.result() * 100,test_acc.result() * 100))train_CNNModel()
实验结果
Epoch 1, Acc 51.45000076293945, Test Acc 51.70000076293945Epoch 2, Acc 60.650001525878906, Test Acc 58.099998474121094Epoch 3, Acc 70.5, Test Acc 63.30000305175781Epoch 4, Acc 78.05000305175781, Test Acc 69.30000305175781Epoch 5, Acc 87.4000015258789, Test Acc 69.19999694824219