最近AIGC的爆火,不管是AI绘图还是ChatGPT,都让生成式模型成为了大家关注的焦点。而在目前主流图像生成模型DiffusionNet之前,相信没有人不承认GAN(Generative Adversarial Nets)是生成模型中划时代的作品,以至于当时GAN的衍生模型异常之多。这篇文章就来介绍一个较为著名的GAN的衍生模型——pix2pix[1]。
本文将以以下几个方面来对模型进行介绍:
GAN系列入门介绍
模型结构及基本原理
patchGAN
一些消融实验
一、GAN系列入门介绍
1. GAN:GAN[2]的中文全称叫做生成对抗模型,模型包含两个部分,生成器(Generator)与鉴别器(Discriminator)。简单来说,生成器的作用是生成假图像,鉴别器的作用是来辨别图像真伪,通过两者的对抗,鉴别器不断提高自己的鉴别能力,而生成器不断提高自己的生成能力,最终当鉴别器无法“有信心”地判断输入的真假时,我们也就可以认为生成器已经学会了以假乱真的生成能力。
鉴别器的输入是从数据集中取出的真实图像与生成器生成的图像,输出则是鉴别器认为该输入是真实数据的概率。也就是说,当此输入稳定在0.5附近时,我们可以认为鉴别器难以判断输入是真是假。
而生成器的本质就是一个解码器(decoder),他的输入是一个来自n维标准正态分布的n维向量,通过解码器解码得到与真实数据维度相等的图片,即生成图片。而一旦这个生成器通过对抗的方式训练完成,我们即可以随便选取一个n维标准正态分布向量,解码得到新的图片(数据集中未曾出现过的)。
GAN模型不仅结构简单,原理不复杂,并且可以生成数据集中未曾出现过的图片,因此在很长一段时间(实际上可以说时至今日)都成为了生成模型的主流研究/使用对象,并出现了一系列的变种来解决不同的下游问题。
2. conditional GAN:上面讲到,GAN的输入是一个n维向量,而输出是某类图片,即使这些图片都属于同一类别,有着相似的风格,但我们却无法控制生成出来的数据长什么样子。而如果我们想要让生成出来的数据可控,我们通常需要给他一个额外的输入标签作为指导条件。这类模型一般称作contional model,而基于这种思想衍生出来的GAN模型,被称为conditional GAN(cGAN)[3]。简单来说,在cGAN中,指导条件(称作y)也会编码成向量形式,通过concatenate的方式与随机向量z融合,并放入生成器中生成图像G(z,y)。在鉴别阶段,y依然会作为额外信息,通过多层映射与真实数据x、生成数据G(z,y)融合,形成新的向量,送入鉴别器进行判断。
3. pix2pix:有一类任务叫做image-to-image translation。也就是输入和输出是来自两个不同集合(设为A和B)的图片,且我们一般认为它们是有对应关系的。比如输入黑白照片(A)输出彩色照片(B),输入轮廓照片(A)输出色彩填充照片(B)等(如图1),本文介绍的pix2pix模型所处理的就是这类任务。并且原文作者通过一系列实验,证明了conditional GAN在这类问题上的有效性,也就是说,pix2pix本质上是一种特殊的conditional GAN。
二、模型结构及基本原理
图2给出了模型的基本结构图,其中G为生成器,D为鉴别器。由于我们的输入是图像而非低维向量,因此G不再是一个简单的解码器,而是一个编码-解码的结构(encoder-decoder)。近些年来,编码-解码结构用的最多的就是U-Net[4],在传统的编码-解码结构上添加了skip-connection结构,将encode过程中卷积得到的不同尺寸的特征图,直接concatenate到decode过程中相应尺寸的特征图上,这样避免了一些特征在下采样过程中的损失,尽可能的保留了原始图像在不同尺寸上的特征信息。
鉴别器D也区别于传统GAN的鉴别器,使用的叫做Patch Discriminator,这个部分将在第三节进行详细讲解。注意,在pix2pix模型中,G与D都会看到输入x(即图2中的轮廓图),在G中,x作为输入来通过编码-解码结构获得G(x),而在D中,x作为指导条件(conditions)来辅助鉴别器进行判断。所以pix2pix本质上就是一个cGAN。
有人可能会问,如果我们不让鉴别器看到x,只让x作为输入进行编码,模型会变差吗?由于鉴别器中不加入x,更像传统GAN(虽然生成器但从解码器变成了编码解码结构),这个问题也可以转换成:在这个任务中,cGAN真的要强于GAN吗,加入condition真的有提升吗?
关于这个问题,作者做了消融实验,并验证了cGAN相比于GAN确实表现更出色。本文将在第四节给出消融实验的结果。
这里面还有一个问题,由于我们使用了encoder-decoder结构的生成器,这样的话,由于训练完成后,模型参数不再变化,这会使得任一确定的输入图像都会被编码成对应的确定的向量,再通过参数固定的解码器,输出图像也将确定(deterministic)。这样就失去了生成模型的随机性。因此参考传统GAN模型,我们需要引进服从标准正态分布的随机向量z来增添其随机性。于是我们看到的损失函数公式如下
其中生成器生成的图像是G(x,z)而非G(x),这里面的z就是一个随机向量。另外,在过去使用的cGAN中,z一般是作为额外的输入向量输入到模型中,但pix2pix的作者通过实验发现,在pix2pix模型里z作为输入向量效果并不好,模型会很容易地学会如何忽视掉这个随机向量。因此,作者将z作为每层网络的dropout的形式加到了模型中来增加随机性。作者提到,即使这么做,模型的随机性依然不好,因此作者认为如何通过cGAN生成随机性很强的输出,将是未来的研究方向之一。
最后,pix2pix模型的损失函数共有两部分组成,上面列出的只是GAN loss这个部分,由于我们不仅希望输出的图片“看起来真”,还要让输出G(x,z)在构图结构及细节上更贴近目标图像y。因此,我们还需要引入像素级别的损失函数,来让对应像素的值尽可能接近。这类损失函数使用最多的就是L1和L2损失。于是最终损失函数如下:
对于损失函数,这里面有一点需要简要拓展一下:无论单独使用是L1损失还是L2损失,都会使得结果偏向模糊。这是因为这两种损失函数均是对对应像素差取均值,这样的话会使得输出的像素分布更加平缓,从而只能很好的保留低频信息,却无法生成准确的高频信息,因此从视觉感受上会明显感觉出差异(一眼模糊)。但是好在我们有GAN损失函数,他是专门处理“看着不像”的问题的,因此L1+cGAN的损失函数可以最大程度还原我们想要的图像。更多对比实验将在第四节展示。
三、patchGAN
上一节提到,与传统GAN中的鉴别器将整张图片映射到一个标量概率值不同,pix2pix是先将图像打成N×N的patches,再将每个patch送到鉴别器中进行判别,最后取得判别的均值作为最终结果。这种方法并非pix2pix首创,而在一种名叫Markovian GAN[5]的模型中已经开始使用。下面我们通过两个方面来对这个方法的细节做一些粗浅的解释:
1. 为什么要用patch discriminator?
其实图3一幅图就可以很简单的阐明这个方法的原理。假设图3中坐标系上的每一个点表示一个图像,蓝色点表示输入数据点,红色点表示输出数据点(也就是来自我们想要得到的图像空间)。由于一般统计算法做了图像分布多为标准正态分布的假设前提,那么就如图3中的第一个图所示,当整体分布已经拟合的很好的时候(可以看到红色圈与蓝色圈基本重合),模型就会停止学习。但实际上,现实中大多数数据分布并不是服从正态分布,而是更为复杂的分布,于是对抗学习就起了大作用,它会不断的让红点与蓝点重合,让输出分布尽量拟合输入分布,这也就是图3第二第三张图所示部分。
而[5]作者将图像分成patches再去做鉴别,相当于是对图像空间/分布进行进一步的细化,这使得输入与输出的图像分布可以更进一步、更细化地拟合,从而得到更好的效果。
另外,由于鉴别器对patches进行判别,输入尺寸大幅减小,因此参数量也大幅降低,模型运行速度也随之大幅提高,这样我们就可以在保证效率的前提下处理任意大的图片。
2. 为什么敢用patch discriminator?
有人可能会问,把整图打成patch,会不会影响模型对整体的把握,会不会丢失全局信息?
[1]中作者提出,由于L1损失已经能够很好地保留低频信息,也就是说,就算不加cGAN损失,我们也已经能够得到很好的色块分布与结构相似度。因此,我们可以大胆的使用patch discriminator,并且把patch当鉴别器输入后,模型可以学习更清晰、精确的高频特征,与L1损失达成互补,使得输出更加精确。
如果想更形象地表示,你可以认为使用L1损失生成的模糊图像,是老花眼人士眼中的图像,虽然非常模糊,但基本可以看清大体轮廓和色彩分布。而patchGAN在这里起到的是放大镜的作用,不断移动放大镜,可以看到不同位置的细节。由于虽然没有放大镜看不清,但我们已经有了大体轮廓,因此当使用放大镜看到每一处局部细节后,我们就可以想象出图像整体清晰的样子。
另外,形成共识的是,图像可以看作Markovian Random Field(MRF),也就是某点像素只与其边上的像素强关联,与远处的像素没有很强的关联性。基于这个先验知识,我们便可以大胆的将图像打成patch进行学习/鉴别,从而不会影响结果。关于MRF的相关知识,请读者自行学习。
四、一些消融实验
1. 使用L1+cGAN损失比单独使用L1/cGAN都要好
图4中,从左至右分别代表:输入,目标图像,只用L1,只用cGAN,L1+cGAN。可以看到只用L1的话,图像确实十分模糊,但是大体的色块分布与结构信息已经学到了。而单独使用cGAN,高频信号非常多(即颜色突变多),图片整体锐化程度过大。
2. cGAN比GAN(不让鉴别器看到输入)要好
图5中可以看到,即使单独使用L1,也要比单独使用GAN好很多,而cGAN更是明显优于GAN。
3. U-Net比传统的encoder-decoder结构要好
图6的结果表明,无论使用哪种损失函数,Unet的结果均要好于传统的encoder-decoder模型,这说明Unet中的skip-connection确实将下采样过程中丢失的信息保存了下来并提高了上采样的精度。另外原文也给了数值指标结果,以L1+cGAN损失为例,U-net的per-pixel,per-class的准确度分别为0.55与0.20,均优于encoder-decoder的0.29与0.09。
4. 不同尺寸的patch下,patch GAN学习的能力表现
图7列出了不同patch尺寸下的学习结果,第一张图只使用了L1,结果相当模糊,后四张均用的是L1+patchGAN损失。第二张图使用的是1×1大小的patch,其实就是一个像素,因此也叫pixelGAN,可以看到车的色彩发生了变化,说明即使看不到邻近信息,GAN依然学会了加强色彩变幻。16×16与70×70都或多或少加入了一些细节,无论从数值结果还是视觉效果,都看得出来70×70效果最优。286×286相当于把整图送进去学习,因此也称作ImageGAN,可以看到视觉效果依然不如70×70,主要由于特征学习的过于全局化,导致局部细节并不和谐。
五、写在最后
pix2pix是我研究生毕业课题中使用的模型,用在了医学影像相关领域的研究上,这也侧面说明了这个模型的泛用性。正如作者所说,在pix2pix模型之前,cGAN其实已经广泛用于各种生成任务中,比如图像修复(inpaiting)、风格迁移(style transfer)、提高分辨率(superresolution)等任务上。但作者认为他们最大的贡献,是构建了一个通用模型,可以在多种任务上取得优异成绩。其实这个模型就是U-Net、MGAN、cGAN的一个整合,但是却有将近2w的引用量。这也印证了那句话,成功有时候真的可能只是因为站在巨人的肩膀上。
引用:
[1] https://arxiv.org/abs/1611.07004
[2] https://arxiv.org/abs/1406.2661
[3] https://arxiv.org/abs/1411.1784
[4] https://arxiv.org/abs/1505.04597
[5] https://arxiv.org/abs/1604.04382