def generate(model_name, anchor_image=None, direction=None, transfer=None, noise_solutions=None, factor=0.25, base=None, insert_limit=4): #direction = 'L, R, T, B' parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--mode', help='random_samples | random_samples_arbitrary_sizes', default='random_samples') # for random_samples: parser.add_argument('--gen_start_scale', type=int, help='generation start scale', default=0) opt = parser.parse_args("") opt.input_name = model_name if model_name == 'islands2_basis_2.jpg': #HARDCODED opt.scale_factor = 0.6 opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] real = functions.read_image(opt) #opt.input_name = anchor #CHANGE TO ANCHOR HERE anchor = functions.read_image(opt) functions.adjust_scales2image(real, opt) Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) in_s = functions.generate_in2coarsest(reals, 1, 1, opt) array = SinGAN_anchor_generate(Gs, Zs, reals, NoiseAmp, opt, gen_start_scale=opt.gen_start_scale, anchor_image=anchor_image, direction=direction, transfer=transfer, noise_solutions=noise_solutions, factor=factor, base=base, insert_limit=insert_limit) return array
def train_model(input_name): parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--input_name', help='input image name', required=True) parser.add_argument('--mode', help='task to be done', default='train') opt = parser.parse_args("") opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if (os.path.exists(dir2save)): print('trained model already exist') else: try: os.makedirs(dir2save) except OSError: pass real = functions.read_image(opt) functions.adjust_scales2image(real, opt) train(opt, Gs, Zs, reals, NoiseAmp) SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
def __init__(self, input_name, load_existing_model): self.input_name = input_name self.load_existing_model = load_existing_model self.opt = Config(input_name) self.dir2save = functions.generate_dir2save(self.opt) self.real = functions.read_image(self.opt) functions.adjust_scales2image(self.real, self.opt) dir_exists = os.path.exists(self.dir2save) assert (not load_existing_model) or dir_exists, "cannot find trained model" if load_existing_model: print("Trained model has been loaded (not really)") self.Gs = torch.load(f'{self.dir2save}/Gs.pth') self.Zs = torch.load(f'{self.dir2save}/Zs.pth') self.reals = torch.load(f'{self.dir2save}/reals.pth') self.NoiseAmp = torch.load(f'{self.dir2save}/NoiseAmp.pth') self.is_loaded = True else: if dir_exists: user_input = input("Trained model has been found, type \"yes\" to overwrite: ") assert user_input == 'yes', "train aborted" rmtree(self.dir2save) print("train directory has been deleted") try: os.makedirs(self.dir2save) except OSError: pass
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/in_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr torch.cuda.empty_cache() return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 netD_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_d, beta_1=opt.beta1, beta_2=0.999) netG_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_g, beta_1=opt.beta1, beta_2=0.999) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if nfc_prev == opt.nfc: D_curr.load_weights('%s/%d/netD' % (opt.out_, scale_num - 1)) G_curr.load_weights('%s/%d/netG' % (opt.out_, scale_num - 1)) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, scale_num, netG_optimizer, netD_optimizer) Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) with open('%s/Zs.pkl' % (opt.out_), 'wb') as f: pickle.dump(Zs, f) with open('%s/reals.pkl' % (opt.out_), 'wb') as f: pickle.dump(reals, f) with open('%s/NoiseAmp.pkl' % (opt.out_), 'wb') as f: pickle.dump(NoiseAmp, f) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return None
def main(opt, generate=True): Gs = [] Zs: List[Tuple] = [] reals1 = [] reals2 = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if (os.path.exists(dir2save)): print('trained model already exist') else: try: os.makedirs(dir2save) except OSError: pass _configure_logger(dir2save) # dump configuration file to json with open(os.path.join(f"{dir2save}", "config.json"), "w") as fp: config_dict = {k: str(v) for k, v in opt.__dict__.items()} json.dump(config_dict, fp) try: real1 = functions.read_image(opt, image_name=opt.input_name1) real2 = functions.read_image(opt, image_name=opt.input_name2) functions.adjust_scales2image(real1, opt) functions.adjust_scales2image(real2, opt) train(opt, Gs, Zs, reals1, reals2, NoiseAmp) logger.info("Done training") if generate: logger.info("Generating random samples") SinGAN_generate(Gs, Zs, reals1, reals2, NoiseAmp, opt) except Exception as e: logger.exception("Failed") raise finally: logger.info("Cleaning logger") _cleanup_logger()
def preprocess_content_image(opt, reals,scale): real = functions.read_image(opt) functions.adjust_scales2image(real, opt) ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) if ref.shape[3] != real.shape[3]: ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]], opt) ref = ref[:, :, :real.shape[2], :real.shape[3]] N = len(reals) - 1 n = scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] return in_s
def main(opt): Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if os.path.exists(dir2save): logger.info("Trained model directory already exists") else: try: os.makedirs(dir2save) except OSError: pass real = functions.read_image(opt) functions.adjust_scales2image(real, opt) train(opt, Gs, Zs, reals, NoiseAmp) SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_,opt.scale1,opt) print ("real.shape:",real.shape) #real.shape: torch.Size([1, 3, 186, 248]) reals = functions.creat_reals_pyramid(real,reals,opt) for i in reals: print (i.shape) ''' torch.Size([1, 3, 20, 27]) torch.Size([1, 3, 27, 36]) torch.Size([1, 3, 35, 47]) torch.Size([1, 3, 47, 62]) torch.Size([1, 3, 61, 82]) torch.Size([1, 3, 81, 108]) torch.Size([1, 3, 107, 143]) torch.Size([1, 3, 141, 188]) torch.Size([1, 3, 186, 248]) ''' nfc_prev = 0 print ("total %d scales.."% (opt.stop_scale)) while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & (scale_num % 4==0): opt.niter = opt.niter//2 ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) print ("out dir:",opt.outf) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #生成D和G,由于是全卷积网络,不需要关心输入大小 D_curr,G_curr = init_models(opt) if (nfc_prev==opt.nfc): #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数 G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() #保留训练后的G网络 Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc #这里将D和G网络删除,减小显存? del D_curr,G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #If training for inpainting if opt.mode == "inpainting": #Importing mask image in space [0,255] if opt.on_drive != None: mask = img.imread('%s/%s/%s' % (opt.on_drive, opt.input_dir, opt.mask_name)) else: mask = img.imread('%s/%s' % (opt.input_dir, opt.mask_name)) #Convert mask to [O,1] space, 0 is masked out area, 1 everywhere else mask = 1 - (mask / 255) #Loading mask to torch tensor mask = torch.from_numpy(mask) #Resizing the initial mask mask = mask[:, :, :, None].view([1, 3, mask.shape[0], mask.shape[1]]) mask = imresize(mask, opt.scale1, opt) #Creating mask pyramid opt.masks = functions.creat_reals_pyramid(mask, [], opt) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) #plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, crops, masks, NoiseAmp): real_ = functions.read_image(opt) real = imresize(real_, opt.scale1, opt) #real, _ , _ = functions.random_crop(real, opt.crop_size) mask_ = functions.read_mask(opt) #eye_ = functions.generate_eye_mask(opt, mask_, 0) crop_ = torch.zeros( (1, 1, opt.crop_size, opt.crop_size)) #Used just for size reference when downsizing #eye_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) eye_color = functions.get_eye_color(real) opt.eye_color = eye_color #torch.autograd.set_detect_anomaly(True) in_s = 0 scale_num = 0 reals = functions.create_pyramid(real, reals, opt) masks = functions.create_pyramid(mask_, masks, opt, mode="mask") #eyes = functions.create_pyramid(eye_,eyes,opt, mode = "mask") # Shortcut to get sizes of corresponding crops for each scale crops = functions.create_pyramid(crop_, crops, opt, mode="mask") nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, crops, masks, eye_color, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) # 不同规格数据形成的列表 reals = functions.creat_reals_pyramid(real, reals, opt) # print('reals', reals) # 各个scale的图形形成的列表 # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1) nfc_prev = 0 # opt.stop_scale = 9 循环9次 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.nfc', opt.nfc) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.min_nfc', opt.min_nfc) if opt.fast_training: if (scale_num > 0) & (scale_num % 4 == 0): opt.niter = opt.niter // 2 # out_是生成根路径 opt.out_ = functions.generate_dir2save(opt) # outf是每个scale路径 opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass # plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) # 保存原始图像 plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) # 在每个scale中保存real_scale plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) # return netD, netG 目前的D和G,D_curr D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) # train_single_scale()返回:z_opt, in_s, netG z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) print('G_curr', G_curr) G_curr.eval() print(G_curr.eval()) D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): print('train() current parameters') print(opt) real_ = functions.read_image(opt) in_s = 0 if 'scale_num' in opt and opt.scale_num > 0: # EXPERIMENTAL: if we are in 'continue' mode in_s = torch.full(reals[0].shape, 0, device=opt.device) else: opt.scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.create_reals_pyramid(real, reals, opt) if 'nfc_prev' not in opt: opt.nfc_prev = 0 while opt.scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, opt.scale_num) try: os.makedirs(opt.outf) except OSError: print('directory %s already exists' % opt.outf) pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % opt.outf, functions.convert_image_np(reals[opt.scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if opt.nfc_prev == opt.nfc: G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, opt.scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, opt.scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % opt.out_) torch.save(Gs, '%s/Gs.pth' % opt.out_) torch.save(reals, '%s/reals.pth' % opt.out_) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % opt.out_) opt.scale_num += 1 opt.nfc_prev = opt.nfc del D_curr, G_curr return
if dir2save is None: print('task does not exist') elif (os.path.exists(dir2save)): if opt.mode == 'random_samples': print( 'random samples for image %s, start scale=%d, already exist' % (dir_name, opt.gen_start_scale)) elif opt.mode == 'random_samples_arbitrary_sizes': print( 'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist' % (dir_name, opt.scale_h, opt.scale_v)) else: try: os.makedirs(dir2save) except OSError: pass if opt.mode == 'random_samples': real1 = functions.read_image(opt, opt.input_name1) real2 = functions.read_image(opt, opt.input_name2) functions.adjust_scales2image(real1, opt) functions.adjust_scales2image(real2, opt) Gs, Zs, reals1, reals2, NoiseAmp = functions.load_trained_pyramid( opt) SinGAN_generate(Gs, Zs, reals1, reals2, NoiseAmp, opt, gen_start_scale=opt.gen_start_scale)
elif (os.path.exists(dir2save)): # 이미 폴더가 있는 경우 if opt.mode == 'random_samples': print( 'random samples for image %s, start scale=%d, already exist' % (opt.input_name, opt.gen_start_scale)) elif opt.mode == 'random_samples_arbitrary_sizes': print( 'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist' % (opt.input_name, opt.scale_h, opt.scale_v)) else: try: os.makedirs(dir2save) # 폴더를 만들어줌 except OSError: pass if opt.mode == 'random_samples': real = functions.read_image( opt) # opt.input_dir 과 opt.input_name을 이용하여 이미지를 array형식으로 받아옴 functions.adjust_scales2image(real, opt) Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) in_s = functions.generate_in2coarsest(reals, 1, 1, opt) SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, gen_start_scale=opt.gen_start_scale) elif opt.mode == 'random_samples_arbitrary_sizes': real = functions.read_image( opt) # opt.input_dir 과 opt.input_name을 이용하여 이미지를 array형식으로 받아옴 functions.adjust_scales2image(real, opt) #opt를 설정해줌 Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(
elif (os.path.exists(dir2save)): if opt.mode == 'random_samples': print( 'random samples for image %s, start scale=%d, already exist' % (opt.input_name, opt.gen_start_scale)) elif opt.mode == 'random_samples_arbitrary_sizes': print( 'random samples for image %s at size: scale_h=%f, scale_v=%f, already exist' % (opt.input_name, opt.scale_h, opt.scale_v)) else: try: os.makedirs(dir2save) except OSError: pass if opt.mode == 'random_samples': real = functions.read_image(opt) functions.adjust_scales2image(real, opt) Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) # 生成最粗糙的图像 in_s = functions.generate_in2coarsest(reals, 1, 1, opt) SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, gen_start_scale=opt.gen_start_scale) # elif opt.mode == 'random_samples_arbitrary_sizes': # real = functions.read_image(opt) # functions.adjust_scales2image(real, opt)
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # real = imresize(real_,opt.scale1,opt) # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) upsamples = [reals[0]] diffs = [reals[0]] # Need to generate opt.stop_scale, thus upsample opt.stop_scale-1 for i in range(opt.stop_scale): cur_img = reals[i] next_img = reals[i + 1] _, b, c, d = next_img.shape upsampled_real = imresize_to_shape(cur_img, (c, d, b), opt) upsamples.append(upsampled_real) diff = (next_img - upsampled_real).abs() - 1 # [-1, 1] diffs.append(diff) # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) plt.imsave('%s/diff.png' % (opt.outf), functions.convert_image_np(diffs[cur_scale_level].detach()), vmin=0, vmax=1) plt.imsave('%s/upsampled.png' % (opt.outf), functions.convert_image_np( upsamples[cur_scale_level].detach()), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) # No need to reload, since training in parallel # if (nfc_prev==opt.nfc): # G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) # D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. # z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, upsamples, cur_scale_level, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 # nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return
def get_reals(reals, opt, image_name): real_ = functions.read_image(opt, image_name) real = imresize(real_,opt.scale1,opt) reals = functions.creat_reals_pyramid(real,reals,opt) return reals
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50): real = functions.read_image(opt) real = real.numpy() real = resize(real, reals[-1].shape) real = torch.from_numpy(real) new_reals = creat_reals_pyramid(real, [], opt) buffer = [] for new_real, real in zip(new_reals, reals): ele = new_real.numpy() ele = resize(ele, real.shape) ele = torch.from_numpy(ele) buffer.append(ele) reals = buffer for i, real_img in enumerate(reals): dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) plt.imsave('%s/%s_%d.png' % (dir2save, "real", i), functions.convert_image_np(real_img.detach()), vmin=0, vmax=1) if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h # For Section IV # if n == 0: # images_prev = images_cur # else: # new_img_prev = [] # for img in images_cur: # ele = reals[n].numpy() # ele = resize(ele, img.shape) # ele = torch.from_numpy(ele) # new_img_prev.append(ele) # images_prev = new_img_prev images_prev = images_cur # if n != 0: # dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) # plt.imsave('%s/%s_%d.png' % (dir2save, "img_cur", n), functions.convert_image_np(images_prev[0].detach()), vmin=0,vmax=1) # plt.imsave('%s/%s_%d.png' % (dir2save, "img_prev", n), functions.convert_image_np(images_cur[0].detach()), vmin=0,vmax=1) images_cur = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp * (z_curr) + I_prev if opt.skip != '' and int(opt.skip) == n: I_curr = I_prev else: I_curr = G(z_in.detach(), I_prev) if n == len(reals) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) # For Section VI # if opt.mode == 'train': # dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) # else: # dir2save = functions.generate_dir2save(opt) # try: # os.makedirs(dir2save) # except OSError: # pass # if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) images_cur.append(I_curr) n += 1 return I_curr.detach()
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 # iterator through the pyramid real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min( opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128 ) # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & ( scale_num % 4 == 0 ): # every 4 scales half the iteration number! (train less for the finer details. ) opt.niter = opt.niter // 2 opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if ( nfc_prev == opt.nfc ): # if channel num match, then load the weights from last scale to init! G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) # real work G_curr = functions.reset_grads( G_curr, False) # Gnet no longer requires gradient. Thus weight is frozen! G_curr.eval() D_curr = functions.reset_grads( D_curr, False) # Note this D_curr is not the trained one...? D_curr.eval() Gs.append( G_curr ) # train append after train G at each scale. Note this G is no longer trainable! Zs.append( z_curr) # what is Zs? collection of z_opt towards current layer. NoiseAmp.append(opt.noise_amp) # Noise Amplitude is changed inside? torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. if fine_tune: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps) else: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter) G_curr = functions.reset_grads(G_curr,False) # D_curr = functions.reset_grads(D_curr,False) G_curr.eval() # D_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # cur_params = m.bias.data.flatten() # ori_weights = torch.cat((ori_weights, cur_params)) sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) print(sparsity, ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning weights Structured or Non-structured if not structured: modules = [G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0]] parameters_to_prune = ( (G_curr.head.conv, 'weight'), (G_curr.head.norm, 'weight'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.norm, 'weight'), (G_curr.tail[0], 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'bias'), ) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=pruning_amount, ) else: modules = [G_curr.head.conv, G_curr.body.block1.conv, G_curr.body.block2.conv, G_curr.body.block3.conv] visualize_weights(modules, 'ori') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) for module in modules: m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0) # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0) torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf)) visualize_weights(modules, 'raw-prune') if cur_scale_level > 0: fake_Gs = Gs.copy() fake_Gs.append(G_curr) fake_Zs = Zs.copy() fake_Zs.append(z_curr) fake_noise = NoiseAmp.copy() fake_noise.append(opt.noise_amp) fake_reals = reals[:cur_scale_level+1].copy() prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level) # Fine-tuning if fine_tune: G_curr = functions.reset_grads(G_curr, True) G_curr.train() if not structured: # Keep training using inherited weights z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True) else: # Training from scratch # G_curr.apply(models.weights_init) # D_curr.apply(models.weights_init) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() visualize_weights(modules, 'fine-tune') for m in modules: prune.remove(m, 'weight') if not structured: prune.remove(m, 'bias') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr return
Gs = [] Zs = [] reals = [] NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if dir2save is None: print('task does not exist') else: try: os.makedirs(dir2save) except OSError: pass #real = functions.read_image(opt) #real = functions.adjust_scales2image(real, opt) #Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) real1, real2 = functions.read_image(opt) real1 = functions.adjust_scales2image(real1, opt) real2 = functions.adjust_scales2image(real2, opt) Gs, Zs, reals1, reals2, NoiseAmp = functions.load_trained_pyramid(opt) if (opt.paint_start_scale < 1) | (opt.paint_start_scale > (len(Gs)-1)): print("injection scale should be between 1 and %d" % (len(Gs)-1)) else: ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) if ref.shape[3] != real1.shape[3]: ref = imresize_to_shape(ref, [real1.shape[2], real1.shape[3]], opt) ref = ref[:, :, :real1.shape[2], :real1.shape[3]] N = len(reals1) - 1 n = opt.paint_start_scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals1[n - 1].shape[2], :reals1[n - 1].shape[3]]
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels #at all scales if opt.inpainting: m = functions.read_image_dir( '%s/%s_mask%s' % (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt) m = imresize(m, opt.scale1, opt) m_s = [] #pyramid of masks opt.m_s = functions.creat_reals_pyramid(m, m_s, opt) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): #from the name get the picture real_ = functions.read_image(opt) in_s = 0 scale_num = 0 #scale1 is defined from adjust2scale, saved in opt real = imresize(real_, opt.scale1, opt) # a list of resized images reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #for scale 0 to stop scale for scale_num in tqdm_notebook(range(opt.stop_scale + 1), desc=opt.input_name, leave=True): #define the number of channels in this scale opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) #define the minimum number of channels in this scale opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #the output main directory opt.out_ = functions.generate_dir2save(opt) #the output sub directory for each scale opt.outf = f'{opt.out_}/{scale_num}' #if need create the directory try: os.makedirs(opt.outf) except OSError: pass #save the resized original image for this scale plt.imsave(f'{opt.outf}/real_scale.png', functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #create the generator and discriminator D_curr, G_curr = init_models(opt) #if the number of channel of previous layer = current nfc if (nfc_prev == opt.nfc): #direct load the weightfrom last model G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) #train a single scale, get the current z, in_s, generator z_curr, in_s, G_curr = train_single_scale( D_curr, #current discriminator G_curr, #current generator reals, #the list of all resized data Gs, #generator list Zs, # a list initialized as [] in_s, # NoiseAmp, # opt #parameters ) #make current G and D untrainable,set it into eval mode G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() # save them into the list Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) #save the checkpoints torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) nfc_prev = opt.nfc #delete the D and G for memory del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image( opt) #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间 #通过norm的操作和clamp函数的功能 #print(real_) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) #开始对图像作变形 reals = functions.creat_reals_pyramid(real, reals, opt) #这个金字塔开始对图像的通道数作变化,以适应不同特征塔 nfc_prev = 0 while scale_num < opt.stop_scale + 1: #print('real_ value is {}'.format(real_)) #print('reals value is {}'.format(reals)) #print('scale_num value is {}'.format(scale_num)) #print('stop_scale value is {}'.format(opt.stop_scale)) opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #print('opt.nfc value is {}'.format(opt.nfc)) #print('opt.min_nfc value is {}'.format(opt.min_nfc)) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) #ZS噪声图 NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def test_generate(model_name, anchor_image=None, direction=None, transfer=None, noise_solutions=None, factor=0.25, base=None, insert_limit=4): #direction = 'L, R, T, B' parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--mode', help='random_samples | random_samples_arbitrary_sizes', default='random_samples') # for random_samples: parser.add_argument('--gen_start_scale', type=int, help='generation start scale', default=0) opt = parser.parse_args("") opt.input_name = model_name opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] opt.input_name = 'island_basis_0.jpg' #grabbing image that exists... real = functions.read_image(opt) #opt.input_name = anchor #CHANGE TO ANCHOR HERE #anchor = functions.read_image(opt) functions.adjust_scales2image(real, opt) opt.input_name = 'test1.jpg' #grabbing model that we want Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) #dummy stuff for dimensions reals = [] real_ = real real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) in_s = functions.generate_in2coarsest(reals, 1, 1, opt) array = SinGAN_anchor_generate(Gs, Zs, reals, NoiseAmp, opt, gen_start_scale=opt.gen_start_scale, anchor_image=anchor_image, direction=direction, transfer=transfer, noise_solutions=noise_solutions, factor=factor, base=base, insert_limit=insert_limit) return array
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 memory = [] ##storing memory time = [] ##storing time while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) start = datetime.datetime.now() z_curr, in_s, G_curr, mbs, percent = train_single_scale( D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) memory.append([mbs, percent]) end = datetime.datetime.now() elapsed = end - start time.append(elapsed) print(f'time: {elapsed}') G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) torch.save(full_memory, '%s/full_memory.pth' % (opt.out_)) torch.save(full_time, '%s/full_time.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_)) #torch.save(full_time, '%s/full_time.pk' % (opt.out_)) print(memory) print(time) print(full_memory) print(full_time) #pk.dump(full_memory, open('Full_memory', 'wb')) return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) cur_params = m.bias.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning all weights modules = [ G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0] ] parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'weight'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'weight'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'weight'), (G_curr.tail[0], 'bias')) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, ) for m in modules: prune.remove(m, 'weight') prune.remove(m, 'bias') visualize_weights(modules, 'prune') G_curr.half() ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return