def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt): G_z = in_s if len(Gs) > 0: if mode == 'rand': count = 0 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': pad_noise = 0 for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): if count == 0: z = functions.generate_noise([ 1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) z = z.expand( 1, opt.nc_z, z.shape[2], z.shape[3] ) # changed the second parameter from 3 to opt.nc_z else: z = functions.generate_noise([ opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) z = m_noise(z) G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp * z + G_z G_z = G(z_in.detach(), G_z) G_z = imresize(G_z, 1 / opt.scale_factor, opt) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 if mode == 'rec': count = 0 for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp * Z_opt + G_z G_z = G(z_in.detach(), G_z) G_z = imresize(G_z, 1 / opt.scale_factor, opt) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] #if count != (len(Gs)-1): # G_z = m_image(G_z) count += 1 return G_z
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) #print("real_ ====", real_.shape) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) #print("real 1 ===", real.shape) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def creat_reals_pyramid(real,reals,opt): if real.shape[1]==4: real = real[:,0:4,:,:] # added by vajira else: real = real[:,0:3,:,:] for i in range(0,opt.stop_scale+1,1): scale = math.pow(opt.scale_factor,opt.stop_scale-i) curr_real = imresize(real,scale,opt) reals.append(curr_real) return reals
def adjust_scales2image_SR(real_,opt): opt.min_size = 18 opt.num_scales = int((math.log(opt.min_size / min(real_.shape[2], real_.shape[3]), opt.scale_factor_init))) + 1 scale2stop = int(math.log(min(opt.max_size , max(real_.shape[2], real_.shape[3])) / max(real_.shape[0], real_.shape[3]), opt.scale_factor_init)) opt.stop_scale = opt.num_scales - scale2stop opt.scale1 = min(opt.max_size / max([real_.shape[2], real_.shape[3]]), 1) # min(250/max([real_.shape[0],real_.shape[1]]),1) real = imresize(real_, opt.scale1, opt) #opt.scale_factor = math.pow(opt.min_size / (real.shape[2]), 1 / (opt.stop_scale)) opt.scale_factor = math.pow(opt.min_size/(min(real.shape[2],real.shape[3])),1/(opt.stop_scale)) scale2stop = int(math.log(min(opt.max_size, max(real_.shape[2], real_.shape[3])) / max(real_.shape[0], real_.shape[3]), opt.scale_factor_init)) opt.stop_scale = opt.num_scales - scale2stop return real
def generate_gif(Gs, Zs, reals, NoiseAmp, opt, alpha=0.1, beta=0.9, start_scale=2, fps=10): in_s = torch.full(Zs[0].shape, 0, device=opt.device, dtype=torch.bool) images_cur = [] count = 0 for G, Z_opt, noise_amp, real in zip(Gs, Zs, NoiseAmp, reals): pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) nzx = Z_opt.shape[2] nzy = Z_opt.shape[3] #pad_noise = 0 #m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) images_prev = images_cur images_cur = [] if count == 0: z_rand = functions.generate_noise([1, nzx, nzy], device=opt.device) z_rand = z_rand.expand(1, opt.nc_z, Z_opt.shape[2], Z_opt.shape[3]) z_prev1 = 0.95 * Z_opt + 0.05 * z_rand z_prev2 = Z_opt else: z_prev1 = 0.95 * Z_opt + 0.05 * functions.generate_noise( [opt.nc_z, nzx, nzy], device=opt.device) z_prev2 = Z_opt for i in range(0, 100, 1): if count == 0: z_rand = functions.generate_noise([1, nzx, nzy], device=opt.device) z_rand = z_rand.expand(1, opt.nc_z, Z_opt.shape[2], Z_opt.shape[3]) diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * z_rand else: diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * ( functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device)) z_curr = alpha * Z_opt + (1 - alpha) * (z_prev1 + diff_curr) z_prev2 = z_prev1 z_prev1 = z_curr if images_prev == []: I_prev = in_s else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]] #I_prev = functions.upsampling(I_prev,reals[count].shape[2],reals[count].shape[3]) I_prev = m_image(I_prev) if count < start_scale: z_curr = Z_opt z_in = noise_amp * z_curr + I_prev I_curr = G(z_in.detach(), I_prev) if (count == len(Gs) - 1): I_curr = functions.denorm(I_curr).detach() I_curr = I_curr[0, :, :, :].cpu().numpy() I_curr = I_curr.transpose(1, 2, 0) * 255 I_curr = I_curr.astype(np.uint8) images_cur.append(I_curr) count += 1 dir2save = functions.generate_dir2save(opt) try: os.makedirs('%s/start_scale=%d' % (dir2save, start_scale)) except OSError: pass imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' % (dir2save, start_scale, alpha, beta), images_cur, fps=fps) del images_cur
def SinGAN_generate_clean(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50): #if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device, dtype=torch.bool) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] img_mask_paths = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, opt.nc_z, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) #I_prev = m(I_prev) #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp * (z_curr) + I_prev I_curr = G(z_in.detach(), I_prev) if n == len(reals) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(opt.dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): #plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) img_path = '%s/chk_id_%s_gen_scale_%d_%d_img.png' % ( opt.dir2save, opt.checkpoint_id, opt.gen_start_scale, i) mask_path = '%s/chk_id_%s_gen_scale_%d_%d_mask.png' % ( opt.dir2save, opt.checkpoint_id, opt.gen_start_scale, i) mask = functions.convert_image_np(I_curr.detach())[:, :, 3] vmax = 1 #print(mask) if opt.mask_post_processing: mask = (mask > 0.5) * 255 mask = mask.astype(np.uint8) mask = np.dstack( [mask] * 3) # to make sure all channels are same in mask vmax = 255 plt.imsave(img_path, functions.convert_image_np( I_curr.detach())[:, :, 0:3], vmin=0, vmax=vmax) # Vajira plt.imsave(mask_path, mask, vmin=0, vmax=vmax) # Vajira #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) n += 1 return I_curr.detach()