1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > PANet:基于金字塔注意力网络的图像超分辨率重建(Pytorch实现)

PANet:基于金字塔注意力网络的图像超分辨率重建(Pytorch实现)

时间:2024-04-08 09:56:24

相关推荐

PANet:基于金字塔注意力网络的图像超分辨率重建(Pytorch实现)

PANet:基于金字塔注意力网络的图像超分辨率重建

[!]为了提高代码的可读性,本文模型的具体实现与原文具有一定区别,因此会造成性能上的差异

文章目录

PANet:基于金字塔注意力网络的图像超分辨率重建1.相关资料2.简介3.模型结构4.项目实践4.1 准备工作4.2 具体实现4.2.1 导入项目所需库4.2.2 构建数据集4.2.3 构建网络模型# 特征金字塔部分# 模型部分4.2.4 准备训练配件# 优化器# 损失函数# 评估标准## PSRN## SSIM4.2.5 构建训练框架4.2.6 训练结果

1.相关资料

论文下载地址: 传送门原作者代码地址:传送门完整代码地址:传送门

2.简介

PANet(Pyramid Attention with Simple Network Backbones)是一种基于图像恢复金字塔注意力模块的图像修复模型,它能够从多尺度特征金字塔种提取到长距离与短距离的特征关系。受降采样能够有效减少压缩伪影等图像噪声的启发,作者所提出的金字塔利用不同采样倍数的特征图来相互传递注意力信号,以更灵活的方式来借用不同特征尺寸之间的“干净”信息。作者只在一个简单的前馈链接网络中加入了一个金字塔注意力模块,就在绝大多数图像修复任务中达到了SOTA。(这样看来模块确实牛逼)

3.模型结构

直接上图

图上面部分就是传说中的金字塔注意力模块,图下面部分就是PANet的结构(这个结构和SRResNet怪像的,可以参考我的相关文章:SRResNet和SRGAN)金字塔注意力模块的结构分为两个部分:金字塔采样环节S-A Attention。金字塔采样环节就是简单的降采样处理,根据源代码来看,作者使用的是双二次下采样的方法。S-A Attention的结构参考了NLP中最经典的注意力机制结构,即构建了Q,K,V三种特征图来捕获图像在不同尺寸中的信息。与其他注意力机制不同的是,S-A Attention将注意力机制中的按元素相乘环节改成将Q和K特征图作为卷积核(即图中浅蓝色特征层出来的两个特征图)来与V特征图进行卷积/反卷积操作。

4.项目实践

在这里我会一步一步教大家做一个能够成功运行的PANet,完整的代码也会很快推出。

4.1 准备工作

笔者使用的工作环境如下所示:

系统:Windows 10CPU:Intel Core i9-10850KGPU:GeForce RTX 3090

实现代码所需要准备的库为:

PytorchOpenCVNumpyTorchvision

本文使用的是COCO 数据集,其中包含了123,403张照片,大家可以根据自己的需要来使用自己的数据集。

4.2 具体实现

为了方便阅读,部分代码已标注中文注释,而且全部放进了一个代码文件中

完整版代码支持重新打开代码自动恢复到上次训练的功能,只需要关注笔者即可获得:传送门

4.2.1 导入项目所需库

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoader,Dataset,SubsetRandomSamplerimport torch.optim as optimfrom torchvision import utils as vutilsfrom torchvision.utils import save_imageimport osimport cv2import random as raimport numpy as npimport math

4.2.2 构建数据集

class PreprocessDataset(Dataset):def __init__(self,path,size = 96):super().__init__()self.size = size #高清图像的尺寸,这里默认为96x96self.allImgs = list()for root,dirs,files in os.walk(path):self.allImgs = [os.path.join(root,file) for file in files] #获取图像的地址def __len__(self):return len(self.allImgs)def __getitem__(self,index):img = self.allImgs[index]img = cv2.imread(img) img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)height,width,_ = img.shapexStart = ra.randint(0,width-self.size-1)yStart = ra.randint(0,height-self.size-1)img = img[yStart:self.size + yStart,xStart:self.size + xStart,:] #随机裁剪图像if ra.random() > 0.5:img = cv2.flip(img,1) #有50%几率反转图像hr = torch.tensor(np.transpose(img,(2,0,1)))/255.0hr = (hr - 0.5)/0.5 #像素标准化lr = F.max_pool2d(hr,2) #使用最大池化来获得下采样图片return hr,lr

构建完数据集类后,我们可以很方便地构建对应的Dataloader。在这里我只构建了训练集,并没有构建测试集。

path = '你的数据集文件路径'dataset = PreprocessDataset(path,size = 96)trainData = DataLoader(dataset,batch_size = 32,num_workers = 4,shuffle = True)

4.2.3 构建网络模型

# 特征金字塔部分

这里直接改进了原作者的金字塔注意力模块代码,因此代码风格会与其他部分有一定差异。

def extract_image_patches(images, ksizes, strides, rates, padding='same'):"""Extract patches from images and put them in the C output dimension.:param padding::param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window foreach dimension of images:param strides: [stride_rows, stride_cols]:param rates: [dilation_rows, dilation_cols]:return: A Tensor"""assert len(images.size()) == 4assert padding in ['same', 'valid']batch_size, channel, height, width = images.size()if padding == 'same':images = same_padding(images, ksizes, strides, rates)elif padding == 'valid':passelse:raise NotImplementedError('Unsupported padding type: {}.\Only "same" or "valid" are supported.'.format(padding))unfold = torch.nn.Unfold(kernel_size=ksizes,dilation=rates,padding=0,stride=strides)patches = unfold(images)return patches # [N, C*k*k, L], L is the total number of such blocksdef reduce_sum(x, axis=None, keepdim=False):if not axis:axis = range(len(x.shape))for i in sorted(axis, reverse=True):x = torch.sum(x, dim=i, keepdim=keepdim)return xdef same_padding(images, ksizes, strides, rates):assert len(images.size()) == 4batch_size, channel, rows, cols = images.size()out_rows = (rows + strides[0] - 1) // strides[0]out_cols = (cols + strides[1] - 1) // strides[1]effective_k_row = (ksizes[0] - 1) * rates[0] + 1effective_k_col = (ksizes[1] - 1) * rates[1] + 1padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)# Pad the inputpadding_top = int(padding_rows / 2.)padding_left = int(padding_cols / 2.)padding_bottom = padding_rows - padding_toppadding_right = padding_cols - padding_leftpaddings = (padding_left, padding_right, padding_top, padding_bottom)images = torch.nn.ZeroPad2d(paddings)(images)return imagesdef default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2),stride=stride, bias=bias)class BasicBlock(nn.Sequential):def __init__(self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,bn=False, act=nn.PReLU()):m = [conv(in_channels, out_channels, kernel_size, bias=bias)]if bn:m.append(nn.BatchNorm2d(out_channels))if act is not None:m.append(act)super(BasicBlock, self).__init__(*m)class PyramidAttention(nn.Module):def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, conv=default_conv):super(PyramidAttention, self).__init__()self.ksize = ksizeself.stride = strideself.res_scale = res_scaleself.softmax_scale = softmax_scaleself.scale = [1-i/10 for i in range(level)]self.average = averageescape_NaN = torch.FloatTensor([1e-4])self.register_buffer('escape_NaN', escape_NaN)self.conv_match_L_base = BasicBlock(conv,channel,channel//reduction, 1, bn=False, act=nn.PReLU())self.conv_match = BasicBlock(conv,channel, channel//reduction, 1, bn=False, act=nn.PReLU())self.conv_assembly = BasicBlock(conv,channel, channel,1,bn=False, act=nn.PReLU())def forward(self, input):res = input#thetamatch_base = self.conv_match_L_base(input)shape_base = list(res.size())input_groups = torch.split(match_base,1,dim=0)# patch size for matching kernel = self.ksize# raw_w is for reconstructionraw_w = []# w is for matchingw = []#build feature pyramidfor i in range(len(self.scale)): ref = inputif self.scale[i]!=1:ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic',align_corners=True,recompute_scale_factor=True)#feature transformation function fbase = self.conv_assembly(ref)shape_input = base.shape#samplingraw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],strides=[self.stride,self.stride],rates=[1, 1],padding='same') # [N, C*k*k, L]raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)raw_w.append(raw_w_i_groups)#feature transformation function gref_i = self.conv_match(ref)shape_ref = ref_i.shape#samplingw_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],strides=[self.stride, self.stride],rates=[1, 1],padding='same')w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]w_i_groups = torch.split(w_i, 1, dim=0)w.append(w_i_groups)y = []for idx, xi in enumerate(input_groups):#group in a filterwi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],dim=0) # [L, C, k, k]#normalizemax_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),axis=[1, 2, 3],keepdim=True)),self.escape_NaN)wi_normed = wi/ max_wi#matchingxi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*Wyi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]yi = yi.view(1,wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32)# softmax matching scoreyi = F.softmax(yi*self.softmax_scale, dim=1)if self.average == False:yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()# deconv for patch pastingraw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))],dim=0)yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,padding=1)/4.y.append(yi)y = torch.cat(y, dim=0)+res*self.res_scale # back to the mini-batchreturn y

# 模型部分

PANet使用的是SRResNet的骨干

class ResBlock(nn.Module):def __init__(self,inChannals):super().__init__()self.model = nn.Sequential(nn.Conv2d(inChannals,inChannals,kernel_size = 1,bias = False),nn.BatchNorm2d(inChannals),nn.ReLU(inplace = True),nn.Conv2d(inChannals,inChannals,kernel_size = 3,stride = 1,padding = 1,bias = False,padding_mode = 'reflect'),nn.BatchNorm2d(inChannals))def forward(self,input):return F.relu(input + self.model(input),inplace = True)class Sequential(nn.Sequential):def __init__(self,inChannals,blockNum = 8):seq = [ResBlock(inChannals) for _ in range(blockNum)]seq.insert(int(blockNum/2),PyramidAttention(channel=inChannals, level=4))super().__init__(*seq)class Model(nn.Module):def __init__(self,channals = 64,blockNum = 6):super().__init__()self.features = nn.Sequential(nn.Conv2d(3,channals,kernel_size = 7,padding = 3,stride = 1,padding_mode = 'reflect',bias = False),nn.BatchNorm2d(channals),nn.ReLU(inplace = True),nn.Conv2d(channals,channals,kernel_size = 3,padding = 1,stride = 1,padding_mode = 'reflect',bias = False),nn.BatchNorm2d(channals),nn.ReLU(inplace = True))self.sequential = Sequential(channals,blockNum)self.upSample = nn.Sequential(nn.Conv2d(channals,channals * 4,kernel_size = 3,padding = 1,stride = 1,padding_mode = 'reflect'),nn.PixelShuffle(2),nn.Conv2d(channals,channals,kernel_size = 3,padding = 1,stride = 1),nn.ReLU(inplace = True),nn.Conv2d(channals,3,kernel_size = 1,stride = 1),nn.Tanh())def forward(self,input):features = self.features(input)output = self.sequential(features)output = features + outputoutput = self.upSample(output)return output

最后,通过简单的方式我们便可构建一个模型

#如果电脑可以使用显卡,则自动使用显卡加速device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#创建网络模型 net = Model(channals = 64,blockNum = 24).to(device)

4.2.4 准备训练配件

为了对模型进行训练和验证,我们需要以下部件:优化器Optimizer损失函数Criteria评估标注

# 优化器

优化器我们使用了AdamW

optimizer = optim.AdamW(net.parameters(),lr = 1e-4)

# 损失函数

损失函数我们参考了原作者,使用了L1 Loss

criteria = nn.L1Loss()

# 评估标准

我们使用了SSIMPSRN两种标注,他们的代码如下所示:
## PSRN
代码如下:

def PSRN(img1, img2):mse = torch.mean((img1 - img2) ** 2)if mse < 1.0e-10:return 100return 10 * math.log10(255.0**2/mse)

## SSIM
代码如下:

def gaussian(window_size, sigma):gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])return gauss/gauss.sum()def create_window(window_size, channel):_1D_window = gaussian(window_size, 1.5).unsqueeze(1)_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())return windowdef _ssim(img1, img2, window, window_size, channel, size_average = True):mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)mu1_sq = mu1.pow(2)mu2_sq = mu2.pow(2)mu1_mu2 = mu1*mu2sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sqsigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sqsigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2C1 = 0.01**2C2 = 0.03**2ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))if size_average:return ssim_map.mean()else:return ssim_map.mean(1).mean(1).mean(1)def ssim(img1, img2, window_size = 11, size_average = True):(_, channel, _, _) = img1.size()window = create_window(window_size, channel)if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)return _ssim(img1, img2, window, window_size, channel, size_average)

好吧,这两个都是借鉴别人的(老懒狗了

4.2.5 构建训练框架

训练框架如下所示:

if __name__ == '__main__':path = '你的数据集路径'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")dataset = PreprocessDataset(path,size = 96)trainData = DataLoader(dataset,batch_size = 32,num_workers = 4,shuffle = True)net = Model(channals = 64,blockNum = 24).to(device)print(net)criteria = nn.L1Loss()optimizer = optim.AdamW(net.parameters(),lr = 1e-4)totalStep = len(trainData)# 构建可视化结果的保存路径if not os.path.exists('./img'):os.mkdir('./img')for epoch in range(startEpoch,10000):if epoch == 20 or epoch == 40:update_lr(optimizer, multiplier = .1)totalSSIM = 0.0totalPSRN = 0.0totalLoss = 0.0for step,(hr,lr) in enumerate(trainData,1):net.train(True)hr,lr = hr.to(device),lr.to(device)net.zero_grad()output = net(lr)loss = criteria(output,hr)loss.backward()optimizer.step()totalLoss += losstotalSSIM += ssim(output,hr)totalPSRN += PSRN(output,hr)print("[Epoch %d] Step: %d/%d Loss: %.4f|ssim: %.4f|psrn: %.4f" %(epoch,step,totalStep,totalLoss/step,totalSSIM/step,totalPSRN/step))if step >= 100: #对图像进行可视化net.train(False)outputs = net(lr)outputs = torch.cat([hr,outputs],dim = 0)save_image(outputs,'./Img/Result_epoch_%08d.jpg' % epoch,nrow = 8,normalize = True)

完整版代码支持重新打开代码自动恢复到上次训练的功能,只需要关注笔者即可获得:传送门

4.2.6 训练结果

100次训练后结果:

10,000次训练后结果:

此时SSIM:0.6710PSRN:65.6384

由于COCO数据集中的特征不唯一,因此需要更多的训练才能够达到更好的结果。

完整版代码支持重新打开代码自动恢复到上次训练的功能,只需要关注笔者即可获得:传送门

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。