1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > pytorch复现经典生成对抗式的超分辨率网络

pytorch复现经典生成对抗式的超分辨率网络

时间:2022-03-27 19:53:39

相关推荐

pytorch复现经典生成对抗式的超分辨率网络

论文原文:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

论文的中文翻译:翻译:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

网络结构如下图所示:

上面和下面分别是生成网络和判别网络:

废话不多说,直接看代码。比较不喜欢一堆废话的博客,懂得都懂,最有用的就是代码!

代码的实现参考pytorch torchvision中的网络实现优点:模块化、简洁易读、而且容易修改网络宽度和深度(方便调整网络架构做对比实验,消融实验)。

实现生成网络:

# -*- coding: utf-8 -*-# @Use:# @Time : /8/17 21:32# @FileName: models.py# @Software: PyCharm# @inference:/pytorch/vision/blob/main/torchvision/models/resnet.pyimport torchfrom torch import nnimport torchvisionfrom torch import Tensorclass GeneratorBasicBlock(nn.Module):"""生成器重复的部分"""def __init__(self, channel, kernel_size) -> None:super(GeneratorBasicBlock, self).__init__()self.channel = channelself.conv1 = nn.Conv2d(in_channels=channel, out_channels=channel,kernel_size=(kernel_size, kernel_size),stride=(1, 1), padding=(1, 1))self.bn1 = nn.BatchNorm2d(num_features=channel)self.p_relu1 = nn.PReLU()self.conv2 = nn.Conv2d(in_channels=channel, out_channels=channel,kernel_size=(kernel_size, kernel_size),stride=(1, 1), padding=(1, 1))self.bn2 = nn.BatchNorm2d(num_features=channel)def forward(self, x: Tensor) -> Tensor:"""前向推断:param x::return:"""identity = xout = self.conv1(x)out = self.bn1(out)out = self.p_relu1(out)out = self.conv2(out)out = self.bn2(out)out += identityreturn outclass PixelShufflerBlock(nn.Module):"""生成器最后的pixelshuffler"""def __init__(self, in_channel, out_channel) -> None:super(PixelShufflerBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.pixels_shuffle = nn.PixelShuffle(upscale_factor=2)self.prelu = nn.PReLU()def forward(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.pixels_shuffle(out)out = self.prelu(out)return outclass Generator(nn.Module):"""生成器"""def __init__(self, config) -> None:# Generator parameterssuper(Generator, self).__init__()large_kernel_size = config.G.large_kernel_size # = 9small_kernel_size = config.G.small_kernel_size # = 3n_channels = config.G.n_channels # = 64n_blocks = config.G.n_blocks # = 16base_block_type = config.G.base_block_type # 'depthwise_conv_residual' # 'conv_residual' or 'depthwise_conv_residual'# base blockif base_block_type == 'depthwise_conv_residual':self.repeat_block = GeneratorDepthwiseBlockif base_block_type == 'conv_residual':self.repeat_block = GeneratorBasicBlockself.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,kernel_size=(large_kernel_size, large_kernel_size),stride=(1, 1), padding=(4, 4))self.prelu1 = nn.PReLU()self.B_residul_block = self._make_layer(self.repeat_block, n_channels,n_blocks, small_kernel_size)self.conv2 = nn.Conv2d(in_channels=n_channels, out_channels=n_channels,kernel_size=(small_kernel_size, small_kernel_size),stride=(1, 1), padding=(1, 1))self.bn1 = nn.BatchNorm2d(n_channels)self.pixel_shuffle_block1 = PixelShufflerBlock(n_channels, 4 * n_channels)self.pixel_shuffle_block2 = PixelShufflerBlock(n_channels, 4 * n_channels)self.conv3 = nn.Conv2d(in_channels=n_channels, out_channels=3,kernel_size=(large_kernel_size, large_kernel_size),stride=(1, 1), padding=(4, 4))def _make_layer(self, base_block, n_channels, n_block, kernel_size) -> nn.Sequential:"""构建重复的B个基本块:param base_block: 基本块:param n_channels: 块里面的通道数:param n_block: 块数:return:"""layers = []self.base_block = base_blockfor _ in range(n_block):layers.append(self.base_block(n_channels, kernel_size))return nn.Sequential(*layers)def _forward_impl(self, x: Tensor) -> Tensor:"""前向的实现"""out = self.conv1(x)out = self.prelu1(out)identity = outout = self.B_residul_block(out)out = self.conv2(out)out = self.bn1(out)out += identityout = self.pixel_shuffle_block1(out)out = self.pixel_shuffle_block2(out)out = self.conv3(out)return outdef forward(self, x: Tensor) -> Tensor:"""前向"""return self._forward_impl(x)

判别网络实现:

class DiscriminatorBaseblock(nn.Module):"""判别器的基本块"""def __init__(self, inchannel, out_chanel, kernel_size, stride):super(DiscriminatorBaseblock, self).__init__()self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=out_chanel,kernel_size=kernel_size, stride=stride, padding=(1, 1))self.bn1 = nn.BatchNorm2d(out_chanel)self.act1 = nn.LeakyReLU(0.2)def forward(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.bn1(out)out = self.act1(out)return outclass Discriminator(nn.Module):"""判别器"""def __init__(self, config):super(Discriminator, self).__init__()# Discriminator parameterskernel_size = config.D.kernel_size = 3n_channels = config.D.n_channels = 64n_blocks = config.D.n_blocks = 8fc_size = config.D.fc_size = 1024self.conv1 = nn.Conv2d(in_channels=3, out_channels=n_channels,kernel_size=(kernel_size, kernel_size), stride=(1, 1), padding=(1, 1))self.leaky_relu1 = nn.LeakyReLU(0.2)conv_configs = [[3, 64, 2], # 配置二维数组[3, 128, 1],[3, 128, 2],[3, 256, 1],[3, 256, 2],[3, 512, 1],[3, 512, 2]]self.base_blocks = self._make_layer(n_channels, DiscriminatorBaseblock, conv_configs)self.dense1 = nn.Linear(512 * 6 * 6, 1024)self.leaky_relu2 = nn.LeakyReLU(0.2)self.dense2 = nn.Linear(1024, 1)self.sigmod1 = nn.Sigmoid()def _make_layer(self, in_channel, base_block, conv_configs: list) -> nn.Sequential:""":param base_block: DiscriminatorBaseblock:param conv_configs: (kernel, channel, stride):return:"""layers = []self.base_block = base_blockself.in_channel = in_channelfor i in range(len(conv_configs)):# in_channel, out_chanel, kernel_size, stridekernel_size = (conv_configs[i][0], conv_configs[i][0])stride = (conv_configs[i][2], conv_configs[i][2])out_channel = conv_configs[i][1]layers.append(self.base_block(self.in_channel, out_channel, kernel_size, stride))self.in_channel = out_channelreturn nn.Sequential(*layers)def _forward_impl(self, x: Tensor) -> Tensor:"""前向"""out = self.conv1(x)out = self.leaky_relu1(out)out = self.base_blocks(out)out = torch.flatten(out, 1)out = self.dense1(out)out = self.leaky_relu2(out)out = self.dense2(out)out = self.sigmod1(out)return outdef forward(self, x: Tensor) -> Tensor:"""前向"""return self._forward_impl(x)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。