我要努力工作,加油!

【第一节】初识GAN,GAN是什么意思?GAN的直觉理解,使用 PyTorch 零基础训练一个自己的GAN

		发表于: 2021-02-01 10:12:00 | 已被阅读: 6 | 分类于: GAN
		

GAN 的全称是“Generative Adversarial Networks”,翻译成中文,意思是“生成式对抗网络”。深度学习领域的同学即使没有深入研究过 GAN,应该大多数也听说过这个概念。GAN 非常有意思,举个狭义的例子,训练好的图像 GAN 看起来就像是一个能够作画的机器人,而且它可以根据不同的输入特征做出不同的画。有不少学者称,生成网络才是真正的人工智能网络,由此可见业界对 GAN 的高度评价。

GAN 的“烂大街”形象解释

这里还是啰嗦一下关于 GAN 的烂大街的形象解释:GAN 的训练过程是一个“对抗”过程,谁在对抗呢?可以想象是画作的鉴别者作假者在对抗。

刚开始时,鉴别者和作假者的水平都比较低。

于是,鉴别者开始提升自己的鉴别能力,并且能够轻易鉴别出作假者画出来的“假画”。当然,为了让作假者心服口服,鉴别者要把自己发现“假的地方”告诉作假者。作假者知道自己假画的“不足”后,开始改进,提升自己的作假能力。

在作假者提升自己作假能力期间,鉴别者没闲着,也在不断的提升自己鉴别能力。随着时间推移,鉴别者的鉴别能力越来越强,作假者的作假能力也越来越强,直到最后,作假者终于能够作出在一般人看来与真画无异的画。再回到 GAN 网络,此时我们便称 GAN 网络的训练收敛了。

到这里读者可以考虑两个“意外情况”:

  1. 如果鉴别者只要看到画的作者是作假者,就说它是假画;
  2. 鉴别者每次告诉作假者“假的地方”时,都撒谎了,瞎说的是别的地方假,但是鉴别者在业界很权威,作假者相信了。

此时,GAN 网络还能收敛吗?

事实上,这两个“意外情况”在实际的 GAN 训练过程中经常出现,那么应该怎么深入理解这种“意外情况”,以及如何解决呢?本系列文章在日后再详谈。

设计 GAN 网络

根据前面一小节 GAN 的形象解释,很容易就可以设计出 GAN 网络。以非常容易获得的 mnist 手写数字数据集为例,本文将使用 PyTorch,从头设计一个 GAN 网络并训练之,让机器学会“写数字”。

mnist 数据集示例

mnist 是一组 28x28 的手写数字数据集。

首先,前面提到 GAN 网络一般包括两个网络——鉴别者和作假者,先定义这两个网络:

class Generator(nn.Module):
    def __init__(self, nz=16):
        super(Generator, self).__init__()
        self.nz = nz
        self.fc_1 = nn.Linear(self.nz, 28 * 7, bias=False)
        self.leaky_ReLU_1 = nn.LeakyReLU(0.2, inplace=True)
        self.fc_2 = nn.Linear(28 * 7, 28 * 14, bias=False)
        self.leaky_ReLU_2 = nn.LeakyReLU(0.2, inplace=True)
        self.fc_3 = nn.Linear(28 * 14, 28 * 28, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, z):
        bs = z.size(0)
        z = self.leaky_ReLU_1(self.fc_1(z))
        z = self.leaky_ReLU_2(self.fc_2(z))
        z = self.tanh(self.fc_3(z))
        z = torch.reshape(z, (bs, 1, 28, 28))
        return z

这里称“作假者”为Generator,“鉴别者”为Discriminator。可以看出 Generator 的结构相当简单——由几个全连接层(fc)构成主干网络,最后通过 reshape 层将网路学习到的特征整形为 1x28x28 的手写数字图片,tanh 层的作用则是将特征缩放到 -1.0 到 1.0 的范围内,以保证生成的图像像素值都在合理范围内,如下图所示:

Generator 的结构

现在将注意力放到 Generator 的输入上,不难发现它是一组一维向量,这组向量通常从一个随机分布(比如高斯分布)中随机采样。

为了行文方便,下面称这组一维向量为隐变量

当然了,即使隐变量来自于随机数,它也有自己的具体含义:

我们可以将 Generator 的过程反向来看——即输入为 1x28x28 的手写数字图片,输出为一维向量,这就和通常的图像特征提取网络很像了.

不难理解,隐变量其实就是手写数字的抽象低维特征,它包含着手写数字诸如类别、粗细、字体、倾斜程度等特征信息。因此,Generator 其实就是根据输入的隐变量,生成对应特征的手写数字图片。

弄明白 Generator 后,我们接着定义 Discriminator 网络,一般来说,它的复杂度与 Generator 相当,所以 Discriminator 几乎与 Generator 是对称的,如下图所示:

Discriminator 结构

Discriminator 在最后将所有特征提取为 size=1,并且通过 sigmoid 算子将这个特征映射到 0~1.0 范围内,以方便判断输入的图片是“真”还是“假”。通常来说,我们可以认为 Discriminator 最终输出的值越接近 1.0,输入的图片越“真”,否则就越“假”。这就相当于 Discriminator 对输入图片打分,越“真”的图片得分越高。相应的 Python 代码:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc_3 = nn.Linear(28 * 28, 28 * 14, bias=False)
        self.leaky_ReLU_2 = nn.LeakyReLU(0.2, inplace=True)
        self.fc_2 = nn.Linear(28 * 14, 28 * 7, bias=False)
        self.leaky_ReLU_1 = nn.LeakyReLU(0.2, inplace=True)
        self.fc_1 = nn.Linear(28 * 7, 1, bias=False)

    def forward(self, x):
        bs = x.size(0)
        x = x.view(bs, -1)
        x = self.leaky_ReLU_2(self.fc_3(x))
        x = self.leaky_ReLU_1(self.fc_2(x))
        x = torch.sigmoid(self.fc_1(x))
        return x

准备 MNIST 数据集

在 PyTorch 中使用 MNIST 数据集是非常方便的,使用下面的 Python 代码,可以自动从网络下载到:

class GanDataSet(torch_data.Dataset):
    def __init__(self, nz=16):
        super(GanDataSet, self).__init__()
        if not os.path.exists('./tmp'):
            os.makedirs('./tmp')
        self.mnist = datasets.MNIST(root='./tmp', train=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]), download=True)
        self.nz = nz

    def __len__(self):
        return self.mnist.__len__()

    def __getitem__(self, item):
        real = self.mnist.__getitem__(item)[0]  # item consist of [data, label]
        z = torch.randn(self.nz).float()
        return real, z

为了方便,数据集的每一个 item 由真实的手写数字图片和隐变量组成,隐变量随机从 randn 分布中采样,随后输送给 Generator,用于生成“假的”手写数字图片。

训练 GAN 网络

GAN 的训练过程其实就是 Generator 和 Discriminator 对抗的过程,Generator 需要不断提升自己的生成能力,Discriminator 需要不断提升自己的鉴别能力。这个过程反映到工程上倒是非常简单了,无非就是交替训练 Generator 和 Discriminator 而已。

设真实手写数字图片样本 \( x \sim p(x) \),隐变量 \( z \sim q(z) \),将 Generator 网络看作函数 \( G \),则 \( G(z) \) 即为生成的“假”样本,同样也将 Discriminator 网络看作函数 \( D \),则 \( D(x) \)\( D(G(z)) \) 分别表示真实样本和假样本的得分值。

训练 Discriminator

在训练判别器 Discriminator 时,为了提升 Discriminator 的判别能力,我们需要站在 Discriminator 的角度,自然希望它对真实样本打的分越高越好,对假样本打的分越低越好,即期望下面这个目标越大越好:

\( $\underset{D}{argmax}{\mathbb{E}_{x \sim p(x)), z \sim q(z))}}log(D(x))-log(D(G(z))) \)$

注意,训练 Discriminator 时,我们并不需要作假者 Generator 的其他信息(梯度),只关注它生成的假样本本身。至此,便不难写出对应的 Discriminator 训练代码,请看:

    def train_discriminator(real_x, fake_z):
        for p in G.parameters():
            p.requires_grad = False
        for p in D.parameters():
            p.requires_grad = True
        with torch.no_grad():
            fake_x = G(fake_z)
        real_conf, fake_conf = D(real_x), D(fake_x)

        loss = (torch.log(fake_conf + 1e-6) - torch.log(real_conf + 1e-6)).mean()
        D_optimizer.zero_grad()
        loss.backward()
        D_optimizer.step()

        return loss

还需要注意的是,鉴于 PyTorch 优化器默认最小化 loss,而我们期望能够最大化上述目标,因此上述 Python 代码中的 loss 其实是期望目标的相反数。

训练 Generator

训练作假者 Generator 时,自然是要站在作假者的角度,我们希望生成的假样本越“真”越好,也即“假”样本在鉴别者 Discriminator 那边的得分值越高越好:

\( $\underset{G}{argmax}{\mathbb{E}_{z \sim q(z))}}log(D(G(z))) \)$

不难写出相应的 Python 代码,请看:

    def train_generator(real_x, fake_z):
        for p in D.parameters():
            p.requires_grad = False
        for p in G.parameters():
            p.requires_grad = True

        fake_x = G(fake_z)
        fake_conf = D(fake_x)

        loss = -torch.log(fake_conf + 1e-6).mean()
        G_optimizer.zero_grad()
        loss.backward()
        G_optimizer.step()

        return loss

至此,鉴别者 Discriminator 和作假者 Generator 的训练代码就完成了,结合 MNIST 数据集代码,将整个训练过程串起来即可完成生成模型,完整的代码可以参考:

todo: add link.

训练结果

我们已经知道,GAN 的训练过程是一个对抗过程,在刚开始时,鉴别者 Discriminator 和作假者 Generator 的水平都很差,生成的假样本明显很假,请看:

训练 1 个 epoch

上图为训练 1 个 epoch 后的生成“假”样本,虽然有一点手写数字的感觉,但和真实样本相比有着明显的区别。不过,随着训练的进行,Generator 生成的“假”样本将会越来越真实,下图是训练 50 个 epoch 后的生成“假”样本:

训练 50 个 epoch

可见,随着训练的进行,Generator 生成的“假”样本质量还是明显变好了的。