1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 使用生成对抗网络(GAN)实现对图像的生成

使用生成对抗网络(GAN)实现对图像的生成

时间:2024-05-29 13:11:27

相关推荐

使用生成对抗网络(GAN)实现对图像的生成

目录

前言

一、GAN模型简介

二、Fashion MNIST数据集简介

三、算法实现

1.导入必要的库

2.下载并展示数据集

3.数据的预处理

4.定义生成器

5.定义判别器

6.构建模型

7.训练模型

四、总结

参考资料:

前言

生成对抗网络(GAN)是一种无监督学习模型,它可以生成与真实数据相似的假数据,其应用非常广泛。本文基于python,使用生成对抗网络(GAN模型)对Fashion MNIST数据集中的图像,进行了生成。

一、GAN模型简介

GAN的英文全称为:GenerativeAdversarialNetworks,这是一种生成模型,它由Goodfellow等人于提出。

GAN由两个神经网络组成:生成器(G)和判别器(D)。生成器用于生成假数据;判别器用于判断数据的真假。两个网络相互对抗又彼此促进,生成器生成的假数据越来越逼真,而判别器的判断能力也越来越强。最终,生成器生成的假数据足以骗过判别器,达到了生成真实数据的目的。就像在草原上,狮子为了生存,需要捕捉到斑马,就要跑得比斑马更快;而斑马为了生存,需要逃避狮子的追捕,就要跑得比狮子更快,所以狮子和斑马都会跑得越来越快。

二、Fashion MNIST数据集简介

Fashion-MNIST是一个服装分类数据集,有如下表所示的10个类别,每个类别都包含训练集(6k个图像)和测试集(1k个图像),故训练集与测试集的图像分别共有6万张和1万张。

三、算法实现

1.导入必要的库

import numpy as npimport matplotlib.pyplot as pltimport tensorflow as tffrom tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropoutfrom tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2Dfrom tensorflow.keras.layers import LeakyReLUfrom tensorflow.keras.models import Sequential, Modelfrom tensorflow.keras.optimizers import Adam

2.下载并展示数据集

# 下载数据集(X_train, y_train), (_, _) = tf.keras.datasets.fashion_mnist.load_data()# 定义类别的名字class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# 创建图片fig, axes = plt.subplots(3, 3, figsize=(8, 8))axes = axes.ravel()# 随机选择9张图片for i in np.arange(0, 9):index = np.random.randint(0, len(X_train))axes[i].imshow(X_train[index], cmap='gray')axes[i].set_title(class_names[y_train[index]])axes[i].axis('off')plt.savefig("服装分类数据集示例.png")# 显示图片plt.show()

3.数据的预处理

对训练数据进行归一化处理,将像素值缩放到了[-1,1],并将图像的通道数从1变为3,以便与模型的输入形状匹配。

# 归一化数据X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)

4.定义生成器

该模型输入一个形状为 (100,) 的噪声向量,并输出一个形状为 (28, 28, 1) 的图像。包含了四个全连接层,前三个全连接层后面都跟着一个斜率为 0.2 的 LeakyReLU 激活函数和一个批量归一化层,最后一个全连接层具有 tanh 激活函数,输出一个范围在 -1 到 1 之间的值(生成图像的像素值)。最终,输出的图像形状被重塑为 (28, 28, 1)。

def build_generator():model = Sequential()# 创建了一个序列模型model.add(Dense(256, input_dim=100))# 添加全连接层,输入维度为100,输出维度为256model.add(LeakyReLU(alpha=0.2))# 添加LeakyReLU激活函数层model.add(BatchNormalization(momentum=0.8))# 添加批量归一化层model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(784, activation='tanh'))model.add(Reshape((28, 28, 1)))noise = Input(shape=(100,))# 定义输入层,维度为100img = model(noise)# 生成图像return Model(noise, img)

5.定义判别器

该模型的输入是一个(28, 28, 1)的图像,输出一个在[0,1]区间的概率值。先将输入的图像在Flatten()层将其展平,然后通过三个全连接层,最后输出一个概率分数,来判别输入图像的真假。

def build_discriminator():model = Sequential()model.add(Flatten(input_shape=(28, 28, 1)))# 将输入的28*28*1的图像展平为一维向量model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))img = Input(shape=(28, 28, 1))validity = model(img)return Model(img, validity)

6.构建模型

生成器模型生成噪声,然后用于生成假图像。判别器模型随后对真实和假图像进行训练,以区分它们。组合模型用于训练生成器生成更逼真的图像。

# 构建生成器generator = build_generator()z = Input(shape=(latent_dim,))# 生成噪声img = generator(z)# 构建判别器discriminator = build_discriminator()pile(loss='binary_crossentropy',optimizer=Adam(0.0002, 0.5),metrics=['accuracy'])discriminator.trainable = False# 固定判别器的权重# 判别器判断真假valid = discriminator(img)# 构建组合模型combined = Model(z, valid)pile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

7.训练模型

定义以下四个超参数:

latent_dim = 100 # 噪声向量的维度epochs = 10001 # 训练的轮数。为后续方便显示第一个生成图像和最后一个生成图像,故在训练1w轮后,再训练了一轮batch_size = 128 # 每个训练批次的大小sam_inter = 1000 # 图像展示频率。即每隔多少轮训练,就展示一次生成器生成的图像。

训练过程大致分为以下三个部分:

①从FashionMNIST数据集中随机选择一批真实数据,生成一批噪声向量,用生成器生成一 批假数据。

②判别器分别判断这些真实数据和假数据的真假,并计算出它们的损失值。

③根据损失值更新判别器和生成器的权重。

这个过程不断重复,直到达到指定的训练轮数(epoch)。

for epoch in range(epochs):'''训练判别器'''# 随机选择一批真实图片idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# 生成一批假图片noise = np.random.normal(0, 1, (batch_size, latent_dim))gen_imgs = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, np.ones((batch_size, 1)))# 真照片的损失值d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))# 假照片的损失值d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 取平均值,其将被用作反向传播的损失值,用于更新判别器的权重。'''训练生成器'''# 生成一批噪声noise = np.random.normal(0, 1, (batch_size, latent_dim))# 训练生成器g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失#print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))'''展示生成的图片'''if epoch % sam_inter == 0:# 每1000轮展示一次r, c = 3, 3noise = np.random.normal(0, 1, (r * c, latent_dim))gen_imgs = generator.predict(noise)# 将图片像素值调整到0-1之间gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1# 保存部分最终生成的图片if epoch == epochs-1:plt.savefig("最终生成效果.png")plt.show()

将此生成图与前面所展示的样本图片比较,可以发现部分图片已经不易通过肉眼识别出,其为真图片还是假图片了。例如:前面示例的真图片最中间那个Shirt,与此生成图的最右边中间的Shirt。

四、总结

通过上述例子,我们可以发现在仅通过1W轮的训练,所生成的图片,就已经与真实的图片十分相似。倘若经过1亿轮呢?估计已与真实图片别无二致了吧。可以预见,AI绘图定会引发众多行业的变革。

参考资料:

[1406.2661] 生成对抗网络 ()/abs/1406.2661

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。