1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > PyTorch 实现 GAN 生成式对抗网络 含代码

PyTorch 实现 GAN 生成式对抗网络 含代码

时间:2021-12-02 05:16:16

相关推荐

PyTorch 实现 GAN 生成式对抗网络 含代码

GAN

网络结构GAN 公式的理解简单线性 GAN 代码如下卷积 GAN 代码如下Ref

网络结构

GAN 公式的理解

minGmaxDV(D,G)=Ex∼Pdata(x)[logD(x)]+Ez∼Pz(z)[log(1−D(G(z)))]min_Gmax_D V(D,G) = E_{x\sim P_{data}(x)} [logD(x)] + E_{z\sim P_{z}(z)}[log(1-D(G(z)))]minG​maxD​V(D,G)=Ex∼Pdata​(x)​[logD(x)]+Ez∼Pz​(z)​[log(1−D(G(z)))]

理解 GAN 公式是进一步理解 GAN 的必经过程,所以下面就来简单讲讲该公式。一开始我们要定义出判别器和生成器,这里将 DDD 定义为判别器,将 GGG 定义成生成器。接着要做的就是训练判别器,让它可以识别真实数据,也就有了 GAN 公式的前半部分。

Ex∼Pdata(x)[logD(x)]E_{x\sim P_{data}(x)}[logD(x)]Ex∼Pdata​(x)​[logD(x)]

其中,Ex∼Pdata(x)E_{x\sim P_{data}(x)}Ex∼Pdata​(x)​ 表示期望 xxx 从 PdataP_{data}Pdata​ 分布中获取;xxx 表示真实数据, PdataP_{data}Pdata​ 表示真实数据的分布。

前半部分的意思就是:判别器判别出真实数据的概率,判别器的目的就是要最大化这一项,简单来说,就是对于服从 PdataP_{data}Pdata​ 分布的 xxx,判别器可以准确得出 D(x)≈1D(x)\approx 1D(x)≈1。

接着看 GAN 公式略微复杂的后半部分。

Ez∼Pz(z)[log(1−D(G(z)))]E_{z\sim P_z(z)} [log(1-D(G(z)))]Ez∼Pz​(z)​[log(1−D(G(z)))]

其中,Ez∼Pz(z)E_{z\sim P_z(z)}Ez∼Pz​(z)​ 表示期望 zzz 是从 Pz(z)P_z(z)Pz​(z) 分布中获取;zzz 表示生成数据;Pz(z)P_z(z)Pz​(z) 表示生成数据的分布。

对于判别器 DDD 而言,如果向其输入的是生成数据,即 D(G(z))D(G(z))D(G(z)),判别器的目标就是最小化 D(G(z))D(G(z))D(G(z)),即判别器希望 D(G(z))≈0D(G(z))\approx 0D(G(z))≈0,也就是判别器希望 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z))) 最大化。

但对生成器来说,它的目标却与判别器相反,生成器希望自己生成的数据被判别器打上高分,即希望 D(G(z))≈1D(G(z))\approx 1D(G(z))≈1,也就是最小化 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z)))。生成器只能影响 GAN 公式的后半部分,对前半部分没有影响。

现在可以理解公式 V(D,G)=Ex∼Pdata(x)[logD(x)]+Ez∼Pz(z)[log(1−D(G(z)))]V(D,G) = E_{x\sim P_{data}(x)}[logD(x)] + E_{z\sim P_z(z)}[log(1-D(G(z)))]V(D,G)=Ex∼Pdata​(x)​[logD(x)]+Ez∼Pz​(z)​[log(1−D(G(z)))],但为什么 GAN 公式中还有 minGmaxDmin_Gmax_DminG​maxD​ 呢?

要理解 minGmaxDmin_Gmax_DminG​maxD​,就要先回忆一下 GAN 的训练流程。一开始,固定生成器 GGG 的参数专门去训练判别器 DDD。GAN 公式表达的意思也一样,先针对判别器 DDD 去训练,也就是最大化 D(x)D(x)D(x) 和 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z))) 的值,从而达到最大化 V(D,G)V(D,G)V(D,G) 的目的,表达如下:

DG⋆=argmaxDV(D,G)D_G^\star = argmax_D V(D,G)DG⋆​=argmaxD​V(D,G)

当训练完判别器 DDD 后,就会固定判别器 DDD 的参数去训练生成器 GGG,因为此时判别器已经经过一次训练了,所以生成器 GGG 的目标就变成:当 D=DG⋆D=D_G^\starD=DG⋆​ 时,最小化 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z))) 的值,从而达到最小化 V(D,G)V(D,G)V(D,G)的目的。表达如下:

G⋆=argminGV(G,DG⋆)G^\star = argmin_G V(G,D_G^\star)G⋆=argminG​V(G,DG⋆​)

通过上面分成两步的分析,我们可以理解 minGmaxDmin_Gmax_DminG​maxD​ 的含义,简单来说,就是先从判别器 DDD 的角度最大化 V(D,G)V(D,G)V(D,G),再从生成器 GGG 的角度最小化 V(D,G)V(D,G)V(D,G)。

上边公式讲解中,大量使用对数,对数函数在它的定义域内是单调增函数,数据取对数后,并不会改变数据间的相对关系,这里使用对数是为了让计算更加方便。

Ref:《深入浅出GAN生成对抗网络》-廖茂文

简单线性 GAN 代码如下

import torchimport torchvisionimport torch.nn as nnimport torch.nn.functional as Ffrom torchvision import datasetsfrom torchvision import transformsfrom torchvision.utils import save_imagefrom torch.autograd import Variableimport osif not os.path.exists('./img'):os.mkdir('./img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 128num_epoch = 100z_dimension = 100# Image processingimg_transform = pose([transforms.ToTensor(),transforms.Normalize(mean=(0.5), std=(0.5))])# MNIST datasetmnist = datasets.MNIST(root='./data/', train=True, transform=img_transform, download=True)# Data loaderdataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)# Discriminatorclass discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())def forward(self, x):x = self.dis(x)return x# Generatorclass generator(nn.Module):def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())def forward(self, x):x = self.gen(x)return xD = discriminator()G = generator()if torch.cuda.is_available():D = D.cuda()G = G.cuda()# Binary cross entropy loss and optimizercriterion = nn.BCELoss()d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)# Start trainingfor epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorimg = img.view(num_img, -1)real_img = Variable(img).cuda()real_label = Variable(torch.ones(num_img, 1)).cuda()fake_label = Variable(torch.zeros(num_img, 1)).cuda()# compute loss of real_imgreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)real_scores = real_out # closer to 1 means better# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)fake_out = D(fake_img)d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_out # closer to 0 means better# bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# ===============train generator# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)output = D(fake_img)g_loss = criterion(output, real_label)# bp and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))torch.save(G.state_dict(), './generator.pth')torch.save(D.state_dict(), './discriminator.pth')

卷积 GAN 代码如下

__author__ = 'ShelockLiao'import torchimport torch.nn as nnfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom torchvision import datasetsfrom torchvision.utils import save_imageimport osif not os.path.exists('./dc_img'):os.mkdir('./dc_img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 128num_epoch = 100z_dimension = 100 # noise dimensionimg_transform = pose([transforms.ToTensor(),transforms.Normalize((0.5), (0.5))])mnist = datasets.MNIST('./data', transform=img_transform)dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,num_workers=4)class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14)self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7)self.fc = nn.Sequential(nn.Linear(64*7*7, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):'''x: batch, width, height, channel=1'''x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = self.fc(x)return xclass generator(nn.Module):def __init__(self, input_size, num_feature):super(generator, self).__init__()self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56self.br = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True))self.downsample1 = nn.Sequential(nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56nn.BatchNorm2d(50),nn.ReLU(True))self.downsample2 = nn.Sequential(nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56nn.BatchNorm2d(25),nn.ReLU(True))self.downsample3 = nn.Sequential(nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28nn.Tanh())def forward(self, x):x = self.fc(x)x = x.view(x.size(0), 1, 56, 56)x = self.br(x)x = self.downsample1(x)x = self.downsample2(x)x = self.downsample3(x)return xD = discriminator().cuda() # discriminator modelG = generator(z_dimension, 3136).cuda() # generator modelcriterion = nn.BCELoss() # binary cross entropyd_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)# trainfor epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorreal_img = Variable(img).cuda()real_label = Variable(torch.ones(num_img, 1)).cuda()fake_label = Variable(torch.zeros(num_img, 1)).cuda()# compute loss of real_imgreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)real_scores = real_out # closer to 1 means better# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)fake_out = D(fake_img)d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_out # closer to 0 means better# bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# ===============train generator# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)output = D(fake_img)g_loss = criterion(output, real_label)# bp and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './dc_img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch+1))torch.save(G.state_dict(), './generatorConv.pth')torch.save(D.state_dict(), './discriminatorConv.pth')

Ref

/L1aoXingyu/pytorch-beginner/tree/master/09-Generative%20Adversarial%20network

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。