Pytorch搭建SRGAN平台提升图片超分辨率_第1页
Pytorch搭建SRGAN平台提升图片超分辨率_第2页
Pytorch搭建SRGAN平台提升图片超分辨率_第3页
Pytorch搭建SRGAN平台提升图片超分辨率_第4页
Pytorch搭建SRGAN平台提升图片超分辨率_第5页
已阅读5页,还剩3页未读 继续免费阅读

下载本文档

版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领

文档简介

第Pytorch搭建SRGAN平台提升图片超分辨率目录网络构建一、什么是SRGAN二、生成网络的构建三、判别网络的构建训练思路一、判别器的训练二、生成器的训练利用SRGAN生成图片一、数据集的准备二、数据集的处理三、模型训练源码下载地址

网络构建

一、什么是SRGAN

SRGAN出自论文Photo-RealisticSingleImageSuper-ResolutionUsingaGenerativeAdversarialNetwork。

如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。

该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。

SRGAN利用感知损失(perceptualloss)和对抗损失(adversarialloss)来提升恢复出的图片的真实感。

二、生成网络的构建

生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。:

SRGAN的生成网络由三个部分组成。

1、低分辨率图像进入后会经过一个卷积+RELU函数。

2、然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。

3、然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升。

前两个部分用于特征提取,第三部分用于提高分辨率。

importmath

importtorch

fromtorchimportnn

classResidualBlock(nn.Module):

def__init__(self,channels):

super(ResidualBlock,self).__init__()

self.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1)

self.bn1=nn.BatchNorm2d(channels)

self.prelu=nn.PReLU(channels)

self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1)

self.bn2=nn.BatchNorm2d(channels)

defforward(self,x):

short_cut=x

x=self.conv1(x)

x=self.bn1(x)

x=self.prelu(x)

x=self.conv2(x)

x=self.bn2(x)

returnx+short_cut

classUpsampleBLock(nn.Module):

def__init__(self,in_channels,up_scale):

super(UpsampleBLock,self).__init__()

self.conv=nn.Conv2d(in_channels,in_channels*up_scale**2,kernel_size=3,padding=1)

self.pixel_shuffle=nn.PixelShuffle(up_scale)

self.prelu=nn.PReLU(in_channels)

defforward(self,x):

x=self.conv(x)

x=self.pixel_shuffle(x)

x=self.prelu(x)

returnx

classGenerator(nn.Module):

def__init__(self,scale_factor,num_residual=16):

upsample_block_num=int(math.log(scale_factor,2))

super(Generator,self).__init__()

self.block_in=nn.Sequential(

nn.Conv2d(3,64,kernel_size=9,padding=4),

nn.PReLU(64)

self.blocks=[]

for_inrange(num_residual):

self.blocks.append(ResidualBlock(64))

self.blocks=nn.Sequential(*self.blocks)

self.block_out=nn.Sequential(

nn.Conv2d(64,64,kernel_size=3,padding=1),

nn.BatchNorm2d(64)

self.upsample=[UpsampleBLock(64,2)for_inrange(upsample_block_num)]

self.upsample.append(nn.Conv2d(64,3,kernel_size=9,padding=4))

self.upsample=nn.Sequential(*self.upsample)

defforward(self,x):

x=self.block_in(x)

short_cut=x

x=self.blocks(x)

x=self.block_out(x)

upsample=self.upsample(x+short_cut)

returntorch.tanh(upsample)

三、判别网络的构建

判别网络的构成如上图所示:

SRGAN的判别网络由不断重复的卷积+LeakyRELU和标准化组成。

对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果。

判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

实现代码如下:

classDiscriminator(nn.Module):

def__init__(self):

super(Discriminator,self).__init__()

=nn.Sequential(

nn.Conv2d(3,64,kernel_size=3,padding=1),

nn.LeakyReLU(0.2),

nn.Conv2d(64,64,kernel_size=3,stride=2,padding=1),

nn.BatchNorm2d(64),

nn.LeakyReLU(0.2),

nn.Conv2d(64,128,kernel_size=3,padding=1),

nn.BatchNorm2d(128),

nn.LeakyReLU(0.2),

nn.Conv2d(128,128,kernel_size=3,stride=2,padding=1),

nn.BatchNorm2d(128),

nn.LeakyReLU(0.2),

nn.Conv2d(128,256,kernel_size=3,padding=1),

nn.BatchNorm2d(256),

nn.LeakyReLU(0.2),

nn.Conv2d(256,256,kernel_size=3,stride=2,padding=1),

nn.BatchNorm2d(256),

nn.LeakyReLU(0.2),

nn.Conv2d(256,512,kernel_size=3,padding=1),

nn.BatchNorm2d(512),

nn.LeakyReLU(0.2),

nn.Conv2d(512,512,kernel_size=3,stride=2,padding=1),

nn.BatchNorm2d(512),

nn.LeakyReLU(0.2),

nn.AdaptiveAvgPool2d(1),

nn.Conv2d(512,1024,kernel_size=1),

nn.LeakyReLU(0.2),

nn.Conv2d(1024,1,kernel_size=1)

defforward(self,x):

batch_size=x.size(0)

returntorch.sigmoid((x).view(batch_size))

训练思路

SRGAN的训练可以分为生成器训练和判别器训练:

每一个step中一般先训练判别器,然后训练生成器。

一、判别器的训练

在训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签。

因此判别器的训练步骤如下:

1、随机选取batch_size个真实高分辨率图片。

2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。

3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。

二、生成器的训练

在训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。

因此生成器的训练步骤如下:

1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。

2、将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss

利用SRGAN生成图片

SRGAN的库整体结构如下:

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

二、数据集的处理

打开txt_anno

温馨提示

  • 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
  • 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
  • 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
  • 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
  • 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
  • 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
  • 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。

最新文档

评论

0/150

提交评论