def __init__(self, input_name, load_existing_model): self.input_name = input_name self.load_existing_model = load_existing_model self.opt = Config(input_name) self.dir2save = functions.generate_dir2save(self.opt) self.real = functions.read_image(self.opt) functions.adjust_scales2image(self.real, self.opt) dir_exists = os.path.exists(self.dir2save) assert (not load_existing_model) or dir_exists, "cannot find trained model" if load_existing_model: print("Trained model has been loaded (not really)") self.Gs = torch.load(f'{self.dir2save}/Gs.pth') self.Zs = torch.load(f'{self.dir2save}/Zs.pth') self.reals = torch.load(f'{self.dir2save}/reals.pth') self.NoiseAmp = torch.load(f'{self.dir2save}/NoiseAmp.pth') self.is_loaded = True else: if dir_exists: user_input = input("Trained model has been found, type \"yes\" to overwrite: ") assert user_input == 'yes', "train aborted" rmtree(self.dir2save) print("train directory has been deleted") try: os.makedirs(self.dir2save) except OSError: pass
def train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale): in_s = torch.full(reals[0].shape, 0, device=opt.device) scale_num = 0 nfc_prev = 0 while scale_num < opt.stop_scale + 1: if scale_num != paint_inject_scale: scale_num += 1 nfc_prev = opt.nfc continue else: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/in_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals[:scale_num + 1], Gs[:scale_num], Zs[:scale_num], in_s, NoiseAmp[:scale_num], opt, centers=centers) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs[scale_num] = G_curr Zs[scale_num] = z_curr NoiseAmp[scale_num] = opt.noise_amp torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train_model(input_name): parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--input_name', help='input image name', required=True) parser.add_argument('--mode', help='task to be done', default='train') opt = parser.parse_args("") opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if (os.path.exists(dir2save)): print('trained model already exist') else: try: os.makedirs(dir2save) except OSError: pass real = functions.read_image(opt) functions.adjust_scales2image(real, opt) train(opt, Gs, Zs, reals, NoiseAmp) SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50): #if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp): pad1 = ((opt.ker_size-1)*opt.num_layer)/2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2]-pad1*2)*scale_v nzy = (Z_opt.shape[3]-pad1*2)*scale_h images_prev = images_cur images_cur = [] for i in range(0,num_samples,1): if n == 0: z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device) z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) #I_prev = m(I_prev) #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = images_prev[i] I_prev = imresize(I_prev,1/opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp*(z_curr)+I_prev I_curr = G(z_in.detach(),I_prev) if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) n+=1 return I_curr.detach()
def SinGAN_SR(opt, Gs, Zs, reals, NoiseAmp): mode = opt.mode in_scale, iter_num = functions.calc_init_scale(opt) opt.scale_factor = 1 / in_scale opt.scale_factor_init = 1 / in_scale opt.mode = 'SR_train' #opt.alpha = 100 opt.stop_scale = 0 dir2trained_model = functions.generate_dir2save(opt) if (os.path.exists(dir2trained_model)): #print('Trained model does not exist, training SinGAN for SR') Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) opt.mode = mode else: SR_train(opt, Gs, Zs, reals, NoiseAmp) opt.mode = mode print('%f' % pow(in_scale, iter_num)) Zs_sr = [] reals_sr = [] NoiseAmp_sr = [] Gs_sr = [] real = reals[-1] #read_image(opt) for j in range(1, iter_num + 1, 1): real_ = imresize(real, pow(1 / opt.scale_factor, j), opt) real_ = real_[:, :, 0:int(pow(1 / opt.scale_factor, j) * real.shape[2]), 0:int(pow(1 / opt.scale_factor, j) * real.shape[3])] reals_sr.append(real_) Gs_sr.append(Gs[-1]) NoiseAmp_sr.append(NoiseAmp[-1]) z_opt = torch.full(real_.shape, 0, device=opt.device) m = nn.ZeroPad2d(5) z_opt = m(z_opt) Zs_sr.append(z_opt) out = SinGAN_generate(Gs_sr, Zs_sr, reals_sr, NoiseAmp_sr, opt, in_s=reals_sr[0], num_samples=1) dir2save = functions.generate_dir2save(opt) plt.imsave('%s.png' % (dir2save), functions.convert_image_np(out.detach()), vmin=0, vmax=1) return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/in_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr torch.cuda.empty_cache() return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 netD_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_d, beta_1=opt.beta1, beta_2=0.999) netG_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_g, beta_1=opt.beta1, beta_2=0.999) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if nfc_prev == opt.nfc: D_curr.load_weights('%s/%d/netD' % (opt.out_, scale_num - 1)) G_curr.load_weights('%s/%d/netG' % (opt.out_, scale_num - 1)) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, scale_num, netG_optimizer, netD_optimizer) Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) with open('%s/Zs.pkl' % (opt.out_), 'wb') as f: pickle.dump(Zs, f) with open('%s/reals.pkl' % (opt.out_), 'wb') as f: pickle.dump(reals, f) with open('%s/NoiseAmp.pkl' % (opt.out_), 'wb') as f: pickle.dump(NoiseAmp, f) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return None
def save_gif(opt, images_cur, alpha, beta): """ images_cur is a list of time series images in same scale """ dir2save = functions.generate_dir2save(opt) save_dir = os.path.join(f'{dir2save}', f'start_scale={start_scale:.2d}') try: os.makedirs(save_dir) except OSError: pass gif_path = os.path.join(save_dir, f'alpha={alpha:.2f}_beta={beta:.2f}__.gif') imageio.mimsave(gif_path, images_cur, fps=10)
def main(opt): Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if os.path.exists(dir2save): logger.info("Trained model directory already exists") else: try: os.makedirs(dir2save) except OSError: pass real = functions.read_image(opt) functions.adjust_scales2image(real, opt) train(opt, Gs, Zs, reals, NoiseAmp) SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
def main(opt, generate=True): Gs = [] Zs: List[Tuple] = [] reals1 = [] reals2 = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if (os.path.exists(dir2save)): print('trained model already exist') else: try: os.makedirs(dir2save) except OSError: pass _configure_logger(dir2save) # dump configuration file to json with open(os.path.join(f"{dir2save}", "config.json"), "w") as fp: config_dict = {k: str(v) for k, v in opt.__dict__.items()} json.dump(config_dict, fp) try: real1 = functions.read_image(opt, image_name=opt.input_name1) real2 = functions.read_image(opt, image_name=opt.input_name2) functions.adjust_scales2image(real1, opt) functions.adjust_scales2image(real2, opt) train(opt, Gs, Zs, reals1, reals2, NoiseAmp) logger.info("Done training") if generate: logger.info("Generating random samples") SinGAN_generate(Gs, Zs, reals1, reals2, NoiseAmp, opt) except Exception as e: logger.exception("Failed") raise finally: logger.info("Cleaning logger") _cleanup_logger()
def train(opt,Gs,Zs,reals1, reals2,NoiseAmp): logger.info("Starting to train...") reals1 = get_reals(reals1, opt, opt.input_name1) reals2 = get_reals(reals2, opt, opt.input_name2) in_s1 = 0 in_s2 = 0 scale_num = 0 nfc_prev = 0 while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale1.png' % (opt.outf), functions.convert_image_np(reals1[scale_num]), vmin=0, vmax=1) plt.imsave('%s/real_scale2.png' % (opt.outf), functions.convert_image_np(reals2[scale_num]), vmin=0, vmax=1) D_curr, D_mask1_curr, D_mask2_curr, G_curr = init_models(opt, reals1[len(Gs)].shape) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) D_mask1_curr.load_state_dict(torch.load('%s/%d/netD_mask1.pth' % (opt.out_, scale_num - 1))) D_mask2_curr.load_state_dict(torch.load('%s/%d/netD_mask2.pth' % (opt.out_, scale_num - 1))) logger.info(f"Starting to train scale {scale_num}") z_curr_tuple, in_s_tuple, G_curr = train_single_scale(D_curr, D_mask1_curr, D_mask2_curr,G_curr,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt, ) in_s1, in_s2 = in_s_tuple logger.info(f"Done training scale {scale_num}") G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() D_mask1_curr = functions.reset_grads(D_mask1_curr,False) D_mask1_curr.eval() D_mask2_curr = functions.reset_grads(D_mask2_curr,False) D_mask2_curr.eval() Gs.append(G_curr) Zs.append(z_curr_tuple) NoiseAmp.append((opt.noise_amp1, opt.noise_amp2)) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals1, '%s/reals1.pth' % (opt.out_)) torch.save(reals2, '%s/reals2.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc del D_curr, D_mask1_curr, D_mask2_curr,G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels #at all scales if opt.inpainting: m = functions.read_image_dir( '%s/%s_mask%s' % (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt) m = imresize(m, opt.scale1, opt) m_s = [] #pyramid of masks opt.m_s = functions.creat_reals_pyramid(m, m_s, opt) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def SinGAN_generate(Gs, Zs, reals1, reals2, NoiseAmp, opt, in_s1=None, in_s2=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=100): #if torch.is_tensor(in_s) == False: # if in_s is None: in_s = torch.full(reals1[0].shape, 0, device=opt.device) images_cur = [] # assert len(reals1) == len(Gs) def random_noise_mode(): prob = torch.rand(1) if prob < 0.5: noise_mode = NoiseMode.Z1 else: noise_mode = NoiseMode.Z2 return noise_mode noise_modes = [random_noise_mode() for _ in range(num_samples)] for G, (Z_opt1, Z_opt2), (noise_amp1, noise_amp2) in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) # assumption: same size nzx = (Z_opt1.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt1.shape[3] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] for i in range(0, num_samples, 1): z_curr = _generate_noise_for_sampling(m, n, nzx, nzy, opt, noise_modes[i]) if images_prev == []: I_prev = m(in_s) #I_prev = m(I_prev) #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) I_prev = I_prev[:, :, 0:round(scale_v * reals1[n].shape[2]), 0:round(scale_h * reals1[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) if n < gen_start_scale: zero = torch.zeros(Z_opt1.shape) if noise_modes[i] == NoiseMode.Z1: z_curr = functions.merge_noise_vectors( Z_opt1, zero, opt.noise_vectors_merge_method) elif noise_modes[i] == NoiseMode.Z2: z_curr = functions.merge_noise_vectors( zero, Z_opt2, opt.noise_vectors_merge_method) else: z_curr = functions.merge_noise_vectors( Z_opt1, Z_opt2, opt.noise_vectors_merge_method) noise_amp = noise_amp1 if noise_modes[ i] == NoiseMode.Z1 else noise_amp2 z_in = noise_amp * (z_curr) + I_prev I_curr = G(z_in.detach(), I_prev)[0] if n == len(reals1) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.exp_name, gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): print(f"Saving image: {i}") plt.imsave(f'%s/%d_{noise_modes[i].name}.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) print(f"Done Generating level: {n}") n += 1 return I_curr.detach()
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. if fine_tune: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps) else: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter) G_curr = functions.reset_grads(G_curr,False) # D_curr = functions.reset_grads(D_curr,False) G_curr.eval() # D_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # cur_params = m.bias.data.flatten() # ori_weights = torch.cat((ori_weights, cur_params)) sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) print(sparsity, ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning weights Structured or Non-structured if not structured: modules = [G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0]] parameters_to_prune = ( (G_curr.head.conv, 'weight'), (G_curr.head.norm, 'weight'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.norm, 'weight'), (G_curr.tail[0], 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'bias'), ) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=pruning_amount, ) else: modules = [G_curr.head.conv, G_curr.body.block1.conv, G_curr.body.block2.conv, G_curr.body.block3.conv] visualize_weights(modules, 'ori') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) for module in modules: m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0) # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0) torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf)) visualize_weights(modules, 'raw-prune') if cur_scale_level > 0: fake_Gs = Gs.copy() fake_Gs.append(G_curr) fake_Zs = Zs.copy() fake_Zs.append(z_curr) fake_noise = NoiseAmp.copy() fake_noise.append(opt.noise_amp) fake_reals = reals[:cur_scale_level+1].copy() prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level) # Fine-tuning if fine_tune: G_curr = functions.reset_grads(G_curr, True) G_curr.train() if not structured: # Keep training using inherited weights z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True) else: # Training from scratch # G_curr.apply(models.weights_init) # D_curr.apply(models.weights_init) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() visualize_weights(modules, 'fine-tune') for m in modules: prune.remove(m, 'weight') if not structured: prune.remove(m, 'bias') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 memory = [] ##storing memory time = [] ##storing time while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) start = datetime.datetime.now() z_curr, in_s, G_curr, mbs, percent = train_single_scale( D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) memory.append([mbs, percent]) end = datetime.datetime.now() elapsed = end - start time.append(elapsed) print(f'time: {elapsed}') G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) torch.save(full_memory, '%s/full_memory.pth' % (opt.out_)) torch.save(full_time, '%s/full_time.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_)) #torch.save(full_time, '%s/full_time.pk' % (opt.out_)) print(memory) print(time) print(full_memory) print(full_time) #pk.dump(full_memory, open('Full_memory', 'wb')) return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 # iterator through the pyramid real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min( opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128 ) # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & ( scale_num % 4 == 0 ): # every 4 scales half the iteration number! (train less for the finer details. ) opt.niter = opt.niter // 2 opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if ( nfc_prev == opt.nfc ): # if channel num match, then load the weights from last scale to init! G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) # real work G_curr = functions.reset_grads( G_curr, False) # Gnet no longer requires gradient. Thus weight is frozen! G_curr.eval() D_curr = functions.reset_grads( D_curr, False) # Note this D_curr is not the trained one...? D_curr.eval() Gs.append( G_curr ) # train append after train G at each scale. Note this G is no longer trainable! Zs.append( z_curr) # what is Zs? collection of z_opt towards current layer. NoiseAmp.append(opt.noise_amp) # Noise Amplitude is changed inside? torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image( opt) #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间 #通过norm的操作和clamp函数的功能 #print(real_) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) #开始对图像作变形 reals = functions.creat_reals_pyramid(real, reals, opt) #这个金字塔开始对图像的通道数作变化,以适应不同特征塔 nfc_prev = 0 while scale_num < opt.stop_scale + 1: #print('real_ value is {}'.format(real_)) #print('reals value is {}'.format(reals)) #print('scale_num value is {}'.format(scale_num)) #print('stop_scale value is {}'.format(opt.stop_scale)) opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #print('opt.nfc value is {}'.format(opt.nfc)) #print('opt.min_nfc value is {}'.format(opt.min_nfc)) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) #ZS噪声图 NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def generate_gif(Gs,Zs,reals,NoiseAmp,opt,alpha=0.1,beta=0.9,start_scale=2,fps=10): in_s = torch.full(Zs[0].shape, 0, device=opt.device) images_cur = [] count = 0 for G,Z_opt,noise_amp,real in zip(Gs,Zs,NoiseAmp,reals): pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) nzx = Z_opt.shape[2] nzy = Z_opt.shape[3] #pad_noise = 0 #m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) images_prev = images_cur images_cur = [] if count == 0: z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device) z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3]) z_prev1 = 0.95*Z_opt +0.05*z_rand z_prev2 = Z_opt else: z_prev1 = 0.95*Z_opt +0.05*functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) z_prev2 = Z_opt for i in range(0,100,1): if count == 0: z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device) z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3]) diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*z_rand else: diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*(functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)) z_curr = alpha*Z_opt+(1-alpha)*(z_prev1+diff_curr) z_prev2 = z_prev1 z_prev1 = z_curr if images_prev == []: I_prev = in_s else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]] I_prev = m_image(I_prev) if count < start_scale: z_curr = Z_opt z_in = noise_amp*z_curr+I_prev I_curr = G(z_in.detach(),I_prev) if (count == len(Gs)-1): I_curr = functions.denorm(I_curr).detach() I_curr = I_curr[0,:,:,:].cpu().numpy() I_curr = I_curr.transpose(1, 2, 0)*255 I_curr = I_curr.astype(np.uint8) images_cur.append(I_curr) count += 1 dir2save = functions.generate_dir2save(opt) try: os.makedirs('%s/start_scale=%d' % (dir2save,start_scale) ) except OSError: pass imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' % (dir2save,start_scale,alpha,beta),images_cur,fps=fps) del images_cur
def train(opt, Gs, Zs, reals, NoiseAmp): print('train() current parameters') print(opt) real_ = functions.read_image(opt) in_s = 0 if 'scale_num' in opt and opt.scale_num > 0: # EXPERIMENTAL: if we are in 'continue' mode in_s = torch.full(reals[0].shape, 0, device=opt.device) else: opt.scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.create_reals_pyramid(real, reals, opt) if 'nfc_prev' not in opt: opt.nfc_prev = 0 while opt.scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, opt.scale_num) try: os.makedirs(opt.outf) except OSError: print('directory %s already exists' % opt.outf) pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % opt.outf, functions.convert_image_np(reals[opt.scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if opt.nfc_prev == opt.nfc: G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, opt.scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, opt.scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % opt.out_) torch.save(Gs, '%s/Gs.pth' % opt.out_) torch.save(reals, '%s/reals.pth' % opt.out_) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % opt.out_) opt.scale_num += 1 opt.nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) # 不同规格数据形成的列表 reals = functions.creat_reals_pyramid(real, reals, opt) # print('reals', reals) # 各个scale的图形形成的列表 # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1) nfc_prev = 0 # opt.stop_scale = 9 循环9次 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.nfc', opt.nfc) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.min_nfc', opt.min_nfc) if opt.fast_training: if (scale_num > 0) & (scale_num % 4 == 0): opt.niter = opt.niter // 2 # out_是生成根路径 opt.out_ = functions.generate_dir2save(opt) # outf是每个scale路径 opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass # plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) # 保存原始图像 plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) # 在每个scale中保存real_scale plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) # return netD, netG 目前的D和G,D_curr D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) # train_single_scale()返回:z_opt, in_s, netG z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) print('G_curr', G_curr) G_curr.eval() print(G_curr.eval()) D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
# for random_samples_arbitrary_sizes: parser.add_argument('--scale_h', type=float, help='horizontal resize factor for random samples', default=1.5) parser.add_argument('--scale_v', type=float, help='vertical resize factor for random samples', default=1) opt = parser.parse_args() opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) #내가 만들어야할 폴더 이름을 반환해줌 if dir2save is None: #opt.mode가 잘못된 경우 print('task does not exist') elif (os.path.exists(dir2save)): # 이미 폴더가 있는 경우 if opt.mode == 'random_samples': print( 'random samples for image %s, start scale=%d, already exist' % (opt.input_name, opt.gen_start_scale)) elif opt.mode == 'random_samples_arbitrary_sizes': print( 'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist' % (opt.input_name, opt.scale_h, opt.scale_v)) else: try: os.makedirs(dir2save) # 폴더를 만들어줌 except OSError:
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50): #if torch.is_tensor(in_s) == False: passes = 0 if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp): pad1 = ((opt.ker_size-1)*opt.num_layer)/2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2]-pad1*2)*scale_v nzy = (Z_opt.shape[3]-pad1*2)*scale_h images_prev = images_cur images_cur = [] for i in range(0,num_samples,1): if n == 0: z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device) z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: print("in_s shape before padding with m", in_s.shape) I_prev = m(in_s) print("in_s shape after padding with m now I_prev", in_s.shape) # I_prev = m(I_prev) # I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) print("I_prev shape after upsampling using noise shape", I_prev.shape) else: I_prev = images_prev[i] I_prev = imresize(I_prev,1/opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt # real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Salt_and_Pepper_Golden_Bridge_by_night (2).png") real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Noisy_Golden_Bridge_by_night.jpg") # real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night/1.png") real = real[:,:,:,None] real = real.transpose((3,2,0,1))/255 real = torch.from_numpy(real) real = move_to_gpu(real) real = real.type(torch.cuda.FloatTensor) real = ((real - 0.5)*2).clamp(-1,1) real = real[:,0:3,:,:] # real = imresize(real,1/opt.scale_factor, opt) # real = real[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] real = m(real) # real = real[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(real,z_curr.shape[2],z_curr.shape[3]) print("I_prev",I_prev.shape) print("z_curr",z_curr.shape) print('---') # z_in = noise_amp*(z_curr)+I_prev z_in = 0*(z_curr)+I_prev I_curr = G(z_in.detach(),I_prev) if n == len(reals)-1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) # plt.imsave('%s/%d.png' % (dir2save, passes), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) # passes +=1 #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) n+=1 # plt.imsave('D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night\gen_start_scale=0\Denoised.png', functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) return I_curr.detach()
'--scale_h', "1", '--gen_start_scale', '1', '--scale_factor', '0.75' ]) opt = functions.post_config(opt) #%% opt = parser.parse_args([ '--mode', "random_samples", "--input_name", "mountains.jpg", '--scale_h', "1", '--gen_start_scale', '1', '--scale_factor', '0.75' ]) opt = functions.post_config(opt) dirlab = ",sf_0.75" for opt.gen_start_scale in range(0, 9): Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2orig = functions.generate_dir2save(opt) dir2save = dir2orig + dirlab if dir2save is None: print('task does not exist') elif (os.path.exists(dir2save)): if opt.mode == 'random_samples': print( 'random samples for image %s, start scale=%d, already exist' % (opt.input_name, opt.gen_start_scale)) elif opt.mode == 'random_samples_arbitrary_sizes': print( 'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist' % (opt.input_name, opt.scale_h, opt.scale_v)) else: try: os.makedirs(dir2orig)
default='Input/Images') parser.add_argument('--input_name', help='training image name', default="33039_LR.png") #required=True) parser.add_argument('--sr_factor', help='super resolution factor', type=float, default=4) parser.add_argument('--mode', help='task to be done', default='SR') opt = parser.parse_args() opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if dir2save is None: print('task does not exist') #elif (os.path.exists(dir2save)): # print("output already exist") else: try: os.makedirs(dir2save) except OSError: pass mode = opt.mode in_scale, iter_num = functions.calc_init_scale(opt) opt.scale_factor = 1 / in_scale opt.scale_factor_init = 1 / in_scale opt.mode = 'train'
def train(opt, Gs, Zs, reals, NoiseAmp): #from the name get the picture real_ = functions.read_image(opt) in_s = 0 scale_num = 0 #scale1 is defined from adjust2scale, saved in opt real = imresize(real_, opt.scale1, opt) # a list of resized images reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #for scale 0 to stop scale for scale_num in tqdm_notebook(range(opt.stop_scale + 1), desc=opt.input_name, leave=True): #define the number of channels in this scale opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) #define the minimum number of channels in this scale opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #the output main directory opt.out_ = functions.generate_dir2save(opt) #the output sub directory for each scale opt.outf = f'{opt.out_}/{scale_num}' #if need create the directory try: os.makedirs(opt.outf) except OSError: pass #save the resized original image for this scale plt.imsave(f'{opt.outf}/real_scale.png', functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #create the generator and discriminator D_curr, G_curr = init_models(opt) #if the number of channel of previous layer = current nfc if (nfc_prev == opt.nfc): #direct load the weightfrom last model G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) #train a single scale, get the current z, in_s, generator z_curr, in_s, G_curr = train_single_scale( D_curr, #current discriminator G_curr, #current generator reals, #the list of all resized data Gs, #generator list Zs, # a list initialized as [] in_s, # NoiseAmp, # opt #parameters ) #make current G and D untrainable,set it into eval mode G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() # save them into the list Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) #save the checkpoints torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) nfc_prev = opt.nfc #delete the D and G for memory del D_curr, G_curr return
import SinGAN.functions as functions if __name__ == '__main__': parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--model_name', help='input image name -1', required=True) parser.add_argument('--mode', help='task to be done', default='train') opt = parser.parse_args() opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if (os.path.exists(dir2save)): print('trained model already exist') else: try: os.makedirs(dir2save) except OSError: pass real = functions.read_images(opt) functions.adjust_scales2image(real, opt) train(opt, Gs, Zs, reals, NoiseAmp) SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_,opt.scale1,opt) print ("real.shape:",real.shape) #real.shape: torch.Size([1, 3, 186, 248]) reals = functions.creat_reals_pyramid(real,reals,opt) for i in reals: print (i.shape) ''' torch.Size([1, 3, 20, 27]) torch.Size([1, 3, 27, 36]) torch.Size([1, 3, 35, 47]) torch.Size([1, 3, 47, 62]) torch.Size([1, 3, 61, 82]) torch.Size([1, 3, 81, 108]) torch.Size([1, 3, 107, 143]) torch.Size([1, 3, 141, 188]) torch.Size([1, 3, 186, 248]) ''' nfc_prev = 0 print ("total %d scales.."% (opt.stop_scale)) while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & (scale_num % 4==0): opt.niter = opt.niter//2 ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) print ("out dir:",opt.outf) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #生成D和G,由于是全卷积网络,不需要关心输入大小 D_curr,G_curr = init_models(opt) if (nfc_prev==opt.nfc): #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数 G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() #保留训练后的G网络 Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc #这里将D和G网络删除,减小显存? del D_curr,G_curr return
def train(opt, Gs, Zs, reals, crops, masks, NoiseAmp): real_ = functions.read_image(opt) real = imresize(real_, opt.scale1, opt) #real, _ , _ = functions.random_crop(real, opt.crop_size) mask_ = functions.read_mask(opt) #eye_ = functions.generate_eye_mask(opt, mask_, 0) crop_ = torch.zeros( (1, 1, opt.crop_size, opt.crop_size)) #Used just for size reference when downsizing #eye_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) eye_color = functions.get_eye_color(real) opt.eye_color = eye_color #torch.autograd.set_detect_anomaly(True) in_s = 0 scale_num = 0 reals = functions.create_pyramid(real, reals, opt) masks = functions.create_pyramid(mask_, masks, opt, mode="mask") #eyes = functions.create_pyramid(eye_,eyes,opt, mode = "mask") # Shortcut to get sizes of corresponding crops for each scale crops = functions.create_pyramid(crop_, crops, opt, mode="mask") nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, crops, masks, eye_color, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50, output_image=False): #if torch.is_tensor(in_s) == False: if in_s is None: # make in_s a 0 tensor with reals[0] shape in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] #for each layers for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): #generate a pad class with width ((ker_size-1)*num_layer)/2 pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) #the shape inside padding * scale nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h #get all the previsous image images_prev = images_cur images_cur = [] output_list = [] #for the number of samples for i in range(0, num_samples, 1): if n == 0: #generate the noise z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) #broadcast to the correct shape z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) #padding it z_curr = m(z_curr) else: #generate noise with defined shape z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) #padding z_curr = m(z_curr) #if it's the first scale if images_prev == []: #use in_s as the first one I_prev = m(in_s) #I_prev = m(I_prev) #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: #get the last image I_prev = images_prev[i] #resize it by 1/scale_factor I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) # cut a piece of shape (round(scale_v * reals[n].shape[2] * round(scale_h * reals[n].shape[3])) I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] #padding I_prev = m(I_prev) #cut a piece of shape (z_curr.shape[2], z_curr.shape[3]) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] #upsample this piece to original shape, with bilinear policy I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) # amplify the z by the param, add the previous graph z_in = noise_amp * (z_curr) + I_prev # pass this value and previous graph to generator, get the value I_curr = G(z_in.detach(), I_prev) #for the last loop if n == len(reals) - 1: #generate the directory dir2save = functions.generate_dir2save(opt) #modified try: os.makedirs(dir2save) except OSError: pass # new variable if (output_image): #save the new generated image plt.imsave(f'{dir2save}/{i}.png', functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) # have the generated image into the list output_list.append(functions.convert_image_np(I_curr.detach())) images_cur.append(I_curr) n += 1 return I_curr.detach(), output_list #newly added