Esempio n. 1
0
def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt):
    G_z = in_s
    if len(Gs) > 0:
        if mode == 'rand':
            count = 0
            pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
            if opt.mode == 'animation_train':
                pad_noise = 0
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                if count == 0:
                    z = functions.generate_noise([
                        1, Z_opt.shape[2] - 2 * pad_noise,
                        Z_opt.shape[3] - 2 * pad_noise
                    ],
                                                 device=opt.device)
                    z = z.expand(
                        1, opt.nc_z, z.shape[2], z.shape[3]
                    )  # changed the second parameter from 3 to opt.nc_z
                else:
                    z = functions.generate_noise([
                        opt.nc_z, Z_opt.shape[2] - 2 * pad_noise,
                        Z_opt.shape[3] - 2 * pad_noise
                    ],
                                                 device=opt.device)
                z = m_noise(z)
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp * z + G_z
                G_z = G(z_in.detach(), G_z)
                G_z = imresize(G_z, 1 / opt.scale_factor, opt)
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]
                count += 1
        if mode == 'rec':
            count = 0
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp * Z_opt + G_z
                G_z = G(z_in.detach(), G_z)
                G_z = imresize(G_z, 1 / opt.scale_factor, opt)
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]
                #if count != (len(Gs)-1):
                #    G_z = m_image(G_z)
                count += 1
    return G_z
Esempio n. 2
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    #print("real_ ====", real_.shape)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    #print("real 1 ===", real.shape)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
def creat_reals_pyramid(real,reals,opt):
    if real.shape[1]==4:
        real = real[:,0:4,:,:] # added by vajira
    else:
        real = real[:,0:3,:,:]
    for i in range(0,opt.stop_scale+1,1):
        scale = math.pow(opt.scale_factor,opt.stop_scale-i)
        curr_real = imresize(real,scale,opt)
        reals.append(curr_real)
    return reals
def adjust_scales2image_SR(real_,opt):
    opt.min_size = 18
    opt.num_scales = int((math.log(opt.min_size / min(real_.shape[2], real_.shape[3]), opt.scale_factor_init))) + 1
    scale2stop = int(math.log(min(opt.max_size , max(real_.shape[2], real_.shape[3])) / max(real_.shape[0], real_.shape[3]), opt.scale_factor_init))
    opt.stop_scale = opt.num_scales - scale2stop
    opt.scale1 = min(opt.max_size / max([real_.shape[2], real_.shape[3]]), 1)  # min(250/max([real_.shape[0],real_.shape[1]]),1)
    real = imresize(real_, opt.scale1, opt)
    #opt.scale_factor = math.pow(opt.min_size / (real.shape[2]), 1 / (opt.stop_scale))
    opt.scale_factor = math.pow(opt.min_size/(min(real.shape[2],real.shape[3])),1/(opt.stop_scale))
    scale2stop = int(math.log(min(opt.max_size, max(real_.shape[2], real_.shape[3])) / max(real_.shape[0], real_.shape[3]), opt.scale_factor_init))
    opt.stop_scale = opt.num_scales - scale2stop
    return real
def generate_gif(Gs,
                 Zs,
                 reals,
                 NoiseAmp,
                 opt,
                 alpha=0.1,
                 beta=0.9,
                 start_scale=2,
                 fps=10):

    in_s = torch.full(Zs[0].shape, 0, device=opt.device, dtype=torch.bool)
    images_cur = []
    count = 0

    for G, Z_opt, noise_amp, real in zip(Gs, Zs, NoiseAmp, reals):
        pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
        nzx = Z_opt.shape[2]
        nzy = Z_opt.shape[3]
        #pad_noise = 0
        #m_noise = nn.ZeroPad2d(int(pad_noise))
        m_image = nn.ZeroPad2d(int(pad_image))
        images_prev = images_cur
        images_cur = []
        if count == 0:
            z_rand = functions.generate_noise([1, nzx, nzy], device=opt.device)
            z_rand = z_rand.expand(1, opt.nc_z, Z_opt.shape[2], Z_opt.shape[3])
            z_prev1 = 0.95 * Z_opt + 0.05 * z_rand
            z_prev2 = Z_opt
        else:
            z_prev1 = 0.95 * Z_opt + 0.05 * functions.generate_noise(
                [opt.nc_z, nzx, nzy], device=opt.device)
            z_prev2 = Z_opt

        for i in range(0, 100, 1):
            if count == 0:
                z_rand = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_rand = z_rand.expand(1, opt.nc_z, Z_opt.shape[2],
                                       Z_opt.shape[3])
                diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * z_rand
            else:
                diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * (
                    functions.generate_noise([opt.nc_z, nzx, nzy],
                                             device=opt.device))

            z_curr = alpha * Z_opt + (1 - alpha) * (z_prev1 + diff_curr)
            z_prev2 = z_prev1
            z_prev1 = z_curr

            if images_prev == []:
                I_prev = in_s
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]]
                #I_prev = functions.upsampling(I_prev,reals[count].shape[2],reals[count].shape[3])
                I_prev = m_image(I_prev)
            if count < start_scale:
                z_curr = Z_opt

            z_in = noise_amp * z_curr + I_prev
            I_curr = G(z_in.detach(), I_prev)

            if (count == len(Gs) - 1):
                I_curr = functions.denorm(I_curr).detach()
                I_curr = I_curr[0, :, :, :].cpu().numpy()
                I_curr = I_curr.transpose(1, 2, 0) * 255
                I_curr = I_curr.astype(np.uint8)

            images_cur.append(I_curr)
        count += 1
    dir2save = functions.generate_dir2save(opt)
    try:
        os.makedirs('%s/start_scale=%d' % (dir2save, start_scale))
    except OSError:
        pass
    imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' %
                    (dir2save, start_scale, alpha, beta),
                    images_cur,
                    fps=fps)
    del images_cur
def SinGAN_generate_clean(Gs,
                          Zs,
                          reals,
                          NoiseAmp,
                          opt,
                          in_s=None,
                          scale_v=1,
                          scale_h=1,
                          n=0,
                          gen_start_scale=0,
                          num_samples=50):
    #if torch.is_tensor(in_s) == False:
    if in_s is None:
        in_s = torch.full(reals[0].shape,
                          0,
                          device=opt.device,
                          dtype=torch.bool)
    images_cur = []
    for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp):
        pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2
        m = nn.ZeroPad2d(int(pad1))
        nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v
        nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h

        images_prev = images_cur
        images_cur = []

        img_mask_paths = []

        for i in range(0, num_samples, 1):
            if n == 0:
                z_curr = functions.generate_noise([1, nzx, nzy],
                                                  device=opt.device)
                z_curr = z_curr.expand(1, opt.nc_z, z_curr.shape[2],
                                       z_curr.shape[3])
                z_curr = m(z_curr)
            else:
                z_curr = functions.generate_noise([opt.nc_z, nzx, nzy],
                                                  device=opt.device)
                z_curr = m(z_curr)

            if images_prev == []:
                I_prev = m(in_s)
                #I_prev = m(I_prev)
                #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]]
                #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3])
            else:
                I_prev = images_prev[i]
                I_prev = imresize(I_prev, 1 / opt.scale_factor, opt)
                if opt.mode != "SR":
                    I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]),
                                    0:round(scale_h * reals[n].shape[3])]
                    I_prev = m(I_prev)
                    I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]]
                    I_prev = functions.upsampling(I_prev, z_curr.shape[2],
                                                  z_curr.shape[3])
                else:
                    I_prev = m(I_prev)

            if n < gen_start_scale:
                z_curr = Z_opt

            z_in = noise_amp * (z_curr) + I_prev
            I_curr = G(z_in.detach(), I_prev)

            if n == len(reals) - 1:
                if opt.mode == 'train':
                    dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (
                        opt.out, opt.input_name[:-4], gen_start_scale)
                else:
                    dir2save = functions.generate_dir2save(opt)
                try:
                    os.makedirs(opt.dir2save)
                except OSError:
                    pass
                if (opt.mode != "harmonization") & (opt.mode != "editing") & (
                        opt.mode != "SR") & (opt.mode != "paint2image"):
                    #plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1)

                    img_path = '%s/chk_id_%s_gen_scale_%d_%d_img.png' % (
                        opt.dir2save, opt.checkpoint_id, opt.gen_start_scale,
                        i)
                    mask_path = '%s/chk_id_%s_gen_scale_%d_%d_mask.png' % (
                        opt.dir2save, opt.checkpoint_id, opt.gen_start_scale,
                        i)

                    mask = functions.convert_image_np(I_curr.detach())[:, :, 3]
                    vmax = 1
                    #print(mask)

                    if opt.mask_post_processing:
                        mask = (mask > 0.5) * 255
                        mask = mask.astype(np.uint8)
                        mask = np.dstack(
                            [mask] *
                            3)  # to make sure all channels are same in mask
                        vmax = 255

                    plt.imsave(img_path,
                               functions.convert_image_np(
                                   I_curr.detach())[:, :, 0:3],
                               vmin=0,
                               vmax=vmax)  # Vajira
                    plt.imsave(mask_path, mask, vmin=0, vmax=vmax)  # Vajira

                    #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1)
                    #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1)
            images_cur.append(I_curr)
        n += 1
    return I_curr.detach()