Exemplo n.º 1
0
def train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale):
    in_s = torch.full(reals[0].shape, 0, device=opt.device, dtype=torch.bool)
    scale_num = 0
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        if scale_num != paint_inject_scale:
            scale_num += 1
            nfc_prev = opt.nfc
            continue
        else:
            opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)
            opt.min_nfc = min(
                opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)

            opt.out_ = functions.generate_dir2save(opt)
            opt.outf = '%s/%d' % (opt.out_, scale_num)
            try:
                os.makedirs(opt.outf)
            except OSError:
                pass

            #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
            #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
            plt.imsave('%s/in_scale.png' % (opt.outf),
                       functions.convert_image_np(reals[scale_num]),
                       vmin=0,
                       vmax=1)

            D_curr, G_curr = init_models(opt)

            z_curr, in_s, G_curr = train_single_scale(D_curr,
                                                      G_curr,
                                                      reals[:scale_num + 1],
                                                      Gs[:scale_num],
                                                      Zs[:scale_num],
                                                      in_s,
                                                      NoiseAmp[:scale_num],
                                                      opt,
                                                      centers=centers)

            G_curr = functions.reset_grads(G_curr, False)
            G_curr.eval()
            D_curr = functions.reset_grads(D_curr, False)
            D_curr.eval()

            Gs[scale_num] = G_curr
            Zs[scale_num] = z_curr
            NoiseAmp[scale_num] = opt.noise_amp

            torch.save(Zs, '%s/Zs.pth' % (opt.out_))
            torch.save(Gs, '%s/Gs.pth' % (opt.out_))
            torch.save(reals, '%s/reals.pth' % (opt.out_))
            torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

            scale_num += 1
            nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Exemplo n.º 2
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    #print("real_ ====", real_.shape)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    #print("real 1 ===", real.shape)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Exemplo n.º 3
0
def generate_gif(Gs,
                 Zs,
                 reals,
                 NoiseAmp,
                 opt,
                 alpha=0.1,
                 beta=0.9,
                 start_scale=2,
                 fps=10):

    in_s = torch.full(Zs[0].shape, 0, device=opt.device, dtype=torch.bool)
    images_cur = []
    count = 0

    for G, Z_opt, noise_amp, real in zip(Gs, Zs, NoiseAmp, reals):
        pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
        nzx = Z_opt.shape[2]
        nzy = Z_opt.shape[3]
        #pad_noise = 0
        #m_noise = nn.ZeroPad2d(int(pad_noise))
        m_image = nn.ZeroPad2d(int(pad_image))
        images_prev = images_cur
        images_cur = []
        if count == 0:
            z_rand = functions.generate_noise([1, nzx, nzy], device=opt.device)
            z_rand = z_rand.expand(1, opt.nc_z, Z_opt.shape[2], Z_opt.shape[3])
            z_prev1 = 0.95 * Z_opt + 0.05 * z_rand
            z_prev2 = Z_opt
        else:
            z_prev1 = 0.95 * Z_opt + 0.05 * functions.generate_noise(
                [opt.nc_z, nzx, nzy], device=opt.device)
            z_prev2 = Z_opt

        for i in range(0, 100, 1):
            if count == 0:
                z_rand = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_rand = z_rand.expand(1, opt.nc_z, Z_opt.shape[2],
                                       Z_opt.shape[3])
                diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * z_rand
            else:
                diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * (
                    functions.generate_noise([opt.nc_z, nzx, nzy],
                                             device=opt.device))

            z_curr = alpha * Z_opt + (1 - alpha) * (z_prev1 + diff_curr)
            z_prev2 = z_prev1
            z_prev1 = z_curr

            if images_prev == []:
                I_prev = in_s
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]]
                #I_prev = functions.upsampling(I_prev,reals[count].shape[2],reals[count].shape[3])
                I_prev = m_image(I_prev)
            if count < start_scale:
                z_curr = Z_opt

            z_in = noise_amp * z_curr + I_prev
            I_curr = G(z_in.detach(), I_prev)

            if (count == len(Gs) - 1):
                I_curr = functions.denorm(I_curr).detach()
                I_curr = I_curr[0, :, :, :].cpu().numpy()
                I_curr = I_curr.transpose(1, 2, 0) * 255
                I_curr = I_curr.astype(np.uint8)

            images_cur.append(I_curr)
        count += 1
    dir2save = functions.generate_dir2save(opt)
    try:
        os.makedirs('%s/start_scale=%d' % (dir2save, start_scale))
    except OSError:
        pass
    imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' %
                    (dir2save, start_scale, alpha, beta),
                    images_cur,
                    fps=fps)
    del images_cur
Exemplo n.º 4
0
def SinGAN_generate_clean(Gs,
                          Zs,
                          reals,
                          NoiseAmp,
                          opt,
                          in_s=None,
                          scale_v=1,
                          scale_h=1,
                          n=0,
                          gen_start_scale=0,
                          num_samples=50):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape,
                          0,
                          device=opt.device,
                          dtype=torch.bool)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        img_mask_paths = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, opt.nc_z, z_curr.shape[2],
                                       z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                                  z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev
            I_curr = G(z_in.detach(), I_prev)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(opt.dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    #plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

                    img_path = '%s/chk_id_%s_gen_scale_%d_%d_img.png' % (
                        opt.dir2save, opt.checkpoint_id, opt.gen_start_scale,
                        i)
                    mask_path = '%s/chk_id_%s_gen_scale_%d_%d_mask.png' % (
                        opt.dir2save, opt.checkpoint_id, opt.gen_start_scale,
                        i)

                    mask = functions.convert_image_np(I_curr.detach())[:, :, 3]
                    vmax = 1
                    #print(mask)

                    if opt.mask_post_processing:
                        mask = (mask > 0.5) * 255
                        mask = mask.astype(np.uint8)
                        mask = np.dstack(
                            [mask] *
                            3)  # to make sure all channels are same in mask
                        vmax = 255

                    plt.imsave(img_path,
                               functions.convert_image_np(
                                   I_curr.detach())[:, :, 0:3],
                               vmin=0,
                               vmax=vmax)  # Vajira
                    plt.imsave(mask_path, mask, vmin=0, vmax=vmax)  # Vajira

                    #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()