1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > Tensorflow2.0实现对抗生成网络(GAN)

Tensorflow2.0实现对抗生成网络(GAN)

时间:2021-02-14 15:33:09

相关推荐

Tensorflow2.0实现对抗生成网络(GAN)

在这篇文章中,我们使用Tensorflow2.0来实现GAN,使用的数据集是手写数字数据集。

引入需要的库

import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersimport matplotlib.pyplot as plt%matplotlib inline

导入数据,归一化数据

(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')train_images = (train_images-127.5)/127.5BATCH_SIZE = 256BUFFER_SIZE = 60000datasets = tf.data.Dataset.from_tensor_slices(train_images)datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

建立生成器

def generator_model(): # 用100个随机数(噪音)生成手写数据集model = keras.Sequential()model.add(layers.Dense(256, input_shape=(100,), use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(28*28*1, use_bias=False, activation='tanh'))model.add(layers.BatchNormalization())model.add(layers.Reshape((28, 28, 1)))return model

建立判别器

def discriminator_model(): # 识别输入的图片model = keras.Sequential()model.add(layers.Flatten())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(256, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(1))return model

分别定义判别器和生成器的损失函数

对于判别器来说,我们需要将导入的原始图片识别为真(1),将生成器胜场的图像识别为假(0)。

对于生成器来说,我们需要使得生成的图片无限接近于真实图片。

cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)def discriminator_loss(real_out, fake_out):real_loss = cross_entropy(tf.ones_like(real_out), real_out)fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)return real_loss + fake_lossdef generator_loss(fake_out):return cross_entropy(tf.ones_like(fake_out), fake_out)

在以上代码中,real_out是指向判别器输入原始图像得到的结果;fake_out是指向判别器输入生成图像得到的结果。

所以对于判别器的损失函数来说,real_out应该无限接近于1;fake_out应该无限接近于0。即我们想训练出的判别器应该对图片有很高的识别能力。

但对于生成器的损失函数来说,fake_out应该无限接近于1,也就是令判别器很难分辨出生成的图片。

【注】keras.losses.BinaryCrossentropy(from_logits=True)的用法可以参考:tensorflow2.0中损失函数tf.keras.losses.BinaryCrossentropy()的用法。

分别定义生成器和判别器的优化函数

generator_opt = keras.optimizers.Adam(1e-4)discriminator_opt = keras.optimizers.Adam(1e-4)

实例化生成器和判别器

generator = generator_model()discriminator = discriminator_model()

定义训练过程

noise_dim = 100 # 即用100个随机数生成图片def train_step(images):noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:real_out = discriminator(images, training=True)gen_image = generator(noise, training=True)fake_out = discriminator(gen_image, training=True)gen_loss = generator_loss(fake_out)disc_loss = discriminator_loss(real_out, fake_out)gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)gradient_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))discriminator_opt.apply_gradients(zip(gradient_disc, discriminator.trainable_variables))

gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)表示计算gen_loss对于generator的所有变量的梯度。

generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))表示根据gradient_gen来优化generator的变量。

【注】梯度带及梯度更新的用法参考:Tensorflow中的梯度带(GradientTape)以及梯度更新。

定义绘图函数

def generate_plot_image(gen_model, test_noise):pre_images = gen_model(test_noise, training=False)fig = plt.figure(figsize=(4, 4))for i in range(pre_images.shape[0]):plt.subplot(4, 4, i+1)plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap='gray')plt.axis('off')plt.show()

plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap=‘gray’)

这里是因为我们使用tanh激活函数之后会将结果限制在-1到1之间,而我们需要将其转化到0到1之间。

定义训练函数

EPOCHS = 100 # 训练100次num_exp_to_generate = 16 # 生成16张图片seed = tf.random.normal([num_exp_to_generate, noise_dim]) # 16组随机数组,每组含100个随机数,用来生成16张图片。def train(dataset, epochs):for epoch in range(epochs):for image_batch in dataset:train_step(image_batch)print('.', end='')generate_plot_image(generator, seed)train(datasets, EPOCHS)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。