1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

时间:2019-04-16 07:34:12

相关推荐

【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

概述GAN 网络结构GAN 训练流程模型详解生成器判别器

概述

GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.

GAN 网络结构

生成器 (Generator) 输入一个向量, 输出手写数字大小的像素图像.

判别器 (Discriminator) 输入图片, 判断图片是来自数据集还是来自生成器的, 输出标签 (Real / Fake)

GAN 训练流程

第一阶段:

固定判别器, 训练生成器: 使得生成器的技能不断提升, 骗过判别器

第二阶段:

固定生成器, 训练判别器: 使得判别器的技能不断提升, 生成器无法骗过判别器

然后:

循环第一阶段和第二阶段, 使得生成器和判别器都越来越强

模型详解

生成器

class Generator(nn.Module):"""生成器"""def __init__(self, latent_dim, img_shape):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):"""block:param in_feat: 输入的特征维度:param out_feat: 输出的特征维度:param normalize: 归一化:return: block"""layers = [nn.Linear(in_feat, out_feat)]# 归一化if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))# 激活layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(# [b, 100] => [b, 128]*block(latent_dim, 128, normalize=False),# [b, 128] => [b, 256]*block(128, 256),# [b, 256] => [b, 512]*block(256, 512),# [b, 512] => [b, 1024]*block(512, 1024),# [b, 1024] => [b, 28 * 28 * 1] => [b, 784]nn.Linear(1024, int(np.prod(img_shape))),# 激活nn.Tanh())def forward(self, z, img_shape):# [b, 100] => [b, 784]img = self.model(z)# [b, 784] => [b, 1, 28, 28]img = img.view(img.size(0), *img_shape)# 返回生成的图片return img

网络结构:

----------------------------------------------------------------Layer (type)Output Shape Param #================================================================Linear-1 [-1, 128]12,928LeakyReLU-2 [-1, 128]0Linear-3 [-1, 256]33,024BatchNorm1d-4 [-1, 256] 512LeakyReLU-5 [-1, 256]0Linear-6 [-1, 512] 131,584BatchNorm1d-7 [-1, 512] 1,024LeakyReLU-8 [-1, 512]0Linear-9 [-1, 1024] 525,312BatchNorm1d-10 [-1, 1024] 2,048LeakyReLU-11 [-1, 1024]0Linear-12 [-1, 784] 803,600Tanh-13 [-1, 784]0================================================================Total params: 1,510,032Trainable params: 1,510,032Non-trainable params: 0----------------------------------------------------------------Input size (MB): 0.00Forward/backward pass size (MB): 0.05Params size (MB): 5.76Estimated Total Size (MB): 5.82----------------------------------------------------------------

判别器

class Discriminator(nn.Module):"""判断器"""def __init__(self, img_shape):super(Discriminator, self).__init__()self.model = nn.Sequential(# 就是个线性回归nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):# 压平img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity

网络结构:

----------------------------------------------------------------Layer (type)Output Shape Param #================================================================Linear-1 [-1, 512] 401,920LeakyReLU-2 [-1, 512]0Linear-3 [-1, 256] 131,328LeakyReLU-4 [-1, 256]0Linear-5[-1, 1] 257Sigmoid-6[-1, 1]0================================================================Total params: 533,505Trainable params: 533,505Non-trainable params: 0----------------------------------------------------------------Input size (MB): 0.00Forward/backward pass size (MB): 0.01Params size (MB): 2.04Estimated Total Size (MB): 2.05

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。