def train(opt,Gs,Zs,reals1,reals2,NoiseAmp): real1_, real2_ = functions.read_image(opt) in_s = 0 scale_num = 0 real1 = imresize(real1_,opt.scale1,opt) real2 = imresize(real2_, opt.scale1, opt) reals1 = functions.creat_reals_pyramid(real1,reals1,opt) reals2 = functions.creat_reals_pyramid(real2, reals2, opt) nfc_prev = 0 while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real1_scale.png' % (opt.outf), functions.convert_image_np(reals1[scale_num]), vmin=0, vmax=1) plt.imsave('%s/real2_scale.png' % (opt.outf), functions.convert_image_np(reals2[scale_num]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals1,reals2,Gs,Zs,in_s,NoiseAmp,opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals1, '%s/reals1.pth' % (opt.out_)) torch.save(reals2, '%s/reals2.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc del D_curr,G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr torch.cuda.empty_cache() return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 netD_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_d, beta_1=opt.beta1, beta_2=0.999) netG_optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr_g, beta_1=opt.beta1, beta_2=0.999) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if nfc_prev == opt.nfc: D_curr.load_weights('%s/%d/netD' % (opt.out_, scale_num - 1)) G_curr.load_weights('%s/%d/netG' % (opt.out_, scale_num - 1)) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, scale_num, netG_optimizer, netD_optimizer) Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) with open('%s/Zs.pkl' % (opt.out_), 'wb') as f: pickle.dump(Zs, f) with open('%s/reals.pkl' % (opt.out_), 'wb') as f: pickle.dump(reals, f) with open('%s/NoiseAmp.pkl' % (opt.out_), 'wb') as f: pickle.dump(NoiseAmp, f) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return None
def test_pyramid(images): parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') #parser.add_argument('--input_name', help='input image name', required=True) parser.add_argument('--mode', help='task to be done', default='train') opt = parser.parse_args("") opt.input_name = 'blank' opt = functions.post_config(opt) real = functions.np2torch(images[0], opt) functions.adjust_scales2image(real, opt) all_reals = [] for image in images: reals = [] real_ = functions.np2torch(image, opt) real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) all_reals.append(reals) return np.array(all_reals).T
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) # 不同规格数据形成的列表 reals = functions.creat_reals_pyramid(real, reals, opt) # print('reals', reals) # 各个scale的图形形成的列表 # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1) nfc_prev = 0 # opt.stop_scale = 9 循环9次 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.nfc', opt.nfc) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.min_nfc', opt.min_nfc) if opt.fast_training: if (scale_num > 0) & (scale_num % 4 == 0): opt.niter = opt.niter // 2 # out_是生成根路径 opt.out_ = functions.generate_dir2save(opt) # outf是每个scale路径 opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass # plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) # 保存原始图像 plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) # 在每个scale中保存real_scale plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) # return netD, netG 目前的D和G,D_curr D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) # train_single_scale()返回:z_opt, in_s, netG z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) print('G_curr', G_curr) G_curr.eval() print(G_curr.eval()) D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def test_generate(model_name, anchor_image=None, direction=None, transfer=None, noise_solutions=None, factor=0.25, base=None, insert_limit=4): #direction = 'L, R, T, B' parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--mode', help='random_samples | random_samples_arbitrary_sizes', default='random_samples') # for random_samples: parser.add_argument('--gen_start_scale', type=int, help='generation start scale', default=0) opt = parser.parse_args("") opt.input_name = model_name opt = functions.post_config(opt) Gs = [] Zs = [] reals = [] NoiseAmp = [] opt.input_name = 'island_basis_0.jpg' #grabbing image that exists... real = functions.read_image(opt) #opt.input_name = anchor #CHANGE TO ANCHOR HERE #anchor = functions.read_image(opt) functions.adjust_scales2image(real, opt) opt.input_name = 'test1.jpg' #grabbing model that we want Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) #dummy stuff for dimensions reals = [] real_ = real real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) in_s = functions.generate_in2coarsest(reals, 1, 1, opt) array = SinGAN_anchor_generate(Gs, Zs, reals, NoiseAmp, opt, gen_start_scale=opt.gen_start_scale, anchor_image=anchor_image, direction=direction, transfer=transfer, noise_solutions=noise_solutions, factor=factor, base=base, insert_limit=insert_limit) return array
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) cur_params = m.bias.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning all weights modules = [ G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0] ] parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'weight'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'weight'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'weight'), (G_curr.tail[0], 'bias')) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, ) for m in modules: prune.remove(m, 'weight') prune.remove(m, 'bias') visualize_weights(modules, 'prune') G_curr.half() ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & (scale_num % 4 == 0): opt.niter = opt.niter // 2 ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def get_reals(reals, opt, image_name): real_ = functions.read_image(opt, image_name) real = imresize(real_,opt.scale1,opt) reals = functions.creat_reals_pyramid(real,reals,opt) return reals
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) print('scale_num:', len(reals)) for _reals in reals: print('image_size:', _reals.size()) nfc_prev = 0 errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, errD2plot, errG2plot, D_real2plot, D_fake2plot, z_opt2plot, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr functions.my_plot(errD2plot, errG2plot, z_opt2plot, opt) return
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50): real = functions.read_image(opt) real = real.numpy() real = resize(real, reals[-1].shape) real = torch.from_numpy(real) new_reals = creat_reals_pyramid(real, [], opt) buffer = [] for new_real, real in zip(new_reals, reals): ele = new_real.numpy() ele = resize(ele, real.shape) ele = torch.from_numpy(ele) buffer.append(ele) reals = buffer for i, real_img in enumerate(reals): dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) plt.imsave('%s/%s_%d.png' % (dir2save, "real", i), functions.convert_image_np(real_img.detach()), vmin=0, vmax=1) if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h # For Section IV # if n == 0: # images_prev = images_cur # else: # new_img_prev = [] # for img in images_cur: # ele = reals[n].numpy() # ele = resize(ele, img.shape) # ele = torch.from_numpy(ele) # new_img_prev.append(ele) # images_prev = new_img_prev images_prev = images_cur # if n != 0: # dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) # plt.imsave('%s/%s_%d.png' % (dir2save, "img_cur", n), functions.convert_image_np(images_prev[0].detach()), vmin=0,vmax=1) # plt.imsave('%s/%s_%d.png' % (dir2save, "img_prev", n), functions.convert_image_np(images_cur[0].detach()), vmin=0,vmax=1) images_cur = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp * (z_curr) + I_prev if opt.skip != '' and int(opt.skip) == n: I_curr = I_prev else: I_curr = G(z_in.detach(), I_prev) if n == len(reals) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) # For Section VI # if opt.mode == 'train': # dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) # else: # dir2save = functions.generate_dir2save(opt) # try: # os.makedirs(dir2save) # except OSError: # pass # if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) images_cur.append(I_curr) n += 1 return I_curr.detach()
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #If training for inpainting if opt.mode == "inpainting": #Importing mask image in space [0,255] if opt.on_drive != None: mask = img.imread('%s/%s/%s' % (opt.on_drive, opt.input_dir, opt.mask_name)) else: mask = img.imread('%s/%s' % (opt.input_dir, opt.mask_name)) #Convert mask to [O,1] space, 0 is masked out area, 1 everywhere else mask = 1 - (mask / 255) #Loading mask to torch tensor mask = torch.from_numpy(mask) #Resizing the initial mask mask = mask[:, :, :, None].view([1, 3, mask.shape[0], mask.shape[1]]) mask = imresize(mask, opt.scale1, opt) #Creating mask pyramid opt.masks = functions.creat_reals_pyramid(mask, [], opt) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) #plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # real = imresize(real_,opt.scale1,opt) # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) upsamples = [reals[0]] diffs = [reals[0]] # Need to generate opt.stop_scale, thus upsample opt.stop_scale-1 for i in range(opt.stop_scale): cur_img = reals[i] next_img = reals[i + 1] _, b, c, d = next_img.shape upsampled_real = imresize_to_shape(cur_img, (c, d, b), opt) upsamples.append(upsampled_real) diff = (next_img - upsampled_real).abs() - 1 # [-1, 1] diffs.append(diff) # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) plt.imsave('%s/diff.png' % (opt.outf), functions.convert_image_np(diffs[cur_scale_level].detach()), vmin=0, vmax=1) plt.imsave('%s/upsampled.png' % (opt.outf), functions.convert_image_np( upsamples[cur_scale_level].detach()), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) # No need to reload, since training in parallel # if (nfc_prev==opt.nfc): # G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) # D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. # z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, upsamples, cur_scale_level, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 # nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_,opt.scale1,opt) print ("real.shape:",real.shape) #real.shape: torch.Size([1, 3, 186, 248]) reals = functions.creat_reals_pyramid(real,reals,opt) for i in reals: print (i.shape) ''' torch.Size([1, 3, 20, 27]) torch.Size([1, 3, 27, 36]) torch.Size([1, 3, 35, 47]) torch.Size([1, 3, 47, 62]) torch.Size([1, 3, 61, 82]) torch.Size([1, 3, 81, 108]) torch.Size([1, 3, 107, 143]) torch.Size([1, 3, 141, 188]) torch.Size([1, 3, 186, 248]) ''' nfc_prev = 0 print ("total %d scales.."% (opt.stop_scale)) while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & (scale_num % 4==0): opt.niter = opt.niter//2 ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) print ("out dir:",opt.outf) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #生成D和G,由于是全卷积网络,不需要关心输入大小 D_curr,G_curr = init_models(opt) if (nfc_prev==opt.nfc): #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数 G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() #保留训练后的G网络 Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc #这里将D和G网络删除,减小显存? del D_curr,G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): #from the name get the picture real_ = functions.read_image(opt) in_s = 0 scale_num = 0 #scale1 is defined from adjust2scale, saved in opt real = imresize(real_, opt.scale1, opt) # a list of resized images reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #for scale 0 to stop scale for scale_num in tqdm_notebook(range(opt.stop_scale + 1), desc=opt.input_name, leave=True): #define the number of channels in this scale opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) #define the minimum number of channels in this scale opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #the output main directory opt.out_ = functions.generate_dir2save(opt) #the output sub directory for each scale opt.outf = f'{opt.out_}/{scale_num}' #if need create the directory try: os.makedirs(opt.outf) except OSError: pass #save the resized original image for this scale plt.imsave(f'{opt.outf}/real_scale.png', functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #create the generator and discriminator D_curr, G_curr = init_models(opt) #if the number of channel of previous layer = current nfc if (nfc_prev == opt.nfc): #direct load the weightfrom last model G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) #train a single scale, get the current z, in_s, generator z_curr, in_s, G_curr = train_single_scale( D_curr, #current discriminator G_curr, #current generator reals, #the list of all resized data Gs, #generator list Zs, # a list initialized as [] in_s, # NoiseAmp, # opt #parameters ) #make current G and D untrainable,set it into eval mode G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() # save them into the list Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) #save the checkpoints torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) nfc_prev = opt.nfc #delete the D and G for memory del D_curr, G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. if fine_tune: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps) else: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter) G_curr = functions.reset_grads(G_curr,False) # D_curr = functions.reset_grads(D_curr,False) G_curr.eval() # D_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # cur_params = m.bias.data.flatten() # ori_weights = torch.cat((ori_weights, cur_params)) sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) print(sparsity, ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning weights Structured or Non-structured if not structured: modules = [G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0]] parameters_to_prune = ( (G_curr.head.conv, 'weight'), (G_curr.head.norm, 'weight'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.norm, 'weight'), (G_curr.tail[0], 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'bias'), ) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=pruning_amount, ) else: modules = [G_curr.head.conv, G_curr.body.block1.conv, G_curr.body.block2.conv, G_curr.body.block3.conv] visualize_weights(modules, 'ori') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) for module in modules: m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0) # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0) torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf)) visualize_weights(modules, 'raw-prune') if cur_scale_level > 0: fake_Gs = Gs.copy() fake_Gs.append(G_curr) fake_Zs = Zs.copy() fake_Zs.append(z_curr) fake_noise = NoiseAmp.copy() fake_noise.append(opt.noise_amp) fake_reals = reals[:cur_scale_level+1].copy() prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level) # Fine-tuning if fine_tune: G_curr = functions.reset_grads(G_curr, True) G_curr.train() if not structured: # Keep training using inherited weights z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True) else: # Training from scratch # G_curr.apply(models.weights_init) # D_curr.apply(models.weights_init) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() visualize_weights(modules, 'fine-tune') for m in modules: prune.remove(m, 'weight') if not structured: prune.remove(m, 'bias') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 # iterator through the pyramid real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min( opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128 ) # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & ( scale_num % 4 == 0 ): # every 4 scales half the iteration number! (train less for the finer details. ) opt.niter = opt.niter // 2 opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if ( nfc_prev == opt.nfc ): # if channel num match, then load the weights from last scale to init! G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) # real work G_curr = functions.reset_grads( G_curr, False) # Gnet no longer requires gradient. Thus weight is frozen! G_curr.eval() D_curr = functions.reset_grads( D_curr, False) # Note this D_curr is not the trained one...? D_curr.eval() Gs.append( G_curr ) # train append after train G at each scale. Note this G is no longer trainable! Zs.append( z_curr) # what is Zs? collection of z_opt towards current layer. NoiseAmp.append(opt.noise_amp) # Noise Amplitude is changed inside? torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels #at all scales if opt.inpainting: m = functions.read_image_dir( '%s/%s_mask%s' % (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt) m = imresize(m, opt.scale1, opt) m_s = [] #pyramid of masks opt.m_s = functions.creat_reals_pyramid(m, m_s, opt) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def SinGAN_anchor_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=1, anchor_image=None, direction=None, transfer=None, noise_solutions=None, factor=None, base=None, insert_limit=0): #### Loading in Anchor if Needed ##### anchor = anchor_image if anchor is not None: anchors = [] anchor = functions.np2torch(anchor_image, opt) anchor_ = imresize(anchor, opt.scale1, opt) anchors = functions.creat_reals_pyramid(anchor_, anchors, opt) #high key hacky code if direction is not None: directions = [] direction = functions.np2torch(direction, opt) direction_ = imresize(direction, opt.scale1, opt) directions = functions.creat_reals_pyramid(direction_, directions, opt) #high key hacky code if base is not None: bases = [] base = functions.np2torch(base, opt) base_ = imresize(base, opt.scale1, opt) bases = functions.creat_reals_pyramid(base_, bases, opt) #high key hacky code #### MY CODE #### #if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] for i in range(0, num_samples, 1): if n == 0: #COARSEST SCALE z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) z_orig = z_curr if images_prev == []: #FIRST GENERATION IN COARSEST SCALE I_prev = m(in_s) else: #NOT FIRST GENERATION, BUT AT COARSEST SCALE I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) #upscale #print(n) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling( I_prev, z_curr.shape[2], z_curr.shape[3]) #make it fit padded noise else: #prev_before = I_prev #MY ADDITION I_prev = m(I_prev) if n < gen_start_scale: #anything less than final z_curr = Z_opt #Z_opt comes from trained pyramid.... z_in = noise_amp * (z_curr) + I_prev if noise_solutions is not None: z_curr = noise_solutions[n] z_in = (1 - factor) * noise_amp * ( z_curr ) + I_prev + factor * noise_amp * z_orig #adds in previous image to z_opt''' I_curr = G(z_in.detach(), I_prev) if base is not None: if n == insert_limit: I_curr = bases[n] * factor + I_curr * (1 - factor) if anchor is not None and direction is not None: anchor_curr = anchors[n] I_curr = reinforcement(anchor_curr, I_curr, directions[n]) #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n) ###### ENFORCE LH = ANCHOR FOR IMAGE ####### if n == opt.stop_scale: #hacky code if anchor is not None and direction is not None: anchor_curr = anchors[n] I_curr = reinforcement(anchor_curr, I_curr, direction) #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n) array = functions.convert_image_np(I_curr.detach()) images_cur.append(I_curr) n += 1 return array
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image( opt) #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间 #通过norm的操作和clamp函数的功能 #print(real_) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) #开始对图像作变形 reals = functions.creat_reals_pyramid(real, reals, opt) #这个金字塔开始对图像的通道数作变化,以适应不同特征塔 nfc_prev = 0 while scale_num < opt.stop_scale + 1: #print('real_ value is {}'.format(real_)) #print('reals value is {}'.format(reals)) #print('scale_num value is {}'.format(scale_num)) #print('stop_scale value is {}'.format(opt.stop_scale)) opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #print('opt.nfc value is {}'.format(opt.nfc)) #print('opt.min_nfc value is {}'.format(opt.min_nfc)) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) #ZS噪声图 NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def invert_model(test_image, model_name, scales2invert=None, penalty=1e-3, show=True): '''test_image is an array, model_name is a name''' Noise_Solutions = [] parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--mode', default='RandomSamples') opt = parser.parse_args("") opt.input_name = model_name opt.reg = penalty if model_name == 'islands2_basis_2.jpg': #HARDCODED opt.scale_factor = 0.6 opt = functions.post_config(opt) ### Loading in Generators Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) for G in Gs: G = functions.reset_grads(G, False) G.eval() ### Loading in Ground Truth Test Images reals = [] #deleting old real images real = functions.np2torch(test_image, opt) functions.adjust_scales2image(real, opt) real_ = functions.np2torch(test_image, opt) real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) ### General Padding pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_image = nn.ZeroPad2d(int(pad_image)) I_prev = None REC_ERROR = 0 if scales2invert is None: scales2invert = opt.stop_scale + 1 for scale in range(scales2invert): #for scale in range(3): #Get X, G X = reals[scale] G = Gs[scale] noise_amp = NoiseAmp[scale] #Defining Dimensions opt.nc_z = X.shape[1] opt.nzx = X.shape[2] opt.nzy = X.shape[3] #getting parameters for prior distribution penalty pdf = torch.distributions.Normal(0, 1) alpha = opt.reg #alpha = 1e-2 #Defining Z if scale == 0: z_init = functions.generate_noise( [1, opt.nzx, opt.nzy], device=opt.device) #only 1D noise else: z_init = functions.generate_noise( [3, opt.nzx, opt.nzy], device=opt.device) #otherwise move up to 3d noise z_init = Variable(z_init.cuda(), requires_grad=True) #variable to optimize #Building I_prev if I_prev == None: #first scale scenario in_s = torch.full(reals[0].shape, 0, device=opt.device) #all zeros I_prev = in_s I_prev = m_image(I_prev) #padding else: #otherwise take the output from the previous scale and upsample I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) #upsamples I_prev = m_image(I_prev) I_prev = I_prev[:, :, 0:X.shape[2] + 10, 0:X.shape[ 3] + 10] #making sure that precision errors don't mess anything up I_prev = functions.upsampling(I_prev, X.shape[2] + 10, X.shape[3] + 10) #seems to be redundant LR = [2e-3, 2e-2, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1] Zoptimizer = torch.optim.RMSprop([z_init], lr=LR[scale]) #Defining Optimizer x_loss = [] #for plotting epochs = [] #for plotting niter = [ 200, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400 ] for epoch in range(niter[scale]): #Gradient Descent on Z if scale == 0: noise_input = m_noise(z_init.expand(1, 3, opt.nzx, opt.nzy)) #expand and padd else: noise_input = m_noise(z_init) #padding z_in = noise_amp * noise_input + I_prev G_z = G(z_in, I_prev) x_recLoss = F.mse_loss(G_z, X) #MSE loss logProb = pdf.log_prob(z_init).mean() #Gaussian loss loss = x_recLoss - (alpha * logProb.mean()) Zoptimizer.zero_grad() loss.backward() Zoptimizer.step() #losses['rec'].append(x_recLoss.data[0]) #print('Image loss: [%d] loss: %0.5f' % (epoch, x_recLoss.item())) #print('Noise loss: [%d] loss: %0.5f' % (epoch, z_recLoss.item())) x_loss.append(loss.item()) epochs.append(epoch) REC_ERROR = x_recLoss if show: plt.plot(epochs, x_loss, label='x_loss') plt.legend() plt.show() I_prev = G_z.detach( ) #take final output, maybe need to edit this line something's very very fishy _ = show_image(X, show, 'target') reconstructed_image = show_image(I_prev, show, 'output') _ = show_image(noise_input.detach().cpu(), show, 'noise') Noise_Solutions.append(noise_input.detach()) return Noise_Solutions, reconstructed_image, REC_ERROR
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 memory = [] ##storing memory time = [] ##storing time while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) start = datetime.datetime.now() z_curr, in_s, G_curr, mbs, percent = train_single_scale( D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) memory.append([mbs, percent]) end = datetime.datetime.now() elapsed = end - start time.append(elapsed) print(f'time: {elapsed}') G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) torch.save(full_memory, '%s/full_memory.pth' % (opt.out_)) torch.save(full_time, '%s/full_time.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_)) #torch.save(full_time, '%s/full_time.pk' % (opt.out_)) print(memory) print(time) print(full_memory) print(full_time) #pk.dump(full_memory, open('Full_memory', 'wb')) return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_images(opt) in_s = 0 scale_num = 0 real = [ imresize(real_[0], opt.scale1, opt), imresize(real_[1], opt.scale1, opt) ] reals = [ functions.creat_reals_pyramid(real[0], reals, opt), functions.creat_reals_pyramid(real[1], reals, opt) ] nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale1.png' % (opt.outf), functions.convert_image_np(reals[0][scale_num]), vmin=0, vmax=1) plt.imsave('%s/real_scale2.png' % (opt.outf), functions.convert_image_np(reals[1][scale_num]), vmin=0, vmax=1) D_curr1, G_curr1 = init_models(opt) D_curr2, G_curr2 = init_models(opt) D_curr = [D_curr1, D_curr2] G_curr = [G_curr1, G_curr2] if (nfc_prev == opt.nfc): G_curr[0].load_state_dict( torch.load('%s/%d/netG1.pth' % (opt.out_, scale_num - 1))) D_curr[0].load_state_dict( torch.load('%s/%d/netD1.pth' % (opt.out_, scale_num - 1))) G_curr[1].load_state_dict( torch.load('%s/%d/netG2.pth' % (opt.out_, scale_num - 1))) D_curr[1].load_state_dict( torch.load('%s/%d/netD2.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr[0] = functions.reset_grads(G_curr[0], False) G_curr[0].eval() D_curr[0] = functions.reset_grads(D_curr[0], False) D_curr[0].eval() G_curr[1] = functions.reset_grads(G_curr[1], False) G_curr[1].eval() D_curr[1] = functions.reset_grads(D_curr[1], False) D_curr[1].eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return