1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 【DCGAN】生成对抗网络 手写数字识别

【DCGAN】生成对抗网络 手写数字识别

时间:2020-10-04 17:04:06

相关推荐

【DCGAN】生成对抗网络 手写数字识别

基于paddle,aistudio的DCGAN

主要用于记录自己学习经历。

1 导入必要的包

import osimport randomimport paddleimport paddle.nn as nnimport paddle.optimizer as optimimport paddle.vision.datasets as dsetimport paddle.vision.transforms as transformsimport numpy as npimport matplotlib.pyplot as pltimport matplotlib.animation as animation

2 定义数据集

demo_dataset = paddle.vision.datasets.MNIST(mode='train')

3 查看数据集

demo_dataset[5][0]

4 查看数据集维度

for data in dataloader:breakdata[0].shape

5 参数初始化模块

@paddle.no_grad()def normal_(x, mean=0., std=1.):temp_value = paddle.normal(mean, std, shape=x.shape) # 该op返回符合正态分布(均值为mean,标准差为std的正态随机分布)的随机Tensor。x.set_value(temp_value)return x@paddle.no_grad()def uniform_(x, a=-1., b=1.):temp_value = paddle.uniform(min=a, max=b, shape=x.shape)# 该op返回值服从范围[min,max]内均值分布的随机Tensor,性状为shape,数据类型为dtypex.set_value(temp_value)return x@paddle.no_grad()def constant_(x, value):temp_value = paddle.full(x.shape, value, x.dtype)# 该op创造形状大小为shape并且数据类型为dtype的Tensor,其中元素值均为fill_value。x.set_value(temp_value)return xdef weights_init(m):classname = m.__class__.__name__if hasattr(m, 'weight') and classname.find('Conv') != -1:normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:normal_(m.weight, 1.0, 0.02)constant_(m.bias, 0)

6 生成器代码

# Layer 给予OOD实现的动态图Layer,包含该Layer的参数,前序运动的结构等信息

class Generator(nn.Layer):def __init__(self, ):super(Generator, self).__init__()# 顺序容器。子Layer将按构造函数参数的顺序添加到此容器中。# 传递给构造函数的参数可以Layers或可迭代的name Layer元组。self.gen = nn.Sequential(# input is Z, [B, 100, 1, 1] -> [B, 64 * 4, 4, 4]nn.Conv2DTranspose(100, 64 * 4, 4, 1, 0, bias_attr=False),# 二维转置神经层# 改层根据输入(input),卷积核(kernel)和空洞大小(dilations),步长(stride)# 填充(padding)来计算输出特征层大小或者通过output_size指定输出特征层大小。# 输入(Input)和输出(Output)为NCHW或NHWC格式,其中N是批尺寸,C为通道数(channel)# H为特征层高度,W为特征层宽度。卷积核是MCHW格式,M是输出图像通道数,# C是输入图像通道数,H是卷积核高度,W是卷积核宽度。# 如果组数大于1,C等于输入图像通道数除以组数的结果。# 转置卷积的计算过程相当于卷积的反向计算。# 转置卷积又被称为反卷积(但其实并不是真正的反卷积)。nn.BatchNorm2D(64 * 4),# 该接口用于构建 BatchNorm2D 类的一个可调用对象。# 可以处理4D的Tensor, 实现了批归一化层(Batch Normalization Layer)的功能,# 可用作卷积和全连接操作的批归一化函数,# 根据当前批次数据按通道计算的均值和方差进行归一化。nn.ReLU(True),# state size. [B, 64 * 4, 4, 4] -> [B, 64 * 2, 8, 8]nn.Conv2DTranspose(64 * 4, 64 * 2, 4, 2, 1, bias_attr=False),nn.BatchNorm2D(64 * 2),nn.ReLU(True),# state size. [B, 64 * 2, 8, 8] -> [B, 64, 16, 16]nn.Conv2DTranspose( 64 * 2, 64, 4, 2, 1, bias_attr=False),nn.BatchNorm2D(64),nn.ReLU(True),# state size. [B, 64, 16, 16] -> [B, 1, 32, 32]nn.Conv2DTranspose( 64, 1, 4, 2, 1, bias_attr=False),nn.Tanh()# Tanh激活层)def forward(self, x):return self.gen(x)netG = Generator()# Apply the weights_init function to randomly initialize all weights# to mean=0, stdev=G.apply(weights_init)# 用来对模型的参数进行初始化 将netG中参数都过weights_init进行初始化# Print the modelprint(netG)

7 判别器代码

class Discriminator(nn.Layer):def __init__(self,):super(Discriminator, self).__init__()self.dis = nn.Sequential(# input [B, 1, 32, 32] -> [B, 64, 16, 16]nn.Conv2D(1, 64, 4, 2, 1, bias_attr=False),nn.LeakyReLU(0.2),# state size. [B, 64, 16, 16] -> [B, 128, 8, 8]nn.Conv2D(64, 64 * 2, 4, 2, 1, bias_attr=False),nn.BatchNorm2D(64 * 2),nn.LeakyReLU(0.2),# state size. [B, 128, 8, 8] -> [B, 256, 4, 4]nn.Conv2D(64 * 2, 64 * 4, 4, 2, 1, bias_attr=False),nn.BatchNorm2D(64 * 4),nn.LeakyReLU(0.2),# state size. [B, 256, 4, 4] -> [B, 1, 1, 1] -> [B, 1]nn.Conv2D(64 * 4, 1, 4, 1, 0, bias_attr=False),nn.Sigmoid())def forward(self, x):return self.dis(x)netD = Discriminator()netD.apply(weights_init)print(netD)

8 二分类的交叉熵损失函数

loss = nn.BCELoss()# 该接口用于创建一个BCELoss的可调用类,# 用于计算输入input和标签label之间的二值交叉熵损失值fixed_noise = paddle.rand([32, 100, 1, 1],dtype='float32')real_label = 1.fake_label = 0.optimizerD = optim.Adam(parameters=netD.parameters(),learning_rate=0.0002,beta1=0.5,beta2=0.999)optimizerG = optim.Adam(parameters=netG.parameters(),learning_rate=0.0002,beta1=0.5,beta2=0.999)# Adam优化器出自 Adam论文 的第二节。#能够利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。

9 进行训练,先训练生成器,后训练判别器

losses = [[],[]]now = 0for pass_id in range(10):for batch_id,(data,target) in enumerate (dataloader):# enumerate 一般用于for中 获取字典索引和元素值optimizerD.clear_grad()# 梯度清空 防止积累real_img = databs_size = real_img.shape[0]label = paddle.full((bs_size,1,1,1),real_label,dtype='float32')real_out = netD(real_img)errD_real = loss(real_out,label)errD_real.backward()noise = paddle.randn([bs_size,100,1,1],'float32')fake_img = netG(noise)label = paddle.full((bs_size,1,1,1),fake_label,dtype='float32')fake_out = netD(fake_img.detach())errD_fake = loss(fake_out,label)errD_fake.backward()optimizerD.step()optimizerD.clear_grad()errD = errD_real+errD_fakelosses[0].append(errD.numpy()[0])optimizerG.clear_grad()noise = paddle.randn([bs_size,100,1,1],'float32')fake = netG(noise)label = paddle.full((bs_size,1,1,1),real_label,dtype=np.float32,)output = netD(fake)errG = loss(output,label)errG.backward()optimizerG.step()optimizerG.clear_grad()losses[1].append(errG.numpy()[0])if batch_id % 100 == 0:generated_image = netG(noise).numpy()imgs = []plt.figure(figsize=(15,15))try:for i in range(10):image = generated_image[i].transpose()image = np.where(image > 0, image, 0)image = image.transpose((1,0,2))plt.subplot(10, 10, i + 1)plt.imshow(image[...,0], vmin=-1, vmax=1)plt.axis('off')plt.xticks([])plt.yticks([])plt.subplots_adjust(wspace=0.1, hspace=0.1)msg = 'Epoch ID={0} Batch ID={1} \n\n D-Loss={2} G-Loss={3}'.format(pass_id, batch_id, errD.numpy()[0], errG.numpy()[0])print(msg)plt.suptitle(msg,fontsize=20)plt.draw()plt.savefig('{}/{:04d}_{:04d}.png'.format('work', pass_id, batch_id), bbox_inches='tight')plt.pause(0.01)except IOError:print(IOError)paddle.save(netG.state_dict(), "work/generator.params")

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。