1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > GAN二次元头像生成Pytorch实现(附完整代码)

GAN二次元头像生成Pytorch实现(附完整代码)

时间:2020-05-21 10:17:52

相关推荐

GAN二次元头像生成Pytorch实现(附完整代码)

介绍

本文是李宏毅GAN课程课后作业HW3_1(二次元头像生成,Keras实现)的Pytorch版本。写这篇的原因是一方面刚开始接触GAN,二是个人比较习惯用Pytorch,所以将keras改成Pytorch版本。

实现所需要的资源:

链接:/s/1cLmFNQpJe1DOI96IVuvVyQ

提取码:nha2

本文一个改动就是将kernel=4变成了3,因为kernel一般都是奇数。其他和原网络基本相同。

下面是主要部分的代码,包括网络模块和训练/验证/测试两个模块。

完整的代码见 /AsajuHuishi/Generate_a_quadratic_image_with_GAN

import torchimport torch.optim as optimimport torch.optim.lr_scheduler as lr_schedulerimport torch.nn as nnfrom torch.autograd import Variableimport matplotlib.pyplot as pltimport numpy as npimport osimport argparseimport timeimport visdom

1.网络模块

生成器

##定义卷积核def default_conv(in_channels,out_channels,kernel_size,bias=True):return nn.Conv2d(in_channels,out_channels,kernel_size,padding=kernel_size//2, #保持尺寸bias=bias)##定义ReLUdef default_relu():return nn.ReLU(inplace=True)## reshapedef get_feature(x):return x.reshape(x.size()[0],128,16,16)class Generator(nn.Module):def __init__(self,input_dim=100,conv=default_conv,relu=default_relu,reshape=get_feature):super(Generator,self).__init__()head = [nn.Linear(input_dim,128*16*16),relu()]self.reshape = reshape #16x16body = [nn.Upsample(scale_factor=2,mode='nearest'), #32x32conv(128,128,3),relu(),nn.Upsample(scale_factor=2,mode='nearest'), #64x64conv(128,64,3),relu(),conv(64,3,3),nn.Tanh()]self.head = nn.Sequential(*head)self.body = nn.Sequential(*body)def forward(self,x):#x:(batchsize,input_dim)x = self.head(x)x = self.reshape(x)x = self.body(x)return x #(batchsize,3,64,64)def name(self):return 'Generator'

判别器

class Discriminator(nn.Module):def __init__(self,conv=default_conv,relu=default_relu):super(Discriminator,self).__init__()main = [conv(3,32,3),relu(),conv(32,64,3),relu(),conv(64,128,3),relu(),conv(128,256,3),relu()]self.main = nn.Sequential(*main)self.fc = nn.Linear(256*64*64,1)self.sigmoid = nn.Sigmoid()def forward(self,x):#x:(batchsize,3,64,64)x = self.main(x)#(b,256,64,64)x = x.view(x.size()[0],-1)#(b,256*64*64)x = self.fc(x) #(b,1)x = self.sigmoid(x)return xdef name(self):return 'Discriminator'

2.训练/验证/测试模块

相关参数、模型初始化

class GAN(nn.Module):def __init__(self,args):super(GAN,self).__init__()self.img_size = 64self.channels = 3 self.latent_dim = args.latent_dimself.num_epoch = args.num_epochself.batch_size = args.batch_sizeself.cuda = args.cudaself.interval = 20 #每相邻20个epoch验证一次self.continue_training = args.continue_training #是否是继续训练 ## 生成器初始化self.generator = Generator(self.latent_dim)## 判别器初始化self.discriminator = Discriminator()self.testmodelpath = args.testmodelpathself.datapath = args.datapathif self.cuda:self.generator.cuda()self.discriminator.cuda()self.continue_training_isrequired()

训练+dataloader数据集

def trainer(self):## 读入图片数据,分batchprint('===> Data preparing...')import torchvision.transforms as transformsfrom torch.utils.data import DataLoaderfrom torchvision.datasets import ImageFoldertransform = transforms.ToTensor() ##dataloader输出是tensor,不加这个会报错dataset = ImageFolder(self.datapath,transform=transform)dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True) ##drop_last: dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃num_batch = len(dataloader) #batch的数量为len(dataloader)=总图片数/batchsizeprint('num_batch:',num_batch)#dataloader: (batchsize,3,64,64) 分布0-1## 判别值target_real = Variable(torch.ones(self.batch_size,1))target_false = Variable(torch.zeros(self.batch_size,1))one_const = Variable(torch.ones(self.batch_size,1))if self.cuda:target_real = target_real.cuda()target_false = target_false.cuda()one_const = one_const.cuda()## 优化器optim_generator = optim.Adam(self.generator.parameters(),lr=0.0002,betas=(0.5,0.999))optim_discriminator = optim.Adam(self.discriminator.parameters(),lr=0.0002,betas=(0.5,0.999))## 误差函数# content_criterion = nn.MSELoss()adversarial_criterion = nn.BCELoss()## 训 练 开 始for epoch in range(self.start_epoch,self.num_epoch): ##epoch##用于观察一个epoch不同batch的平均lossmean_dis_loss = 0.0mean_gen_con_loss = 0.0mean_gen_adv_loss = 0.0mean_gen_total_loss = 0.0for i,data in enumerate(dataloader): ##循环次数:batch的数量为len(dataloader)=总图片数//batchsizeif epoch<3 and i%10==0:print('epoch%d: %d/%d'%(epoch,i,len(dataloader)))##1.1生成noisegen_input = np.random.normal(0,1,(self.batch_size,self.latent_dim)).astype(np.float32)gen_input = torch.from_numpy(gen_input)gen_input = torch.autograd.Variable(gen_input,requires_grad=True)if self.cuda:gen_input = gen_input.cuda()fake = self.generator(gen_input) ##生成器生成的(batchsize,3,64,64)real, _ = data #data:list[tensor,tensor]取第零个 real:(batchsize,3,64,64)if self.cuda:real = real.cuda()fake = fake.cuda()## 1.固定G,训练判别器Dself.discriminator.zero_grad()dis_loss1 = adversarial_criterion(self.discriminator(real),target_real)dis_loss2 = adversarial_criterion(self.discriminator(fake.detach()),target_false)##注意经过G的网络再进入D网络之前要detach()之后再进入dis_loss = 0.5*(dis_loss1+dis_loss2)#print('epoch:%d--%d,判别器loss:%.6f'%(epoch,i,dis_loss))dis_loss.backward()optim_discriminator.step()mean_dis_loss+=dis_loss## 2.固定D,训练生成器Gself.generator.zero_grad()##生成noisegen_input = np.random.normal(0,1,(self.batch_size,self.latent_dim)).astype(np.float32)gen_input = torch.from_numpy(gen_input)gen_input = torch.autograd.Variable(gen_input,requires_grad=True)if self.cuda:gen_input = gen_input.cuda() fake = self.generator(gen_input) ##生成器生成的(batchsize,3,64,64) gen_con_loss = 0gen_adv_loss = adversarial_criterion(self.discriminator(fake),one_const)##固定D更新Ggen_total_loss = gen_con_loss + gen_adv_loss#print('epoch:%d--%d,生成器loss:%.6f'%(epoch,i,gen_total_loss))gen_total_loss.backward()optim_generator.step()mean_gen_con_loss+=gen_con_lossmean_gen_adv_loss+=gen_adv_lossmean_gen_total_loss+=gen_total_loss## 一个epoch输出一次print('epoch:%d/%d'%(epoch, self.num_epoch))print('Discriminator_Loss: %.4f'%(mean_dis_loss/num_batch))print('Generator_total_Loss:%.4f'%(mean_gen_total_loss/num_batch))## 保存模型state_dis = {'dis_model': self.discriminator.state_dict(), 'epoch': epoch}state_gen = {'gen_model': self.generator.state_dict(), 'epoch': epoch}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint') torch.save(state_dis, 'checkpoint/'+self.discriminator.name()+'__'+str(epoch+1)) #each epochtorch.save(state_gen, 'checkpoint/'+self.generator.name()+'__'+str(epoch+1))#each epochtorch.save(state_dis, 'checkpoint/'+self.discriminator.name()) #final torch.save(state_gen, 'checkpoint/'+self.generator.name()) #final ## 验证模型if epoch<45 or epoch%self.interval==0:self.validater(epoch)print('--'.center(12,'-'))

验证

def validater(self,epoch):vis = visdom.Visdom(env='generate_girl_epoch%d'%(epoch))r,c = 3,3gen_input_val = np.random.normal(0,1,(r*c,self.latent_dim)).astype(np.float32)gen_input_val = torch.from_numpy(gen_input_val)gen_input_val = torch.autograd.Variable(gen_input_val)if self.cuda:gen_input_val = gen_input_val.cuda() output_val = self.generator(gen_input_val)#(r*c,3,64,64)output_val = output_val.cpu()output_val = output_val.data.numpy()#(r*c,3,64,64) img = np.transpose(output_val,(0,2,3,1)) #(r*c,64,64,3) fig, axs = plt.subplots(r,c)cnt = 0for i in range(r):for j in range(c):vis.image(output_val[cnt],opts={'title':'epoch%d_cnt%d'%(epoch,cnt)}) axs[i, j].imshow(img[cnt, :, :, :])axs[i, j].axis('off')cnt += 1 if not os.path.isdir('images'):os.mkdir('images') fig.savefig('images/val_%d.png'%(epoch+1)) ##保存验证结果plt.close()

测试

def tester(self,gen_input_test): #输入(N,latent_dim)assert gen_input_test.shape[1]==self.latent_dim, \'dimension 1''s size expect %d,but input %d'%(self.latent_dim,gen_input_test.shape[1])gen_input_test = gen_input_test.astype(np.float32)gen_input_test = torch.from_numpy(gen_input_test)gen_input_test = torch.autograd.Variable(gen_input_test)if self.cuda:gen_input_test = gen_input_test.cuda() ## 下载验证结果if os.path.isdir('checkpoint'):try:checkpoint_gen = torch.load(self.testmodelpath)self.generator.load_state_dict(checkpoint_gen['gen_model'])except FileNotFoundError:print('Can\'t found dict')output_test = self.generator(gen_input_test)output_test = output_test.cpu()output_test = output_test.data.numpy()#(N,3,64,64)img = np.transpose(output_test,(0,2,3,1)) #(N,64,64,3) if not os.path.isdir('images'):os.mkdir('images') N = img.shape[0] #图像个数for i in range(N):plt.imshow(img[i, :, :, :])plt.axis('off')plt.savefig('images/test_%d.png'%(i+1)) ##保存结果plt.close()

结果和原keras相比没什么区别,毕竟网络都差不多,也不需要过高期望,而且网络本身比较小,生成一个好看的人脸,对是五官是否协调有很大的要求,是很有挑战的事情。

输入:

np.random.normal(0,1,(1,self.latent_dim)).astype(np.float32)

部分结果:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。