Exemple #1
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    # generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1)
    # size: [opt.nc_z, opt.nzx, opt.nzy]
    # z_opt: input noise for calculating reconstruction loss in Generator
    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    # TODO: these plots are not visualized
    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        start_time = time.time()
        if (Gs == []) & (opt.mode != 'SR_train'):
            # Bottom generator, here without zero init
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            # noise_: input noise for the discriminator
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                # Initialize prev and z_prev
                # prev: image outputs from previous level
                # z_prev: image outputs from previous level of fixed noise z_opt
                if (Gs == []) & (opt.mode != 'SR_train'):
                    # in_s and prev are both noise
                    # z_prev are also noise
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch > 1 and (epoch % 100 == 0 or epoch == (opt.niter - 1)):
            total_time = time.time() - start_time
            start_time = time.time()
            print('scale %d:[%d/%d], total time: %f' %
                  (len(Gs), epoch, opt.niter, total_time))
            memory = torch.cuda.max_memory_allocated()
            # print('allocated memory: %dG %dM %dk %d' %
            #         ( memory // (1024*1024*1024),
            #           (memory // (1024*1024)) % 1024,
            #           (memory // 1024) % 1024,
            #           memory % 1024 ))
            print('allocated memory: %.03f GB' % (memory /
                                                  (1024 * 1024 * 1024 * 1.0)))

        # if epoch % 500 == 0 or epoch == (opt.niter-1):
        if epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            # plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            # plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            # plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            plt.imsave('%s/prev.png' % (opt.outf),
                       functions.convert_image_np(prev),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/prev_plus_noise.png' % (opt.outf),
                       functions.convert_image_np(noise),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/z_prev.png' % (opt.outf),
                       functions.convert_image_np(z_prev),
                       vmin=0,
                       vmax=1)
            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
Exemple #2
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        schedulerD.step()
        schedulerG.step()
        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy])
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))
    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
Exemple #3
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):

        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):

            stamp = datetime.datetime.now()
            timestamp.append(stamp)

            delta = (timestamp[-1] - timestamp[-2]).seconds
            mbs, percent = memory_check()
            print('scale %d:[%d/%d] | Mb: %.3f | Percent %.3f | secs %.3f ' %
                  (len(Gs), epoch, opt.niter, mbs, 100 * percent, delta))
            full_memory.append([mbs, percent])
            full_time.append(delta)

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)

    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate

    mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    mbs = mem_res.used / (1024**2)
    percent = mem_res.used / mem_res.total
    print(f'mem: {mem_res.used / (1024**2)} (GiB)')  # usage in GiB
    print(f'mem: {100 * (mem_res.used / mem_res.total):.6f}%')  # percentage
    return z_opt, in_s, netG, mbs, percent
Exemple #4
0
 opt.mode = 'paint_train'
 dir2trained_model = functions.generate_dir2save(opt)
 # N = len(reals) - 1
 # n = opt.paint_start_scale
 real_s = imresize(real, pow(opt.scale_factor, (N - n)), opt)
 real_s = real_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
 real_quant, centers = functions.quant(real_s, opt.device)
 plt.imsave('%s/real_quant.png' % dir2save,
            functions.convert_image_np(real_quant),
            vmin=0,
            vmax=1)
 plt.imsave('%s/in_paint.png' % dir2save,
            functions.convert_image_np(in_s),
            vmin=0,
            vmax=1)
 in_s = functions.quant2centers(ref, centers)
 in_s = imresize(in_s, pow(opt.scale_factor, (N - n)), opt)
 # in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]]
 # in_s = imresize(in_s, 1 / opt.scale_factor, opt)
 in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
 plt.imsave('%s/in_paint_quant.png' % dir2save,
            functions.convert_image_np(in_s),
            vmin=0,
            vmax=1)
 if (os.path.exists(dir2trained_model)):
     # print('Trained model does not exist, training SinGAN for SR')
     Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(
         opt)
     opt.mode = 'paint2image'
 else:
     train_paint(opt, Gs, Zs, reals, NoiseAmp, centers,
Exemple #5
0
def train_single_scale(netD,
                       netG,
                       reals,
                       crops,
                       masks,
                       eye_color,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real_fullsize = reals[len(Gs)]

    crop_size = crops[len(Gs)].size()[2]
    fixed_crop = real_fullsize[:, :, 0:crop_size, 0:crop_size]
    if opt.random_crop:
        real, h_idx, w_idx = functions.random_crop(real_fullsize.clone(),
                                                   crop_size)
    else:
        real = real_fullsize.clone()
    mask = masks[len(Gs)]

    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer) width
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer) height
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs))

    for epoch in range(opt.niter):

        if opt.resize:
            max_patch_size = int(
                min(real.size()[2],
                    real.size()[3],
                    mask.size()[2] * 1.25))
            min_patch_size = int(max(mask.size()[2] * 0.75, 1))
            patch_size = random.randint(min_patch_size, max_patch_size)
            mask_in = nn.functional.interpolate(mask.clone(), size=patch_size)
            eye_in = nn.functional.interpolate(eye.clone(), size=patch_size)
        else:
            mask_in = mask.clone()
            eye_in = eye.clone()

        eye_colored = eye_in.clone()
        if opt.random_eye_color:
            eye_color = functions.get_eye_color(real)
            eye_colored[:, 0, :, :] *= (eye_color[0] / 255)
            eye_colored[:, 1, :, :] *= (eye_color[1] / 255)
            eye_colored[:, 2, :, :] *= (eye_color[2] / 255)

        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()
            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = functions.draw_concat(Gs, Zs, reals, crops, masks,
                                                 eye_colored, NoiseAmp, in_s,
                                                 'rand', m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = functions.draw_concat(Gs, Zs, reals, crops, masks,
                                                   eye_colored, NoiseAmp, in_s,
                                                   'rec', m_noise, m_image,
                                                   opt)
                    criterion = nn.MSELoss()
                    #print(z_prev.get_device())
                    #print(real.get_device())
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = functions.draw_concat(Gs, Zs, reals, crops, masks,
                                             eye_colored, NoiseAmp, in_s,
                                             'rand', m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            # Stacking masks and noise to make input
            G_input = functions.make_input(noise, mask_in, eye_colored)
            fake_background = netG(G_input.detach(), prev)

            import copy
            netG_copy = copy.deepcopy(netG)

            # Cropping mask shape from generated image and putting on top of real image at random location
            fake, fake_ind, eye_ind = functions.gen_fake(
                real, fake_background, mask_in, eye_in, eye_color, opt)

            output = netD(fake.detach())
            errD_fake = output.mean()

            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):

            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                input_opt = functions.make_input(Z_opt, mask_in, eye_in)
                rec_loss = alpha * loss(netG(input_opt.detach(), z_prev), real)
                #rec_loss = alpha*loss(netG(input_opt.detach(),z_prev),fixed_crop)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()))
            plt.imsave('%s/fake_indicator.png' % (opt.outf),
                       functions.convert_image_np(fake_ind.detach()))
            plt.imsave('%s/eye_indicator.png' % (opt.outf),
                       functions.convert_image_np(eye_ind.detach()))
            plt.imsave('%s/background.png' % (opt.outf),
                       functions.convert_image_np(fake_background.detach()))
            #plt.imsave('%s/G(z_opt).png'    % (opt.outf),  functions.convert_image_np(netG(input_opt.detach(), z_prev).detach()), vmin=0, vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

        if opt.random_crop:
            real, h_idx, w_idx = functions.random_crop(
                real_fullsize, crop_size)  #randomly find crop in image
        if opt.random_eye:
            eye = functions.generate_eye_mask(opt, masks[-1],
                                              opt.stop_scale - len(Gs))

    functions.save_networks(netG, netD, z_opt, opt)

    if len(Gs) == (opt.stop_scale):
        netG = netG_copy

    return z_opt, in_s, netG
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):
    # print("Gs:", Gs)
    # Gs:scale尺度
    print("len(Gs):", len(Gs))

    # 获取当前scale的真实值
    real = reals[len(Gs)]

    opt.nzx = real.shape[2]  # +(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  # +(opt.ker_size-1)*(opt.num_layer)
    # 接受野
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    print(opt.receptive_field)  # out: 3+2*4*1 = 11
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    print(pad_noise)  # pad_noise: 5
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    # ZeroPad2d 在输入的数据周围做zero-padding
    m_noise = nn.ZeroPad2d(int(pad_noise))
    print('m_noise', m_noise)  # ZeroPad2d(padding=(5, 5, 5, 5), value=0.0)
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha
    print(alpha)

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
    # print('fixed_noise', fixed_noise.shape)  #  torch.Size([1, 3, 76, 76])
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # 设置优化器
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    # 绘画损失列表
    list_ssim = []
    list_ssim_1 = []
    list_ssim_2 = []
    list_ssim_3 = []
    list_ssim_4 = []

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    # 循环迭代
    for epoch in range(opt.niter):
        schedulerD.step()
        schedulerG.step()
        # 与运算
        if (Gs == []) & (opt.mode != 'SR_train'):
            # opt.nzx和opt.nzy是当前scale的尺寸
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy])
            # 扩充维度为(1, 3, opt.nzx, opt.nzy)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy])
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        # 更新判别器
        ###########################

        # Dsteps = 3
        for j in range(opt.Dsteps):
            # train with real 用真实图像训练
            netD.zero_grad()

            output = netD(real).to(opt.device)
            # print(netD)
            # print('output', output)  # 4维数据

            errD_real = -output.mean()  # -a
            # print('errD_real', errD_real) # -2.
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake 用虚假图像训练
            # 仅第一次训练用到z_prev(噪声)
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)

                    # nc_z  3通道
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    # print('z_prev', z_prev)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    # MSE 军方误差损失函数
                    criterion = nn.MSELoss()

                    # 均方根误差, 标准误差
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    # 标准误差
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            # .detach()用于切断反向传播
            fake = netG(noise.detach(), prev)

            # 添加的ssim_loss
            ssim_loss = ssim(real, fake, data_range=255, size_average=True)

            output = netD(fake.detach())
            # 判别以后损失反向传播
            errD_fake = output.mean()
            # print('errD_fake', errD_fake)  # -0.0072
            errD_fake.backward(retain_graph=True)
            # 判别器_生成器_噪声
            D_G_z = output.mean().item()
            # print('D_G_z', D_G_z)

            # 梯度惩罚---------------------------------------------------------------------------
            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad)
            # print('gradient_penalty', gradient_penalty)
            # 梯度惩罚更新
            gradient_penalty.backward()

            # 损失函数:对抗损失 + 重构损失
            D_ssim_3 = errD_real + errD_fake + gradient_penalty
            errD = errD_real + errD_fake + gradient_penalty

            # print('item', errD.item())

            # D_ssim = 0.8 * (errD_real + errD_fake) + 0.68 * gradient_penalty + (1 - ssim_loss)
            # D_ssim_1 = 0.7 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss)
            # D_ssim_2 = 0.75 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss)
            # errD = (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss)
            # D_ssim_4 = 0.6 * (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss)

            # int_ssim = D_ssim.item()
            # int_ssim = round(int_ssim, 4)

            # int_ssim_1 = D_ssim_1.item()
            # int_ssim_1 = round(int_ssim_1, 4)
            #
            # int_ssim_2 = D_ssim_2.item()
            # int_ssim_2 = round(int_ssim_2, 4)

            int_ssim_3 = D_ssim_3.item()
            int_ssim_3 = round(int_ssim_3, 4)

            optimizerD.step()

        errDint = []
        errD2plot.append(errD.detach())

        # print('errD2plot', errD2plot)
        for i in range(len(errD2plot)):
            errDint.append(errD2plot[i].cpu().numpy())

        # list_ssim.append(int_ssim)
        # list_ssim_1.append(int_ssim_1)
        # list_ssim_2.append(int_ssim_2)
        # list_ssim_3.append(int_ssim_3)
        # list_ssim_4.append(int_ssim_4)
        # print('list_ssim', list_ssim)

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            # D_fake_map = output.detach()

            # errG均值函数
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)

                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errGint = []
        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        for i in range(len(errG2plot)):
            errGint.append(errG2plot[i].cpu().numpy())

        if epoch % 100 == 0 or epoch == (opt.niter - 1):
            # len(Gs):scale   epoch= ,   niter = 2000
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            # plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            # plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            # plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            # plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            plt.imsave('%s/noise.png' % (opt.outf),
                       functions.convert_image_np(noise),
                       vmin=0,
                       vmax=1)
            # plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

    # 目前是训练单次scale, 绘制第一次scale
    while (len(Gs) == 0):
        name = 'G_loos & G_loos'
        functions.plot_learning_curves(errGint, errDint, opt.niter,
                                       'Generator', 'Discriminator', name)
        break

    while (len(Gs) == 1):
        name = 'G_loos & G_loos_1'
        functions.plot_learning_curves_1(errGint, errDint, opt.niter,
                                         'Generator', 'Discriminator', name)
        break

    # while (len(Gs) == 2):
    #     name = 'G_loos & G_loos_2'
    #     functions.plot_learning_curves(errGint, errDint,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD',
    #                                    'ssim_3', 'ssim_4', name)
    #     break

    # while (len(Gs) == 3):
    #     name = 'G_loos & G_loos_3'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break

    # while (len(Gs) == 4):
    #     name = 'G_loos & G_loos_4'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    # while (len(Gs) == 5):
    #     name = 'G_loos & G_loos_5'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    # while (len(Gs) == 6):
    #     name = 'G_loos & G_loos_6'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    # while (len(Gs) == 7):
    #     name = 'G_loos & G_loos_7'
    #     functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2,
    #                                    list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim',
    #                                    'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name)
    #     break
    #
    while (len(Gs) == 8):
        name = 'G_loos & G_loos_8'
        functions.plot_learning_curves_8(errGint, errDint, opt.niter,
                                         'Generator', 'Discriminator', name)
        break

    functions.save_networks(netG, netD, z_opt, opt)

    return z_opt, in_s, netG
Exemple #7
0
def train_single_scale(netD,
                       netG,
                       reals,
                       masks,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    mask = masks[len(Gs)]

    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    # Here we calculate the decrease in output size due to conv layers.
    # required for setting the size of discriminators map.
    # TODO: currently, calculation doesn't consider opt.stride
    if (opt.ker_size % 2 == 0):
        r = (opt.num_layer) * (opt.ker_size - 1)
    else:  #(opt.ker_size % 2 != 0):
        r = (opt.num_layer) * (opt.ker_size - 1) / 2
    r = int(r)

    _, _, h, w = mask.size()
    discriminators_mask = mask.detach()[:, :, r:h - r,
                                        r:w - r][:, 0, :, :].unsqueeze(0)
    _, _, h, w = discriminators_mask.size()

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []
    norm = []

    norm.append(1)
    norm.append((h * w) / discriminators_mask.sum().item())

    plt.imsave('%s/mask.png' % (opt.outf),
               functions.convert_image_np(real * mask))

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):
            if (epoch == 0):
                z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                                 device=opt.device)
                z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()
            output = netD(real).to(opt.device)
            output = output * discriminators_mask
            D_real_map = output.detach()
            errD_real = -(output.mean()) * norm[opt.norm]  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device,
                discriminators_mask * norm[opt.norm])
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                netG_out = netG(Z_opt.detach(), z_prev)
                netG_out = netG_out * mask
                real = real * mask
                rec_loss = alpha * loss(netG_out, real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errD2plot.append(errD.detach())
        errG2plot.append(errG.detach())
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/D_fake.png' % (opt.outf),
                       functions.convert_image_np(D_fake_map))
            plt.imsave('%s/D_real.png' % (opt.outf),
                       functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

            with open('%s/D_real2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(D_real2plot, fp)
            with open('%s/D_fake2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(D_fake2plot, fp)
            # with open('%s/D_real2plot' % (opt.outf), "rb") as fp:  # Unpickling
            #     D_fake2plot = pickle.load(fp)

            with open('%s/errD2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(errD2plot, fp)
            with open('%s/errG2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(errG2plot, fp)
            with open('%s/z_opt2plot' % (opt.outf), "wb") as fp:  # Pickling
                pickle.dump(z_opt2plot, fp)

        schedulerD.step()
        schedulerG.step()

    # plt.imsave('%s/masked_img.png'   % (opt.outf), functions.convert_image_np(real*mask))
    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
Exemple #8
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):
    """This is the core to understand it all. """
    real = reals[len(Gs)]  # real image at current scale
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':  # Supplementary says they generate noise on the border in animation mode
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy
                                            ])  # select a fixed noise at start
    z_opt = torch.full(
        fixed_noise.shape, 0,
        device=opt.device)  # but they didn't use it, just used all 0 instead.
    z_opt = m_noise(z_opt)  # get 0 padded (still all 0)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []  # collect the error in D training
    errG2plot = []  # collect the error in G training
    D_real2plot = []  # collect D_x error for real image
    D_fake2plot = []  # Discriminator error for fake image
    z_opt2plot = []  # collect reconstruction error

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):  # if it's the first scale.
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(
                1, 3, opt.nzx, opt.nzy))  # this is chosen start of each epoch
            noise_ = functions.generate_noise(
                [1, opt.nzx, opt.nzy],
                device=opt.device)  # single channel noise
            noise_ = m_noise(
                noise_.expand(1, 3, opt.nzx, opt.nzy)
            )  # expand is like repeat, it copy data along channel axis. So noise in RGB channels share the same val
        else:  # for each epocs only one noise_ is used! why?
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):  # a few D step first
            # train with real
            netD.zero_grad()

            output = netD(real).to(
                opt.device)  # Output of netD, the mean of it is the score!
            #D_real_map = output.detach()
            errD_real = -output.mean(
            )  #-a # want to maximize D output for patches in real img.
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()  # D loss for the real image!

            # train with fake, need to generate through the previous Generators
            if (j == 0) & (epoch == 0):  # first Dstep in this level (epoch 0)
                if (Gs == []) & (opt.mode != 'SR_train'):  # initial scale
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev  # in_s doesn't get padded!
                    prev = m_image(prev)  # prev gets padded!
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(
                        real, z_prev))  # MSE between real and z_prev
                    opt.noise_amp = opt.noise_amp_init * RMSE  # learn the noise amplitude
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)  # process the in_s !
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):  # top level
                noise = noise_  # now, full 0 tersor
            else:  # other level
                noise = opt.noise_amp * noise_ + prev  # now, full 0 tersor still
            # generate a single fake image through G and pass to D
            fake = netG(
                noise.detach(), prev
            )  # netG takes 2 inputs prev and noise, net process noise and + prev
            output = netD(
                fake.detach()
            )  # netD score the fake image, note they detach here, so error not back prop to G
            errD_fake = output.mean()  # decrease the D output for fake.
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach().item()
                         )  # loss combining D loss for real, fake and gradient

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):  # then a few G steps
            netG.zero_grad()
            output = netD(
                fake)  # note here the same fake image is used multi times!???
            #D_fake_map = output.detach()
            errG = -output.mean(
            )  # the D, adversarial loss part. here want to minimize adversarial loss. (Fake the img)
            errG.backward(retain_graph=True)  # Why? retain_graph
            if alpha != 0:  # compute the reconstruction loss is alpha non-zero
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(
                        z_prev, centers
                    )  # z_prev here are all inherited from D training part
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)  # Why? retain_graph
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach().item() + rec_loss.item())
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss.item())

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave(
                '%s/noise.png' % (opt.outf),
                functions.convert_image_np(noise),
                vmin=0,
                vmax=1
            )  # this is the noise that go into generator, so prev + amp * noise_
            plt.imsave('%s/Z_opt.png' % (opt.outf),
                       functions.convert_image_np(Z_opt),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/z_prev.png' % (opt.outf),
                       functions.convert_image_np(z_prev),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(
                {
                    "errD": errD2plot,
                    "errG": errG2plot,
                    "D_real": D_real2plot,
                    "D_fake": D_fake2plot,
                    "recons": z_opt2plot
                }, '%s/loss_trace.pth' % (opt.outf))
            plt.figure()
            plt.plot(errD2plot, label="errD")
            plt.plot(errG2plot, label="errG")
            plt.plot(D_real2plot, label="Dreal")
            plt.plot(D_fake2plot, label="Dfake")
            plt.plot(z_opt2plot, label="recons")
            plt.legend()
            plt.savefig('%s/loss_trace.png' % (opt.outf))
            plt.close()
            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG