1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > SRGAN loss部分的pytorch代码实现

SRGAN loss部分的pytorch代码实现

时间:2021-06-03 14:31:35

相关推荐

SRGAN loss部分的pytorch代码实现

转载地址:/forum/thread-137101-1-1.html

作者:雨丝儿

最近在参加华为与高校合做开发mindspore模型的活动,使用mindspore开发了SRGAN模型,下面几篇帖子想针对SRGAN做一些自己的经验分享。这篇帖子分享SRAGAN loss pytorch的实现。

Pytorch版本参考:/dongheehand/SRGAN-PyTorch

Paper中SRGAN的loss:

对于Discriminator:

就是基础GAN中Discriminator的loss

代码实现:

其中gt为原始高分辨率图像,lr为gt经过双三次插值缩小四倍的低分辨率图像,cross_ent为BCELoss()

对与Generator:

Generator的loss包含三部分,一是基础的MSELoss,二是adversarial loss,三是将生成的HR图像与原始高清分辨率图像分别经过预训练的vgg19提取特征后,计算MSELoss.

代码部分:

VGG_loss=perceptual_loss(vgg_net)

cross_ent=nn.BCELoss()

tv_loss=TVLoss()

real_label=torch.ones((args.batch_size,1)).to(device)

fake_label=torch.zeros((args.batch_size,1)).to(device)

fori,tr_datainenumerate(loader):

gt=tr_data['GT'].to(device)

lr=tr_data['LR'].to(device)

output,_=generator(lr)

fake_prob=discriminator(output)

# 第一部分

L2_loss=l2_loss(output,gt)

# 第二部分

adversarial_loss=args.adv_coeff*cross_ent(fake_prob,real_label)

# 第三部分

_percep_loss,hr_feat,sr_feat=VGG_loss((gt+1.0)/2.0,(output+1.0)/2.0,layer=args.feat_layer)

percep_loss=args.vgg_rescale_coeff*_percep_loss

g_loss=L2_loss+adversarial_loss+percep_loss

g_optim.zero_grad()

d_optim.zero_grad()

g_loss.backward()

g_optim.step()

其中vgg19是在imagenet上训练好的vgg19,选取其前37层,args.adv_coeff,args.vgg_rescale_coeff为loss的系数,分别取0.003和0.006。

以上就是srgan loss部分的pytorch代码实现,下篇帖子将分享srgan loss部分minspore代码的实现。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。