DCGAN 教程

译者:wangshuai9517

作者: Nathan Inkawhich

介绍

本教程将通过一个例子来介绍DCGAN。我们将使用很多真正的名人照片训练一个生成对抗网络(GAN)后,生成新的假名人照片。这里的大多数代码来自于pytorch/examples中对DCGAN的实现,并且本文档将对DCGAN的实现进行全面解释,并阐明该模型是怎样工作的以及为什么能工作。但是不要担心,我们并不需要你事先了解GAN,但是可能需要先花一些时间来弄明白实际发生了什么。 此外,拥有一两个GPU将对节省运行时间很有帮助。 让我们从头开始吧。

对抗生成网络

什么是对抗生成网络(GAN)?

对抗生成网络(GAN)是一个教深度模型获取训练数据分布的一种框架,因此我们能够使用类似的分布来生成新的数据。对抗生成网络是Ian Goodfellow在2014年发明的,并首次发表在文章 Generative Adversarial Nets中。它们由两种不同的模块组成,一个生成器 generator 以及一个判别器 discriminator 。生成器的工作是产生看起来像训练图像的“假”图像。 判别器的工作是查看图像并输出它是否是来自真实训练图像或生成器的伪图像。在训练期间,生成器不断尝试通过产生越来越好的假图片来超越判别器,与此同时判别器逐渐更好的检测并正确分类真假图片。 这个过程最后逐渐的变得平衡,生成器生成完美的假图片,这些假图片看起来好像它们直接来自训练数据,并且判别器总是猜测生成器输出的图片真假都是50%。

现在,我们先定义一些整个教程中要使用的符号,首先从判别器开始。 表示图像数据。 表示判别网络,它的输出表示数据 来自与训练数据而不是生成数据的概率。这里的输入图像是大小为3x64x64。 直观地说,当来自训练数据时,的值应当是大的;而当来自发生器时,D(x)z G(z)z G p {data} p_g D(G(z))GDGDlogD(x)GDlog(1-D(G(x)))\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}{x\sim p{data}(x)}\big[logD(x)\big] + \mathbb{E}{z\sim p{z}(z)}\big[log(1-D(G(x)))\big]p_g = p {data} zG z z [ - 1,1] D D G D DG\ell(x, y) = L = {l_1,\dots,l_N}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]log(D(x))log(1-D(G(z)))yDG) ` 的损失函数的时候使用,这是在原始的GAN文献中使用的惯例。最后我们设置两个单独的优化器,一个给判别器使用,一个给生成器使用。 就像DCGAN文章中说的那样,两个Adam优化算法都是用学习率为0.0002以及Beta1参数为0.5。为了保存追踪生成器学习的过程,我们将生成一个批固定不变的来自于高斯分布的本征向量(例如 fixed_noise)。在训练的循环中,我们将周期性的输入这个fixed_noise到生成器 中, 在训练都完成后我们将看一下由fixed_noise生成的图片。

# 初始化 BCE损失函数
criterion = nn.BCELoss()

# 创建一个批次的本征向量用于可视化生成器训练的过程。
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# 建立一个在训练中使用的真实和假的标记
real_label = 1
fake_label = 0

# 为G和D都设置Adam优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

训练

最后,既然已经定义了GAN框架的所有部分,我们就可以对其进行训练。 请注意,训练GAN在某种程度上是一种艺术形式,因为不正确的超参数设置会导致mode collapse,而对错误的解释很少。 在这里,我们将密切关注Goodfellow的论文中的算法1,同时遵守ganhacks中显示的一些最佳实践。 也就是说,我们将“为真实和假冒”图像构建不同的小批量,并调整G的目标函数以最大化。 训练分为两个主要部分。 第1部分更新判别器Discriminator,第2部分更新生成器Generator。

Part 1 - 训练判别器

回想一下,训练判别器的目的是最大化将给定输入正确分类为真实或假的概率。 就Goodfellow而言,我们希望“通过提升其随机梯度来更新判别器”。 实际上,我们想要最大化损失。 由于ganhacks的单独小批量建议,我们将分两步计算。 首先,我们将从训练集中构造一批实际样本,向前通过,计算损失(),然后计算梯度 向后传递。 其次,我们将用当前的生成器构造一批假样本,通过转发该批次,计算损失( )和 accumulate 带有向后传递。 现在,随着从全真实和全假批量累积的梯度,我们称之为Discriminator优化器的一步。

Part 2 - 训练生成器

正如原始论文所述,我们希望通过最小化\(log(1-D(G(z)))\)来训练生成器Generator,以便产生更好的假样本。 如上所述,Goodfellow表明这不会提供足够的梯度,尤其是在学习过程的早期阶段。 作为修改,我们希望最大化\(log(D(G(z)))\)。 在代码中,我们通过以下方式实现此目的:使用判别器Discriminator对第1部分的生成器Generator输出进行分类,计算生成器G的损失 使用实际标签作为GT ,在向后传递中计算生成器G的梯度,最后使用优化步骤更新G的参数。 使用真实标签作为损失函数的GT标签似乎是违反直觉的,但这允许我们使用BCELoss的\(log(x)\)部分(而不是\(log(1) -x)\)这部分)这正是我们想要的。

最后,我们将进行一些统计报告,在每个循环结束时,我们将通过生成器推送我们的fixed_noise批次,以直观地跟踪G训练的进度。 报告的训练统计数据是:

  • Loss_D - 判别器损失是所有真实样本批次和所有假样本批次的损失之和 .
  • Loss_G - 生成器损失
  • D(x) - 所有真实批次的判别器的平均输出(整批)。 这应该从接近1开始,然后当G变好时理论上收敛到0.5。 想想为什么会这样。
  • D(G(z)) - 所有假批次的平均判别器输出。 第一个数字是在D更新之前,第二个数字是在D更新之后。 当G变好时,这些数字应该从0开始并收敛到0.5。 想想为什么会这样。

Note: 此步骤可能需要一段时间,具体取决于您运行的循环数以及是否从数据集中删除了一些数据。

# 训练循环

# 保存跟踪进度的列表
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# 每个epoh
for epoch in range(num_epochs):
    # 数据加载器中的每个批次
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) 更新 D 网络: 最大化 log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## 使用所有真实样本批次训练
        netD.zero_grad()
        # 格式化批
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # 通过D向前传递真实批次
        output = netD(real_cpu).view(-1)
        # 对所有真实样本批次计算损失
        errD_real = criterion(output, label)
        # 计算后向传递中D的梯度
        errD_real.backward()
        D_x = output.mean().item()

        ## 使用所有假样本批次训练
        # 生成本征向量批次
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # 使用生成器G生成假图片
        fake = netG(noise)
        label.fill_(fake_label)
        # 使用判别器分类所有的假批次样本
        output = netD(fake.detach()).view(-1)
        # 计算判别器D的损失对所有的假样本批次
        errD_fake = criterion(output, label)
        # 对这个批次计算梯度
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # 把所有真样本和假样本批次的梯度加起来
        errD = errD_real + errD_fake
        # 更新判别器D
        optimizerD.step()

        ############################
        # (2) 更新 G 网络: 最大化 log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # 假样本的标签对于生成器成本是真的
        # 因为我们之更新了D,通过D执行所有假样本批次的正向传递
        output = netD(fake).view(-1)
        # 基于这个输出计算G的损失
        errG = criterion(output, label)
        # 为生成器计算梯度
        errG.backward()
        D_G_z2 = output.mean().item()
        # 更新生成器G
        optimizerG.step()

        # 输出训练状态
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # 为以后画损失图,保存损失
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # 检查生成器generator做了什么,通过保存的fixed_noise通过G的输出
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

Out:

Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.7410  Loss_G: 4.7761  D(x): 0.5343    D(G(z)): 0.5771 / 0.0136
[0/5][50/1583]  Loss_D: 1.7332  Loss_G: 25.4829 D(x): 0.9774    D(G(z)): 0.7441 / 0.0000
[0/5][100/1583] Loss_D: 1.6841  Loss_G: 11.6585 D(x): 0.4728    D(G(z)): 0.0000 / 0.0000
[0/5][150/1583] Loss_D: 1.2547  Loss_G: 8.7245  D(x): 0.9286    D(G(z)): 0.5209 / 0.0044
[0/5][200/1583] Loss_D: 0.7563  Loss_G: 8.9600  D(x): 0.9525    D(G(z)): 0.4514 / 0.0003
[0/5][250/1583] Loss_D: 1.0221  Loss_G: 2.5713  D(x): 0.5274    D(G(z)): 0.0474 / 0.1177
[0/5][300/1583] Loss_D: 0.3387  Loss_G: 3.8185  D(x): 0.8431    D(G(z)): 0.1066 / 0.0461
[0/5][350/1583] Loss_D: 0.5054  Loss_G: 3.6141  D(x): 0.7289    D(G(z)): 0.0758 / 0.0535
[0/5][400/1583] Loss_D: 0.8758  Loss_G: 6.5680  D(x): 0.8097    D(G(z)): 0.4017 / 0.0031
[0/5][450/1583] Loss_D: 0.2486  Loss_G: 3.5121  D(x): 0.9035    D(G(z)): 0.1054 / 0.0717
[0/5][500/1583] Loss_D: 1.5792  Loss_G: 4.3590  D(x): 0.3457    D(G(z)): 0.0053 / 0.0379
[0/5][550/1583] Loss_D: 0.8897  Loss_G: 3.9447  D(x): 0.5350    D(G(z)): 0.0349 / 0.0386
[0/5][600/1583] Loss_D: 0.5292  Loss_G: 4.4346  D(x): 0.8914    D(G(z)): 0.2768 / 0.0233
[0/5][650/1583] Loss_D: 0.3779  Loss_G: 4.7253  D(x): 0.7868    D(G(z)): 0.0627 / 0.0174
[0/5][700/1583] Loss_D: 0.7512  Loss_G: 2.6246  D(x): 0.6112    D(G(z)): 0.0244 / 0.1493
[0/5][750/1583] Loss_D: 0.4378  Loss_G: 5.0045  D(x): 0.8614    D(G(z)): 0.2028 / 0.0108
[0/5][800/1583] Loss_D: 0.5795  Loss_G: 6.0537  D(x): 0.8693    D(G(z)): 0.2732 / 0.0066
[0/5][850/1583] Loss_D: 0.8980  Loss_G: 6.5355  D(x): 0.8465    D(G(z)): 0.4226 / 0.0048
[0/5][900/1583] Loss_D: 0.5776  Loss_G: 7.7162  D(x): 0.9756    D(G(z)): 0.3707 / 0.0009
[0/5][950/1583] Loss_D: 0.5593  Loss_G: 5.6692  D(x): 0.9560    D(G(z)): 0.3494 / 0.0080
[0/5][1000/1583]        Loss_D: 0.5036  Loss_G: 5.1312  D(x): 0.7775    D(G(z)): 0.0959 / 0.0178
[0/5][1050/1583]        Loss_D: 0.5192  Loss_G: 4.5706  D(x): 0.8578    D(G(z)): 0.2605 / 0.0222
[0/5][1100/1583]        Loss_D: 0.5645  Loss_G: 3.1618  D(x): 0.7133    D(G(z)): 0.1138 / 0.0768
[0/5][1150/1583]        Loss_D: 0.2790  Loss_G: 4.5294  D(x): 0.8541    D(G(z)): 0.0909 / 0.0207
[0/5][1200/1583]        Loss_D: 0.5334  Loss_G: 4.3445  D(x): 0.8567    D(G(z)): 0.2457 / 0.0245
[0/5][1250/1583]        Loss_D: 0.7318  Loss_G: 2.2779  D(x): 0.6846    D(G(z)): 0.1485 / 0.1497
[0/5][1300/1583]        Loss_D: 0.6939  Loss_G: 6.1172  D(x): 0.9123    D(G(z)): 0.3853 / 0.0041
[0/5][1350/1583]        Loss_D: 0.4653  Loss_G: 3.7054  D(x): 0.8208    D(G(z)): 0.1774 / 0.0404
[0/5][1400/1583]        Loss_D: 1.9711  Loss_G: 3.1569  D(x): 0.2704    D(G(z)): 0.0108 / 0.1390
[0/5][1450/1583]        Loss_D: 0.4427  Loss_G: 5.8683  D(x): 0.9230    D(G(z)): 0.2600 / 0.0056
[0/5][1500/1583]        Loss_D: 0.4432  Loss_G: 3.3681  D(x): 0.8001    D(G(z)): 0.1510 / 0.0633
[0/5][1550/1583]        Loss_D: 0.4852  Loss_G: 3.2790  D(x): 0.7532    D(G(z)): 0.1100 / 0.0661
[1/5][0/1583]   Loss_D: 0.3536  Loss_G: 4.5358  D(x): 0.8829    D(G(z)): 0.1714 / 0.0173
[1/5][50/1583]  Loss_D: 0.4717  Loss_G: 4.7728  D(x): 0.8973    D(G(z)): 0.2750 / 0.0142
[1/5][100/1583] Loss_D: 0.4702  Loss_G: 2.3528  D(x): 0.7847    D(G(z)): 0.1468 / 0.1385
[1/5][150/1583] Loss_D: 0.4833  Loss_G: 2.9645  D(x): 0.7893    D(G(z)): 0.1607 / 0.0867
[1/5][200/1583] Loss_D: 0.6035  Loss_G: 2.0728  D(x): 0.6646    D(G(z)): 0.0852 / 0.1806
[1/5][250/1583] Loss_D: 0.3822  Loss_G: 3.1946  D(x): 0.7969    D(G(z)): 0.1024 / 0.0656
[1/5][300/1583] Loss_D: 0.3892  Loss_G: 3.3337  D(x): 0.7848    D(G(z)): 0.0969 / 0.0525
[1/5][350/1583] Loss_D: 1.7989  Loss_G: 7.5798  D(x): 0.9449    D(G(z)): 0.7273 / 0.0011
[1/5][400/1583] Loss_D: 0.4765  Loss_G: 3.0655  D(x): 0.7479    D(G(z)): 0.1116 / 0.0687
[1/5][450/1583] Loss_D: 0.3649  Loss_G: 3.1674  D(x): 0.8603    D(G(z)): 0.1619 / 0.0627
[1/5][500/1583] Loss_D: 0.6922  Loss_G: 4.5841  D(x): 0.9235    D(G(z)): 0.4003 / 0.0175
[1/5][550/1583] Loss_D: 0.6126  Loss_G: 4.6642  D(x): 0.8761    D(G(z)): 0.3199 / 0.0180
[1/5][600/1583] Loss_D: 0.7032  Loss_G: 4.6221  D(x): 0.9463    D(G(z)): 0.4365 / 0.0154
[1/5][650/1583] Loss_D: 0.4707  Loss_G: 3.3616  D(x): 0.7664    D(G(z)): 0.1280 / 0.0617
[1/5][700/1583] Loss_D: 0.3393  Loss_G: 2.4236  D(x): 0.9120    D(G(z)): 0.1771 / 0.1280
[1/5][750/1583] Loss_D: 0.6828  Loss_G: 4.4585  D(x): 0.8647    D(G(z)): 0.3546 / 0.0191
[1/5][800/1583] Loss_D: 0.7958  Loss_G: 3.6708  D(x): 0.8386    D(G(z)): 0.3987 / 0.0403
[1/5][850/1583] Loss_D: 0.4651  Loss_G: 2.7477  D(x): 0.7602    D(G(z)): 0.1334 / 0.0900
[1/5][900/1583] Loss_D: 0.8799  Loss_G: 4.7930  D(x): 0.9050    D(G(z)): 0.4710 / 0.0201
[1/5][950/1583] Loss_D: 0.3909  Loss_G: 2.7973  D(x): 0.7730    D(G(z)): 0.0902 / 0.0838
[1/5][1000/1583]        Loss_D: 0.3822  Loss_G: 3.0223  D(x): 0.8699    D(G(z)): 0.1837 / 0.0709
[1/5][1050/1583]        Loss_D: 0.4689  Loss_G: 2.2831  D(x): 0.7096    D(G(z)): 0.0536 / 0.1448
[1/5][1100/1583]        Loss_D: 0.6676  Loss_G: 2.2773  D(x): 0.6669    D(G(z)): 0.1386 / 0.1443
[1/5][1150/1583]        Loss_D: 0.5970  Loss_G: 4.1558  D(x): 0.9166    D(G(z)): 0.3554 / 0.0240
[1/5][1200/1583]        Loss_D: 0.3622  Loss_G: 3.5782  D(x): 0.8590    D(G(z)): 0.1547 / 0.0481
[1/5][1250/1583]        Loss_D: 0.5234  Loss_G: 2.5915  D(x): 0.7811    D(G(z)): 0.1990 / 0.1037
[1/5][1300/1583]        Loss_D: 1.3243  Loss_G: 5.5428  D(x): 0.9882    D(G(z)): 0.6572 / 0.0088
[1/5][1350/1583]        Loss_D: 0.4891  Loss_G: 1.9552  D(x): 0.7686    D(G(z)): 0.1540 / 0.1910
[1/5][1400/1583]        Loss_D: 0.5639  Loss_G: 3.7796  D(x): 0.9137    D(G(z)): 0.3390 / 0.0343
[1/5][1450/1583]        Loss_D: 1.7329  Loss_G: 5.0373  D(x): 0.9760    D(G(z)): 0.7332 / 0.0161
[1/5][1500/1583]        Loss_D: 0.7999  Loss_G: 3.7268  D(x): 0.9029    D(G(z)): 0.4550 / 0.0384
[1/5][1550/1583]        Loss_D: 0.4740  Loss_G: 2.3220  D(x): 0.7824    D(G(z)): 0.1625 / 0.1327
[2/5][0/1583]   Loss_D: 0.8693  Loss_G: 3.8890  D(x): 0.9376    D(G(z)): 0.4822 / 0.0339
[2/5][50/1583]  Loss_D: 0.3742  Loss_G: 2.5041  D(x): 0.8148    D(G(z)): 0.1310 / 0.1151
[2/5][100/1583] Loss_D: 1.1134  Loss_G: 1.5167  D(x): 0.4248    D(G(z)): 0.0335 / 0.3023
[2/5][150/1583] Loss_D: 0.5987  Loss_G: 3.2047  D(x): 0.8536    D(G(z)): 0.3121 / 0.0555
[2/5][200/1583] Loss_D: 2.0846  Loss_G: 1.5473  D(x): 0.1919    D(G(z)): 0.0054 / 0.2899
[2/5][250/1583] Loss_D: 0.5017  Loss_G: 3.0225  D(x): 0.8965    D(G(z)): 0.2986 / 0.0626
[2/5][300/1583] Loss_D: 1.3296  Loss_G: 4.1927  D(x): 0.9444    D(G(z)): 0.6574 / 0.0270
[2/5][350/1583] Loss_D: 0.4905  Loss_G: 2.7693  D(x): 0.8049    D(G(z)): 0.2090 / 0.0863
[2/5][400/1583] Loss_D: 0.4668  Loss_G: 2.1790  D(x): 0.7160    D(G(z)): 0.0815 / 0.1529
[2/5][450/1583] Loss_D: 0.4877  Loss_G: 2.4190  D(x): 0.6943    D(G(z)): 0.0693 / 0.1254
[2/5][500/1583] Loss_D: 0.7856  Loss_G: 2.2362  D(x): 0.6148    D(G(z)): 0.1698 / 0.1489
[2/5][550/1583] Loss_D: 0.6371  Loss_G: 1.3879  D(x): 0.6164    D(G(z)): 0.0852 / 0.3041
[2/5][600/1583] Loss_D: 0.6409  Loss_G: 2.8623  D(x): 0.7658    D(G(z)): 0.2684 / 0.0790
[2/5][650/1583] Loss_D: 0.6454  Loss_G: 1.5708  D(x): 0.6293    D(G(z)): 0.0944 / 0.2706
[2/5][700/1583] Loss_D: 0.8472  Loss_G: 2.0847  D(x): 0.5071    D(G(z)): 0.0181 / 0.1937
[2/5][750/1583] Loss_D: 1.2356  Loss_G: 0.3673  D(x): 0.3606    D(G(z)): 0.0328 / 0.7270
[2/5][800/1583] Loss_D: 0.4852  Loss_G: 2.7325  D(x): 0.8670    D(G(z)): 0.2630 / 0.0877
[2/5][850/1583] Loss_D: 0.6494  Loss_G: 4.5357  D(x): 0.8899    D(G(z)): 0.3756 / 0.0158
[2/5][900/1583] Loss_D: 0.5184  Loss_G: 2.7194  D(x): 0.8377    D(G(z)): 0.2540 / 0.0871
[2/5][950/1583] Loss_D: 0.9771  Loss_G: 4.6200  D(x): 0.9596    D(G(z)): 0.5432 / 0.0176
[2/5][1000/1583]        Loss_D: 0.7509  Loss_G: 2.2864  D(x): 0.5861    D(G(z)): 0.1021 / 0.1539
[2/5][1050/1583]        Loss_D: 0.4512  Loss_G: 3.2484  D(x): 0.8649    D(G(z)): 0.2313 / 0.0542
[2/5][1100/1583]        Loss_D: 0.6856  Loss_G: 2.2425  D(x): 0.6405    D(G(z)): 0.1333 / 0.1508
[2/5][1150/1583]        Loss_D: 0.5271  Loss_G: 3.0327  D(x): 0.8385    D(G(z)): 0.2552 / 0.0639
[2/5][1200/1583]        Loss_D: 0.4058  Loss_G: 2.9557  D(x): 0.8769    D(G(z)): 0.2169 / 0.0694
[2/5][1250/1583]        Loss_D: 0.5564  Loss_G: 2.9065  D(x): 0.8409    D(G(z)): 0.2835 / 0.0695
[2/5][1300/1583]        Loss_D: 0.4703  Loss_G: 2.7865  D(x): 0.7825    D(G(z)): 0.1680 / 0.0850
[2/5][1350/1583]        Loss_D: 0.5352  Loss_G: 3.1362  D(x): 0.8260    D(G(z)): 0.2582 / 0.0606
[2/5][1400/1583]        Loss_D: 0.5281  Loss_G: 2.7742  D(x): 0.7970    D(G(z)): 0.2275 / 0.0835
[2/5][1450/1583]        Loss_D: 0.6558  Loss_G: 1.8152  D(x): 0.6103    D(G(z)): 0.0795 / 0.2030
[2/5][1500/1583]        Loss_D: 0.9446  Loss_G: 1.1492  D(x): 0.4593    D(G(z)): 0.0356 / 0.3947
[2/5][1550/1583]        Loss_D: 0.9269  Loss_G: 0.7383  D(x): 0.5226    D(G(z)): 0.1333 / 0.5205
[3/5][0/1583]   Loss_D: 0.4855  Loss_G: 2.1548  D(x): 0.7157    D(G(z)): 0.1059 / 0.1568
[3/5][50/1583]  Loss_D: 0.7259  Loss_G: 1.1093  D(x): 0.5804    D(G(z)): 0.0797 / 0.3894
[3/5][100/1583] Loss_D: 0.7367  Loss_G: 1.0389  D(x): 0.5515    D(G(z)): 0.0405 / 0.4190
[3/5][150/1583] Loss_D: 0.5942  Loss_G: 3.4803  D(x): 0.9290    D(G(z)): 0.3709 / 0.0432
[3/5][200/1583] Loss_D: 1.3464  Loss_G: 0.6549  D(x): 0.3261    D(G(z)): 0.0242 / 0.5949
[3/5][250/1583] Loss_D: 0.5110  Loss_G: 2.2086  D(x): 0.7263    D(G(z)): 0.1327 / 0.1457
[3/5][300/1583] Loss_D: 1.4272  Loss_G: 3.3018  D(x): 0.9230    D(G(z)): 0.6654 / 0.0635
[3/5][350/1583] Loss_D: 0.6491  Loss_G: 3.0766  D(x): 0.8124    D(G(z)): 0.3127 / 0.0607
[3/5][400/1583] Loss_D: 0.5583  Loss_G: 2.9363  D(x): 0.8233    D(G(z)): 0.2759 / 0.0666
[3/5][450/1583] Loss_D: 0.9496  Loss_G: 0.6436  D(x): 0.4958    D(G(z)): 0.1367 / 0.5538
[3/5][500/1583] Loss_D: 0.4463  Loss_G: 2.2234  D(x): 0.7776    D(G(z)): 0.1545 / 0.1371
[3/5][550/1583] Loss_D: 0.5874  Loss_G: 3.6688  D(x): 0.8478    D(G(z)): 0.2930 / 0.0348
[3/5][600/1583] Loss_D: 0.3724  Loss_G: 2.6326  D(x): 0.8673    D(G(z)): 0.1854 / 0.0891
[3/5][650/1583] Loss_D: 0.7292  Loss_G: 4.4254  D(x): 0.9081    D(G(z)): 0.4234 / 0.0200
[3/5][700/1583] Loss_D: 0.4728  Loss_G: 2.8665  D(x): 0.8189    D(G(z)): 0.2115 / 0.0774
[3/5][750/1583] Loss_D: 0.5845  Loss_G: 3.3046  D(x): 0.8977    D(G(z)): 0.3490 / 0.0463
[3/5][800/1583] Loss_D: 0.5597  Loss_G: 2.2564  D(x): 0.7088    D(G(z)): 0.1497 / 0.1300
[3/5][850/1583] Loss_D: 0.6518  Loss_G: 2.5048  D(x): 0.7195    D(G(z)): 0.2183 / 0.1053
[3/5][900/1583] Loss_D: 0.7340  Loss_G: 1.4263  D(x): 0.6285    D(G(z)): 0.1806 / 0.2818
[3/5][950/1583] Loss_D: 1.4633  Loss_G: 4.9204  D(x): 0.9792    D(G(z)): 0.7093 / 0.0143
[3/5][1000/1583]        Loss_D: 0.6643  Loss_G: 2.8332  D(x): 0.8548    D(G(z)): 0.3597 / 0.0751
[3/5][1050/1583]        Loss_D: 0.7741  Loss_G: 2.9355  D(x): 0.7281    D(G(z)): 0.3064 / 0.0712
[3/5][1100/1583]        Loss_D: 0.7279  Loss_G: 3.2299  D(x): 0.8867    D(G(z)): 0.4193 / 0.0544
[3/5][1150/1583]        Loss_D: 0.6049  Loss_G: 1.9150  D(x): 0.6917    D(G(z)): 0.1645 / 0.1912
[3/5][1200/1583]        Loss_D: 0.7431  Loss_G: 3.8188  D(x): 0.9334    D(G(z)): 0.4500 / 0.0306
[3/5][1250/1583]        Loss_D: 0.5061  Loss_G: 1.9905  D(x): 0.7393    D(G(z)): 0.1531 / 0.1653
[3/5][1300/1583]        Loss_D: 0.6979  Loss_G: 3.0183  D(x): 0.8182    D(G(z)): 0.3421 / 0.0616
[3/5][1350/1583]        Loss_D: 0.9133  Loss_G: 4.0629  D(x): 0.9198    D(G(z)): 0.5131 / 0.0261
[3/5][1400/1583]        Loss_D: 0.7075  Loss_G: 4.0061  D(x): 0.9188    D(G(z)): 0.4216 / 0.0266
[3/5][1450/1583]        Loss_D: 0.7704  Loss_G: 2.3802  D(x): 0.7555    D(G(z)): 0.3348 / 0.1114
[3/5][1500/1583]        Loss_D: 0.6055  Loss_G: 1.8402  D(x): 0.7011    D(G(z)): 0.1643 / 0.1995
[3/5][1550/1583]        Loss_D: 0.7240  Loss_G: 3.2589  D(x): 0.8747    D(G(z)): 0.4069 / 0.0528
[4/5][0/1583]   Loss_D: 0.8162  Loss_G: 2.8040  D(x): 0.8827    D(G(z)): 0.4435 / 0.0870
[4/5][50/1583]  Loss_D: 0.5859  Loss_G: 2.2796  D(x): 0.6782    D(G(z)): 0.1312 / 0.1309
[4/5][100/1583] Loss_D: 0.6655  Loss_G: 3.5365  D(x): 0.8178    D(G(z)): 0.3262 / 0.0394
[4/5][150/1583] Loss_D: 1.8662  Loss_G: 5.4950  D(x): 0.9469    D(G(z)): 0.7590 / 0.0113
[4/5][200/1583] Loss_D: 0.7060  Loss_G: 3.6253  D(x): 0.9215    D(G(z)): 0.4316 / 0.0364
[4/5][250/1583] Loss_D: 0.5589  Loss_G: 2.1394  D(x): 0.7108    D(G(z)): 0.1513 / 0.1548
[4/5][300/1583] Loss_D: 0.7278  Loss_G: 1.2391  D(x): 0.5757    D(G(z)): 0.0987 / 0.3454
[4/5][350/1583] Loss_D: 0.7597  Loss_G: 2.8481  D(x): 0.7502    D(G(z)): 0.3094 / 0.0843
[4/5][400/1583] Loss_D: 0.6167  Loss_G: 2.2143  D(x): 0.6641    D(G(z)): 0.1315 / 0.1405
[4/5][450/1583] Loss_D: 0.6234  Loss_G: 1.7961  D(x): 0.7303    D(G(z)): 0.2208 / 0.2007
[4/5][500/1583] Loss_D: 0.6098  Loss_G: 4.9416  D(x): 0.9442    D(G(z)): 0.3978 / 0.0104
[4/5][550/1583] Loss_D: 0.6570  Loss_G: 3.6935  D(x): 0.9180    D(G(z)): 0.4015 / 0.0312
[4/5][600/1583] Loss_D: 0.4195  Loss_G: 2.3446  D(x): 0.7798    D(G(z)): 0.1319 / 0.1211
[4/5][650/1583] Loss_D: 0.5291  Loss_G: 2.5303  D(x): 0.7528    D(G(z)): 0.1875 / 0.1075
[4/5][700/1583] Loss_D: 0.5187  Loss_G: 2.0350  D(x): 0.7174    D(G(z)): 0.1431 / 0.1547
[4/5][750/1583] Loss_D: 0.8208  Loss_G: 1.0780  D(x): 0.5665    D(G(z)): 0.1128 / 0.3844
[4/5][800/1583] Loss_D: 0.5223  Loss_G: 3.0140  D(x): 0.8708    D(G(z)): 0.2871 / 0.0612
[4/5][850/1583] Loss_D: 2.9431  Loss_G: 1.0175  D(x): 0.0914    D(G(z)): 0.0162 / 0.4320
[4/5][900/1583] Loss_D: 0.5456  Loss_G: 1.7923  D(x): 0.7489    D(G(z)): 0.1972 / 0.2038
[4/5][950/1583] Loss_D: 0.4718  Loss_G: 2.3825  D(x): 0.7840    D(G(z)): 0.1772 / 0.1172
[4/5][1000/1583]        Loss_D: 0.5174  Loss_G: 2.5070  D(x): 0.8367    D(G(z)): 0.2556 / 0.1074
[4/5][1050/1583]        Loss_D: 0.8214  Loss_G: 0.8055  D(x): 0.5181    D(G(z)): 0.0694 / 0.4963
[4/5][1100/1583]        Loss_D: 1.3243  Loss_G: 0.7562  D(x): 0.3284    D(G(z)): 0.0218 / 0.5165
[4/5][1150/1583]        Loss_D: 0.9334  Loss_G: 5.1260  D(x): 0.8775    D(G(z)): 0.4817 / 0.0088
[4/5][1200/1583]        Loss_D: 0.5141  Loss_G: 2.7230  D(x): 0.8067    D(G(z)): 0.2188 / 0.0872
[4/5][1250/1583]        Loss_D: 0.6007  Loss_G: 1.9893  D(x): 0.6968    D(G(z)): 0.1667 / 0.1748
[4/5][1300/1583]        Loss_D: 0.4025  Loss_G: 2.3066  D(x): 0.8101    D(G(z)): 0.1471 / 0.1412
[4/5][1350/1583]        Loss_D: 0.5979  Loss_G: 3.2825  D(x): 0.8248    D(G(z)): 0.3003 / 0.0509
[4/5][1400/1583]        Loss_D: 0.7430  Loss_G: 3.6521  D(x): 0.8888    D(G(z)): 0.4243 / 0.0339
[4/5][1450/1583]        Loss_D: 1.0814  Loss_G: 5.4255  D(x): 0.9647    D(G(z)): 0.5842 / 0.0070
[4/5][1500/1583]        Loss_D: 1.7211  Loss_G: 0.7875  D(x): 0.2588    D(G(z)): 0.0389 / 0.5159
[4/5][1550/1583]        Loss_D: 0.5871  Loss_G: 2.1340  D(x): 0.7332    D(G(z)): 0.1982 / 0.1518

结果

最后,让我们看看我们做的怎么样。 在这里,我们将看看三个不同的结果。 首先,我们将看到判别器D和生成器G的损失在训练期间是如何变化的。 其次,我们将在每个批次可视化生成器G的输出。 第三,我们将查看一批实际数据以及来自生成器G一批假数据。

损失与训练迭代次数关系图

下面将绘制生成器和判别器的损失和训练迭代次数关系图。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

https://pytorch.org/tutorials/_images/sphx_glr_dcgan_faces_tutorial_002.png

生成器G的训练进度

我们在每一个批次训练完成之后都保存了生成器的输出。 现在我们可以通过动画可视化生成器G的训练进度。点击播放按钮开始动画.

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

https://pytorch.org/tutorials/_images/sphx_glr_dcgan_faces_tutorial_003.png

真实图像 vs. 假图像

最后,让我们一起看看一些真实的图像和假图像。

# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))

# 画出真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 画出来自最后一次训练的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

https://pytorch.org/tutorials/_images/sphx_glr_dcgan_faces_tutorial_004.png

下一步计划

我们已经到了教程的最后,但是你可以根据此教程研究以下内容:

  • 训练更长的时间看看能够达到多好的结果
  • 调整此模型以适合不同的数据集,如果可能你可以更改输入图片大小以及模型的架构
  • 看看这里其他一些很酷的GAN项目
  • 创建一个能够产生音乐的GAN模型


Copyright © ibooker.org.cn 2019 all right reserved,由 ApacheCN 团队提供支持该文件修订时间: 2019-07-18 12:53:34

results matching ""

    No results matching ""

    results matching ""

      No results matching ""