示例#1
0
def train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale):
    in_s = torch.full(reals[0].shape, 0, device=opt.device)
    cur_scale_level = 0
    nfc_prev = 0

    while cur_scale_level < opt.stop_scale + 1:
        if cur_scale_level != paint_inject_scale:
            cur_scale_level += 1
            nfc_prev = opt.nfc
            continue
        else:
            opt.nfc = min(
                opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)
            opt.min_nfc = min(
                opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                128)

            opt.out_ = functions.generate_dir2save(opt)
            opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
            try:
                os.makedirs(opt.outf)
            except OSError:
                pass

            #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
            #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
            plt.imsave('%s/in_scale.png' % (opt.outf),
                       functions.convert_image_np(reals[cur_scale_level]),
                       vmin=0,
                       vmax=1)

            D_curr, G_curr = init_models(opt)

            z_curr, in_s, G_curr = train_single_scale(
                D_curr,
                G_curr,
                reals[:cur_scale_level + 1],
                Gs[:cur_scale_level],
                Zs[:cur_scale_level],
                in_s,
                NoiseAmp[:cur_scale_level],
                opt,
                centers=centers)

            G_curr = functions.reset_grads(G_curr, False)
            G_curr.eval()
            D_curr = functions.reset_grads(D_curr, False)
            D_curr.eval()

            Gs[cur_scale_level] = G_curr
            Zs[cur_scale_level] = z_curr
            NoiseAmp[cur_scale_level] = opt.noise_amp

            torch.save(Zs, '%s/Zs.pth' % (opt.out_))
            torch.save(Gs, '%s/Gs.pth' % (opt.out_))
            torch.save(reals, '%s/reals.pth' % (opt.out_))
            torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

            cur_scale_level += 1
            nfc_prev = opt.nfc
        del D_curr, G_curr
    return
示例#2
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    # generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1)
    # size: [opt.nc_z, opt.nzx, opt.nzy]
    # z_opt: input noise for calculating reconstruction loss in Generator
    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    # TODO: these plots are not visualized
    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        start_time = time.time()
        if (Gs == []) & (opt.mode != 'SR_train'):
            # Bottom generator, here without zero init
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            # noise_: input noise for the discriminator
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                # Initialize prev and z_prev
                # prev: image outputs from previous level
                # z_prev: image outputs from previous level of fixed noise z_opt
                if (Gs == []) & (opt.mode != 'SR_train'):
                    # in_s and prev are both noise
                    # z_prev are also noise
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch > 1 and (epoch % 100 == 0 or epoch == (opt.niter - 1)):
            total_time = time.time() - start_time
            start_time = time.time()
            print('scale %d:[%d/%d], total time: %f' %
                  (len(Gs), epoch, opt.niter, total_time))
            memory = torch.cuda.max_memory_allocated()
            # print('allocated memory: %dG %dM %dk %d' %
            #         ( memory // (1024*1024*1024),
            #           (memory // (1024*1024)) % 1024,
            #           (memory // 1024) % 1024,
            #           memory % 1024 ))
            print('allocated memory: %.03f GB' % (memory /
                                                  (1024 * 1024 * 1024 * 1.0)))

        # if epoch % 500 == 0 or epoch == (opt.niter-1):
        if epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            # plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            # plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            # plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            plt.imsave('%s/prev.png' % (opt.outf),
                       functions.convert_image_np(prev),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/prev_plus_noise.png' % (opt.outf),
                       functions.convert_image_np(noise),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/z_prev.png' % (opt.outf),
                       functions.convert_image_np(z_prev),
                       vmin=0,
                       vmax=1)
            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
示例#3
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    # cur_scale_level: current level from coarest to finest.
    cur_scale_level = 0
    # scale1: for the largest patch size, what ratio wrt the image shape
    reals = functions.creat_reals_pyramid(real_, reals, opt)
    nfc_prev = 0

    # Train including opt.stop_scale
    while cur_scale_level < opt.stop_scale + 1:
        # nfc: number of out channels in conv block
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)),
                      128)
        opt.min_nfc = min(
            opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128)

        # out_: output directory
        # outf: output folder, with scale
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, cur_scale_level)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[cur_scale_level]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1)))

        # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor.
        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()

        #################################################################################
        # Visualzie weights
        def visualize_weights(modules, fig_name):
            ori_weights = torch.tensor([]).cuda()
            for m in modules:
                cur_params = m.weight.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
                cur_params = m.bias.data.flatten()
                ori_weights = torch.cat((ori_weights, cur_params))
            # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement())
            ori_weights = ori_weights.cpu().numpy()
            ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100)
            plt.savefig("%s/%s.png" % (opt.outf, fig_name))
            plt.close()

        # Pruning all weights
        modules = [
            G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv,
            G_curr.body.block1.norm, G_curr.body.block2.conv,
            G_curr.body.block2.norm, G_curr.body.block3.conv,
            G_curr.body.block3.norm, G_curr.tail[0]
        ]
        parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv,
                                                              'bias'),
                               (G_curr.head.norm, 'weight'), (G_curr.head.norm,
                                                              'bias'),
                               (G_curr.body.block1.conv,
                                'weight'), (G_curr.body.block1.conv, 'bias'),
                               (G_curr.body.block1.norm,
                                'weight'), (G_curr.body.block1.norm, 'bias'),
                               (G_curr.body.block2.conv,
                                'weight'), (G_curr.body.block2.conv, 'bias'),
                               (G_curr.body.block2.norm,
                                'weight'), (G_curr.body.block2.norm, 'bias'),
                               (G_curr.body.block3.conv,
                                'weight'), (G_curr.body.block3.conv,
                                            'bias'), (G_curr.body.block3.norm,
                                                      'weight'),
                               (G_curr.body.block3.norm,
                                'bias'), (G_curr.tail[0],
                                          'weight'), (G_curr.tail[0], 'bias'))

        visualize_weights(modules, 'ori')

        # Prune weights
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=0.2,
        )

        for m in modules:
            prune.remove(m, 'weight')
            prune.remove(m, 'bias')

        visualize_weights(modules, 'prune')
        G_curr.half()
        #################################################################################
        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        cur_scale_level += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
        torch.cuda.empty_cache()
    return
示例#4
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    #if opt.input_type == 'audio':
    #    real = real.permute((0, 2, 1))
    print("@ train_single_scale:real.shape = ", real.shape, "| opt.mode = ",
          opt.mode)
    if opt.input_type == 'image':
        opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
        opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    else:
        opt.nzx = real.shape[1]  # +(opt.ker_size-1)*(opt.num_layer)
        opt.nzy = real.shape[2]  # +(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    if opt.input_type == 'image':
        m_noise = nn.ZeroPad2d(int(pad_noise))
        m_image = nn.ZeroPad2d(int(pad_image))
    else:
        m_noise = nn.ConstantPad1d(int(pad_noise), 0)
        m_image = nn.ConstantPad1d(int(pad_image), 0)
    print("m_noise")

    alpha = opt.alpha

    if opt.input_type == 'image':
        fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                               device=opt.device)
    else:
        fixed_noise = functions.generate_noise([opt.nzx, opt.nzy],
                                               device=opt.device)
    print("fixed_noise.shape = ", fixed_noise.shape)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    #z_opt = torch.full(fixed_noise.shape, 0, device=opt.device, dtype=int)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):
            if opt.input_type == 'image':
                z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                                 device=opt.device)
                if opt.conv_spectrogram == True:
                    z_opt = m_noise(z_opt.expand(1, 2, opt.nzx, opt.nzy))
                else:
                    z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            else:
                z_opt = functions.generate_noise([opt.nzx, opt.nzy],
                                                 device=opt.device)
                z_opt = m_noise(z_opt)
            if opt.input_type == 'image':
                noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                                  device=opt.device)
                if opt.conv_spectrogram == True:
                    noise_ = m_noise(noise_.expand(1, 2, opt.nzx, opt.nzy))
                else:
                    noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
            else:
                noise_ = functions.generate_noise([opt.nzx, opt.nzy],
                                                  device=opt.device)
                noise_ = m_noise(noise_)
        else:
            if opt.input_type == 'image':
                noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                                  device=opt.device)
                noise_ = m_noise(noise_)
            else:
                noise_ = functions.generate_noise([opt.nzx, opt.nzy],
                                                  device=opt.device)
                noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            if epoch % 100 == 0:
                print("@ train_single_scale: epoch = ", epoch,
                      "| real.shape = ", real.shape)
            # if opt.input_type == 'audio':
            #     real = real.permute((0,2,1))
            #     print("@ train_single_scale: real.shape = ", real.shape)
            output = netD(real).to(opt.device)
            # if opt.input_type == 'audio':
            #     real = real.permute((0,2,1))
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    if opt.input_type == 'image':
                        prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                          0,
                                          device=opt.device)
                    else:
                        prev = torch.full([1, opt.nzx, opt.nzy],
                                          0,
                                          device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    if opt.input_type == 'image':
                        z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                            0,
                                            device=opt.device)
                    else:
                        z_prev = torch.full([1, opt.nzx, opt.nzy],
                                            0,
                                            device=opt.device)
                    print("@ train_single_scale: z_prev.shape =", z_prev.shape)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    print('@ train_single_scale: real.shape = ', real.shape,
                          '| z_prev.shape = ', z_prev.shape)
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    print("@ train_single_scale: RMSE.shape = ", RMSE.shape)
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    print("@ train_single_scale: z_prev.shape = ",
                          z_prev.shape)
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    print('@ train_single_scale: real.shape = ', real.shape,
                          '| z_prev.shape = ', z_prev.shape)
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            if epoch % 100 == 0:
                print("@ train_single_scale: noise.detach().shape = ",
                      noise.detach().shape, "prev.shape = ", prev.shape,
                      "epoch = ", epoch, "j = ", j)
            # if opt.input_type == 'audio':
            #     noise = noise.permute((0, 2, 1))
            fake = netG(noise.detach(), prev)
            # if opt.input_type == 'audio':
            #     noise = noise.permute((0, 2, 1))
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            if opt.input_type == 'image':
                if opt.conv_spectrogram == True:
                    write('%s/fake_sample.wav' % (opt.outf), opt.sample_rate,
                          functions.convert_spectrogram_np(fake.detach(), opt))
                    write(
                        '%s/G(z_opt).wav' % (opt.outf), opt.sample_rate,
                        functions.convert_spectrogram_np(
                            netG(Z_opt.detach(), z_prev).detach(), opt))
                else:
                    plt.imsave('%s/fake_sample.png' % (opt.outf),
                               functions.convert_image_np(fake.detach()),
                               vmin=0,
                               vmax=1)
                    plt.imsave('%s/G(z_opt).png' % (opt.outf),
                               functions.convert_image_np(
                                   netG(Z_opt.detach(), z_prev).detach()),
                               vmin=0,
                               vmax=1)
                    #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
                    #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
                    #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
                    #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
                    #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
                    #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)
            else:
                write('%s/fake_sample.wav' % (opt.outf), opt.sample_rate,
                      functions.convert_audio_np(fake.detach(), opt))
                write(
                    '%s/G(z_opt).wav' % (opt.outf), opt.sample_rate,
                    functions.convert_audio_np(
                        netG(Z_opt.detach(), z_prev).detach(), opt))

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    #if opt.input_type == 'audio':
    #    real = real.permute((0, 2, 1))
    return z_opt, in_s, netG
            in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
            #opt.gen_start_scale=0
            #print(in_s.shape)
            #in_s = torch.full(reals[0].shape, 0, device=opt.device)
            #in_s[0,:,:20,:]=0.2

            if opt.quantization_flag:
                opt.mode = 'paint_train'
                dir2trained_model = functions.generate_dir2save(opt)
                # N = len(reals) - 1
                # n = opt.paint_start_scale
                real_s = imresize(real, pow(opt.scale_factor, (N - n)), opt)
                real_s = real_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
                real_quant, centers = functions.quant(real_s, opt.device)
                plt.imsave('%s/real_quant.png' % dir2save,
                           functions.convert_image_np(real_quant),
                           vmin=0,
                           vmax=1)
                plt.imsave('%s/in_paint.png' % dir2save,
                           functions.convert_image_np(in_s),
                           vmin=0,
                           vmax=1)
                in_s = functions.quant2centers(ref, centers)
                in_s = imresize(in_s, pow(opt.scale_factor, (N - n)), opt)
                # in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]]
                # in_s = imresize(in_s, 1 / opt.scale_factor, opt)
                in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
                plt.imsave('%s/in_paint_quant.png' % dir2save,
                           functions.convert_image_np(in_s),
                           vmin=0,
                           vmax=1)
示例#6
0
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=50):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape, 0, device=opt.device)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzh = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzw = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzh, nzw],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzh, nzw],
                                                  device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
                # I_prev = m(I_prev)
                # I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                # I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                                  z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            # print('z_curr:',z_curr.size())
            # print('I_prev:',I_prev.size())

            z_in = noise_amp * (z_curr) + I_prev
            I_curr = G(z_in.detach(), I_prev)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i),
                               functions.convert_image_np(I_curr.detach()),
                               vmin=0,
                               vmax=1)
                    #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()
示例#7
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    if opt.input_type == 'image':
        real_ = functions.read_image(opt)
    else:
        real_ = functions.read_audio(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        if opt.input_type == 'image':
            if opt.conv_spectrogram == True:
                write('%s/real_scale.wav' % (opt.outf), opt.sample_rate,
                      functions.convert_spectrogram_np(reals[scale_num], opt))
            else:
                plt.imsave('%s/real_scale.png' % (opt.outf),
                           functions.convert_image_np(reals[scale_num]),
                           vmin=0,
                           vmax=1)
        else:
            write('%s/real_scale.wav' % (opt.outf), opt.sample_rate,
                  functions.convert_audio_np(reals[scale_num], opt))

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
示例#8
0
            N = len(reals) - 1
            n = opt.paint_start_scale
            in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt)
            in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]]
            in_s = imresize(in_s, 1 / opt.scale_factor, opt)
            in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
            if opt.quantization_flag:
                opt.mode = 'paint_train'
                dir2trained_model = functions.generate_dir2save(opt)
                # N = len(reals) - 1
                # n = opt.paint_start_scale
                real_s = imresize(real, pow(opt.scale_factor, (N - n)), opt)
                real_s = real_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
                real_quant, centers = functions.quant(real_s)
                plt.imsave('%s/real_quant.png' % dir2save, functions.convert_image_np(real_quant), vmin=0, vmax=1)
                plt.imsave('%s/in_paint.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1)
                in_s = functions.quant2centers(ref, centers)
                in_s = imresize(in_s, pow(opt.scale_factor, (N - n)), opt)
                # in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]]
                # in_s = imresize(in_s, 1 / opt.scale_factor, opt)
                in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
                plt.imsave('%s/in_paint_quant.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1)
                if (os.path.exists(dir2trained_model)):
                    # print('Trained model does not exist, training SinGAN for SR')
                    Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt)
                    opt.mode = 'paint2image'
                else:
                    train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, opt.paint_start_scale)
                    opt.mode = 'paint2image'
            out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1)
示例#9
0
                (opt.ref_dir, opt.ref_name[:-4], opt.ref_name[-4:]), opt)
            if ref.shape[3] != real.shape[3]:
                mask = imresize_to_shape(mask, [real.shape[2], real.shape[3]],
                                         opt)
                mask = mask[:, :, :real.shape[2], :real.shape[3]]
                ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]],
                                        opt)
                ref = ref[:, :, :real.shape[2], :real.shape[3]]
            mask = functions.dilate_mask(mask, opt)

            N = len(reals) - 1
            n = opt.harmonization_start_scale
            in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt)
            in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]]
            in_s = imresize(in_s, 1 / opt.scale_factor, opt)
            in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
            out = SinGAN_generate(Gs[n:],
                                  Zs[n:],
                                  reals,
                                  NoiseAmp[n:],
                                  opt,
                                  in_s,
                                  n=n,
                                  num_samples=1)
            out = (1 - mask) * real + mask * out
            plt.imsave('%s/start_scale=%d.png' %
                       (dir2save, opt.harmonization_start_scale),
                       functions.convert_image_np(out.detach()),
                       vmin=0,
                       vmax=1)
示例#10
0
            ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt)
            mask = functions.read_image_dir('%s/%s_mask%s' % (opt.ref_dir,opt.ref_name[:-4],opt.ref_name[-4:]), opt)
            if ref.shape[3] != real.shape[3]:
                '''
                mask = imresize(mask, real.shape[3]/ref.shape[3], opt)
                mask = mask[:, :, :real.shape[2], :real.shape[3]]
                ref = imresize(ref, real.shape[3] / ref.shape[3], opt)
                ref = ref[:, :, :real.shape[2], :real.shape[3]]
                '''
                mask = imresize_to_shape(mask, [real.shape[2],real.shape[3]], opt)
                mask = mask[:, :, :real.shape[2], :real.shape[3]]
                ref = imresize_to_shape(ref, [real.shape[2],real.shape[3]], opt)
                ref = ref[:, :, :real.shape[2], :real.shape[3]]

            mask = functions.dilate_mask(mask, opt)

            N = len(reals) - 1
            n = opt.editing_start_scale
            in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt)
            in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]]
            in_s = imresize(in_s, 1 / opt.scale_factor, opt)
            in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]]
            out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1)
            plt.imsave('%s/start_scale=%d.png' % (dir2save, opt.editing_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1)
            out = (1-mask)*real+mask*out
            plt.imsave('%s/start_scale=%d_masked.png' % (dir2save, opt.editing_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1)




示例#11
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_images(opt)
    in_s = 0
    scale_num = 0
    real = [
        imresize(real_[0], opt.scale1, opt),
        imresize(real_[1], opt.scale1, opt)
    ]
    reals = [
        functions.creat_reals_pyramid(real[0], reals, opt),
        functions.creat_reals_pyramid(real[1], reals, opt)
    ]
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale1.png' % (opt.outf),
                   functions.convert_image_np(reals[0][scale_num]),
                   vmin=0,
                   vmax=1)
        plt.imsave('%s/real_scale2.png' % (opt.outf),
                   functions.convert_image_np(reals[1][scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr1, G_curr1 = init_models(opt)
        D_curr2, G_curr2 = init_models(opt)
        D_curr = [D_curr1, D_curr2]
        G_curr = [G_curr1, G_curr2]

        if (nfc_prev == opt.nfc):
            G_curr[0].load_state_dict(
                torch.load('%s/%d/netG1.pth' % (opt.out_, scale_num - 1)))
            D_curr[0].load_state_dict(
                torch.load('%s/%d/netD1.pth' % (opt.out_, scale_num - 1)))
            G_curr[1].load_state_dict(
                torch.load('%s/%d/netG2.pth' % (opt.out_, scale_num - 1)))
            D_curr[1].load_state_dict(
                torch.load('%s/%d/netD2.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr[0] = functions.reset_grads(G_curr[0], False)
        G_curr[0].eval()
        D_curr[0] = functions.reset_grads(D_curr[0], False)
        D_curr[0].eval()
        G_curr[1] = functions.reset_grads(G_curr[1], False)
        G_curr[1].eval()
        D_curr[1] = functions.reset_grads(D_curr[1], False)
        D_curr[1].eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
示例#12
0
def SinGAN_generate(Gs,
                    Zs,
                    reals,
                    NoiseAmp,
                    opt,
                    in_s=None,
                    scale_v=1,
                    scale_h=1,
                    n=0,
                    gen_start_scale=0,
                    num_samples=50):
    if in_s is None:  # torch NCWH tf NWHC
        in_s = tf.zeros_like(reals[0])
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = tf.keras.layers.ZeroPadding2D(padding=(int(pad1), int(pad1)))
        nzx = (Z_opt.shape[1] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[2] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy])
                z_curr = tf.tile(z_curr, multiples=(1, 1, 1, 3))
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy])
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, 0:round(scale_v * reals[n].shape[1]),
                                    0:round(scale_h * reals[n].shape[2]), :]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, 0:z_curr.shape[1], 0:z_curr.shape[2], :]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[1],
                                                  z_curr.shape[2])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev
            I_curr = G(z_in, I_prev, train=True)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    plt.imsave('%s/%d.png' % (dir2save, i),
                               functions.convert_image_np(I_curr))
            images_cur.append(I_curr)
        n += 1
    return I_curr