def train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale): in_s = torch.full(reals[0].shape, 0, device=opt.device) scale_num = 0 nfc_prev = 0 while scale_num < opt.stop_scale + 1: if scale_num != paint_inject_scale: scale_num += 1 nfc_prev = opt.nfc continue else: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/in_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals[:scale_num + 1], Gs[:scale_num], Zs[:scale_num], in_s, NoiseAmp[:scale_num], opt, centers=centers) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs[scale_num] = G_curr Zs[scale_num] = z_curr NoiseAmp[scale_num] = opt.noise_amp torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/in_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr torch.cuda.empty_cache() return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image( opt) #将输入的png图像转变为行 列 通道 像素这样的tensor之后,作归一化,值都在[-1,1]之间 #通过norm的操作和clamp函数的功能 #print(real_) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) #开始对图像作变形 reals = functions.creat_reals_pyramid(real, reals, opt) #这个金字塔开始对图像的通道数作变化,以适应不同特征塔 nfc_prev = 0 while scale_num < opt.stop_scale + 1: #print('real_ value is {}'.format(real_)) #print('reals value is {}'.format(reals)) #print('scale_num value is {}'.format(scale_num)) #print('stop_scale value is {}'.format(opt.stop_scale)) opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #print('opt.nfc value is {}'.format(opt.nfc)) #print('opt.min_nfc value is {}'.format(opt.min_nfc)) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) #ZS噪声图 NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt,Gs,Zs,reals1, reals2,NoiseAmp): logger.info("Starting to train...") reals1 = get_reals(reals1, opt, opt.input_name1) reals2 = get_reals(reals2, opt, opt.input_name2) in_s1 = 0 in_s2 = 0 scale_num = 0 nfc_prev = 0 while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale1.png' % (opt.outf), functions.convert_image_np(reals1[scale_num]), vmin=0, vmax=1) plt.imsave('%s/real_scale2.png' % (opt.outf), functions.convert_image_np(reals2[scale_num]), vmin=0, vmax=1) D_curr, D_mask1_curr, D_mask2_curr, G_curr = init_models(opt, reals1[len(Gs)].shape) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) D_mask1_curr.load_state_dict(torch.load('%s/%d/netD_mask1.pth' % (opt.out_, scale_num - 1))) D_mask2_curr.load_state_dict(torch.load('%s/%d/netD_mask2.pth' % (opt.out_, scale_num - 1))) logger.info(f"Starting to train scale {scale_num}") z_curr_tuple, in_s_tuple, G_curr = train_single_scale(D_curr, D_mask1_curr, D_mask2_curr,G_curr,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt, ) in_s1, in_s2 = in_s_tuple logger.info(f"Done training scale {scale_num}") G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() D_mask1_curr = functions.reset_grads(D_mask1_curr,False) D_mask1_curr.eval() D_mask2_curr = functions.reset_grads(D_mask2_curr,False) D_mask2_curr.eval() Gs.append(G_curr) Zs.append(z_curr_tuple) NoiseAmp.append((opt.noise_amp1, opt.noise_amp2)) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals1, '%s/reals1.pth' % (opt.out_)) torch.save(reals2, '%s/reals2.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc del D_curr, D_mask1_curr, D_mask2_curr,G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #creating a pyramid of masks the same way we did for the img and thus to train on only the correct pixels #at all scales if opt.inpainting: m = functions.read_image_dir( '%s/%s_mask%s' % (opt.ref_dir, opt.input_name[:-4], opt.input_name[-4:]), opt) m = imresize(m, opt.scale1, opt) m_s = [] #pyramid of masks opt.m_s = functions.creat_reals_pyramid(m, m_s, opt) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_,reals,opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale+1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr,G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev==opt.nfc): G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. if fine_tune: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, warmup_steps) else: z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter) G_curr = functions.reset_grads(G_curr,False) # D_curr = functions.reset_grads(D_curr,False) G_curr.eval() # D_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # cur_params = m.bias.data.flatten() # ori_weights = torch.cat((ori_weights, cur_params)) sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) print(sparsity, ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning weights Structured or Non-structured if not structured: modules = [G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0]] parameters_to_prune = ( (G_curr.head.conv, 'weight'), (G_curr.head.norm, 'weight'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.norm, 'weight'), (G_curr.tail[0], 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'bias'), ) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=pruning_amount, ) else: modules = [G_curr.head.conv, G_curr.body.block1.conv, G_curr.body.block2.conv, G_curr.body.block3.conv] visualize_weights(modules, 'ori') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) for module in modules: m = prune.ln_structured(module, name="weight", amount=pruning_amount, n=1, dim=0) # m = prune.ln_structured(module, name="bias", amount=pruning_amount, n=1, dim=0) torch.save(G_curr.state_dict(), '%s/raw_prune_netG.pth' % (opt.outf)) visualize_weights(modules, 'raw-prune') if cur_scale_level > 0: fake_Gs = Gs.copy() fake_Gs.append(G_curr) fake_Zs = Zs.copy() fake_Zs.append(z_curr) fake_noise = NoiseAmp.copy() fake_noise.append(opt.noise_amp) fake_reals = reals[:cur_scale_level+1].copy() prune_SinGAN_generate(fake_Gs, fake_Zs, fake_reals, fake_noise, opt, gen_start_scale=0, num_samples=1, level=cur_scale_level) # Fine-tuning if fine_tune: G_curr = functions.reset_grads(G_curr, True) G_curr.train() if not structured: # Keep training using inherited weights z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter - warmup_steps, prune=True) else: # Training from scratch # G_curr.apply(models.weights_init) # D_curr.apply(models.weights_init) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, opt.niter, prune=True) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() visualize_weights(modules, 'fine-tune') for m in modules: prune.remove(m, 'weight') if not structured: prune.remove(m, 'bias') # pytorch_total_params = sum(p.numel() for p in G_curr.parameters()) # print(pytorch_total_params) ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level+=1 nfc_prev = opt.nfc del D_curr,G_curr return
def train(opt,Gs,Zs,reals,NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_,opt.scale1,opt) print ("real.shape:",real.shape) #real.shape: torch.Size([1, 3, 186, 248]) reals = functions.creat_reals_pyramid(real,reals,opt) for i in reals: print (i.shape) ''' torch.Size([1, 3, 20, 27]) torch.Size([1, 3, 27, 36]) torch.Size([1, 3, 35, 47]) torch.Size([1, 3, 47, 62]) torch.Size([1, 3, 61, 82]) torch.Size([1, 3, 81, 108]) torch.Size([1, 3, 107, 143]) torch.Size([1, 3, 141, 188]) torch.Size([1, 3, 186, 248]) ''' nfc_prev = 0 print ("total %d scales.."% (opt.stop_scale)) while scale_num<opt.stop_scale+1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & (scale_num % 4==0): opt.niter = opt.niter//2 ''' if (scale_num == opt.stop_scale): opt.nfc = 128 opt.min_nfc = 128 ''' opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) print ("out dir:",opt.outf) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #生成D和G,由于是全卷积网络,不需要关心输入大小 D_curr,G_curr = init_models(opt) if (nfc_prev==opt.nfc): #如果两次网络中的层数一样,则可以finetune上一个scale的网络参数 G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,scale_num-1))) D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) #对该scale的网络进行训练,这里每训练一个scale的网络就换,保持显存占用一直很低,不到3g z_curr,in_s,G_curr = train_single_scale(D_curr,G_curr,reals,Gs,Zs,in_s,NoiseAmp,opt) G_curr = functions.reset_grads(G_curr,False) G_curr.eval() D_curr = functions.reset_grads(D_curr,False) D_curr.eval() #保留训练后的G网络 Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num+=1 nfc_prev = opt.nfc #这里将D和G网络删除,减小显存? del D_curr,G_curr return
def train(opt, Gs, Zs, reals, crops, masks, NoiseAmp): real_ = functions.read_image(opt) real = imresize(real_, opt.scale1, opt) #real, _ , _ = functions.random_crop(real, opt.crop_size) mask_ = functions.read_mask(opt) #eye_ = functions.generate_eye_mask(opt, mask_, 0) crop_ = torch.zeros( (1, 1, opt.crop_size, opt.crop_size)) #Used just for size reference when downsizing #eye_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) eye_color = functions.get_eye_color(real) opt.eye_color = eye_color #torch.autograd.set_detect_anomaly(True) in_s = 0 scale_num = 0 reals = functions.create_pyramid(real, reals, opt) masks = functions.create_pyramid(mask_, masks, opt, mode="mask") #eyes = functions.create_pyramid(eye_,eyes,opt, mode = "mask") # Shortcut to get sizes of corresponding crops for each scale crops = functions.create_pyramid(crop_, crops, opt, mode="mask") nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, crops, masks, eye_color, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # real = imresize(real_,opt.scale1,opt) # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) upsamples = [reals[0]] diffs = [reals[0]] # Need to generate opt.stop_scale, thus upsample opt.stop_scale-1 for i in range(opt.stop_scale): cur_img = reals[i] next_img = reals[i + 1] _, b, c, d = next_img.shape upsampled_real = imresize_to_shape(cur_img, (c, d, b), opt) upsamples.append(upsampled_real) diff = (next_img - upsampled_real).abs() - 1 # [-1, 1] diffs.append(diff) # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) plt.imsave('%s/diff.png' % (opt.outf), functions.convert_image_np(diffs[cur_scale_level].detach()), vmin=0, vmax=1) plt.imsave('%s/upsampled.png' % (opt.outf), functions.convert_image_np( upsamples[cur_scale_level].detach()), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) # No need to reload, since training in parallel # if (nfc_prev==opt.nfc): # G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) # D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. # z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, upsamples, cur_scale_level, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 # nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) # 不同规格数据形成的列表 reals = functions.creat_reals_pyramid(real, reals, opt) # print('reals', reals) # 各个scale的图形形成的列表 # plt.imsave('Output/real_scale_0.png', functions.convert_image_np(reals[0]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_1.png', functions.convert_image_np(reals[1]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_2.png', functions.convert_image_np(reals[2]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_3.png', functions.convert_image_np(reals[3]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_4.png', functions.convert_image_np(reals[4]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_5.png', functions.convert_image_np(reals[5]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_6.png', functions.convert_image_np(reals[6]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_7.png', functions.convert_image_np(reals[7]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_8.png', functions.convert_image_np(reals[8]), vmin=0, vmax=1) # plt.imsave('Output/real_scale_9.png', functions.convert_image_np(reals[9]), vmin=0, vmax=1) nfc_prev = 0 # opt.stop_scale = 9 循环9次 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.nfc', opt.nfc) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) print('opt.min_nfc', opt.min_nfc) if opt.fast_training: if (scale_num > 0) & (scale_num % 4 == 0): opt.niter = opt.niter // 2 # out_是生成根路径 opt.out_ = functions.generate_dir2save(opt) # outf是每个scale路径 opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass # plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) # 保存原始图像 plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) # 在每个scale中保存real_scale plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) # return netD, netG 目前的D和G,D_curr D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) # train_single_scale()返回:z_opt, in_s, netG z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) print('G_curr', G_curr) G_curr.eval() print(G_curr.eval()) D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def invert_model(test_image, model_name, scales2invert=None, penalty=1e-3, show=True): '''test_image is an array, model_name is a name''' Noise_Solutions = [] parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--mode', default='RandomSamples') opt = parser.parse_args("") opt.input_name = model_name opt.reg = penalty if model_name == 'islands2_basis_2.jpg': #HARDCODED opt.scale_factor = 0.6 opt = functions.post_config(opt) ### Loading in Generators Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) for G in Gs: G = functions.reset_grads(G, False) G.eval() ### Loading in Ground Truth Test Images reals = [] #deleting old real images real = functions.np2torch(test_image, opt) functions.adjust_scales2image(real, opt) real_ = functions.np2torch(test_image, opt) real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) ### General Padding pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_image = nn.ZeroPad2d(int(pad_image)) I_prev = None REC_ERROR = 0 if scales2invert is None: scales2invert = opt.stop_scale + 1 for scale in range(scales2invert): #for scale in range(3): #Get X, G X = reals[scale] G = Gs[scale] noise_amp = NoiseAmp[scale] #Defining Dimensions opt.nc_z = X.shape[1] opt.nzx = X.shape[2] opt.nzy = X.shape[3] #getting parameters for prior distribution penalty pdf = torch.distributions.Normal(0, 1) alpha = opt.reg #alpha = 1e-2 #Defining Z if scale == 0: z_init = functions.generate_noise( [1, opt.nzx, opt.nzy], device=opt.device) #only 1D noise else: z_init = functions.generate_noise( [3, opt.nzx, opt.nzy], device=opt.device) #otherwise move up to 3d noise z_init = Variable(z_init.cuda(), requires_grad=True) #variable to optimize #Building I_prev if I_prev == None: #first scale scenario in_s = torch.full(reals[0].shape, 0, device=opt.device) #all zeros I_prev = in_s I_prev = m_image(I_prev) #padding else: #otherwise take the output from the previous scale and upsample I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) #upsamples I_prev = m_image(I_prev) I_prev = I_prev[:, :, 0:X.shape[2] + 10, 0:X.shape[ 3] + 10] #making sure that precision errors don't mess anything up I_prev = functions.upsampling(I_prev, X.shape[2] + 10, X.shape[3] + 10) #seems to be redundant LR = [2e-3, 2e-2, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1] Zoptimizer = torch.optim.RMSprop([z_init], lr=LR[scale]) #Defining Optimizer x_loss = [] #for plotting epochs = [] #for plotting niter = [ 200, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400 ] for epoch in range(niter[scale]): #Gradient Descent on Z if scale == 0: noise_input = m_noise(z_init.expand(1, 3, opt.nzx, opt.nzy)) #expand and padd else: noise_input = m_noise(z_init) #padding z_in = noise_amp * noise_input + I_prev G_z = G(z_in, I_prev) x_recLoss = F.mse_loss(G_z, X) #MSE loss logProb = pdf.log_prob(z_init).mean() #Gaussian loss loss = x_recLoss - (alpha * logProb.mean()) Zoptimizer.zero_grad() loss.backward() Zoptimizer.step() #losses['rec'].append(x_recLoss.data[0]) #print('Image loss: [%d] loss: %0.5f' % (epoch, x_recLoss.item())) #print('Noise loss: [%d] loss: %0.5f' % (epoch, z_recLoss.item())) x_loss.append(loss.item()) epochs.append(epoch) REC_ERROR = x_recLoss if show: plt.plot(epochs, x_loss, label='x_loss') plt.legend() plt.show() I_prev = G_z.detach( ) #take final output, maybe need to edit this line something's very very fishy _ = show_image(X, show, 'target') reconstructed_image = show_image(I_prev, show, 'output') _ = show_image(noise_input.detach().cpu(), show, 'noise') Noise_Solutions.append(noise_input.detach()) return Noise_Solutions, reconstructed_image, REC_ERROR
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 # iterator through the pyramid real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min( opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128 ) # every 4 levels in the pyr, double the filter number. 128 as maximum (not too wide.) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) if opt.fast_training: if (scale_num > 0) & ( scale_num % 4 == 0 ): # every 4 scales half the iteration number! (train less for the finer details. ) opt.niter = opt.niter // 2 opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if ( nfc_prev == opt.nfc ): # if channel num match, then load the weights from last scale to init! G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) # real work G_curr = functions.reset_grads( G_curr, False) # Gnet no longer requires gradient. Thus weight is frozen! G_curr.eval() D_curr = functions.reset_grads( D_curr, False) # Note this D_curr is not the trained one...? D_curr.eval() Gs.append( G_curr ) # train append after train G at each scale. Note this G is no longer trainable! Zs.append( z_curr) # what is Zs? collection of z_opt towards current layer. NoiseAmp.append(opt.noise_amp) # Noise Amplitude is changed inside? torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): #from the name get the picture real_ = functions.read_image(opt) in_s = 0 scale_num = 0 #scale1 is defined from adjust2scale, saved in opt real = imresize(real_, opt.scale1, opt) # a list of resized images reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #for scale 0 to stop scale for scale_num in tqdm_notebook(range(opt.stop_scale + 1), desc=opt.input_name, leave=True): #define the number of channels in this scale opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) #define the minimum number of channels in this scale opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) #the output main directory opt.out_ = functions.generate_dir2save(opt) #the output sub directory for each scale opt.outf = f'{opt.out_}/{scale_num}' #if need create the directory try: os.makedirs(opt.outf) except OSError: pass #save the resized original image for this scale plt.imsave(f'{opt.outf}/real_scale.png', functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) #create the generator and discriminator D_curr, G_curr = init_models(opt) #if the number of channel of previous layer = current nfc if (nfc_prev == opt.nfc): #direct load the weightfrom last model G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) #train a single scale, get the current z, in_s, generator z_curr, in_s, G_curr = train_single_scale( D_curr, #current discriminator G_curr, #current generator reals, #the list of all resized data Gs, #generator list Zs, # a list initialized as [] in_s, # NoiseAmp, # opt #parameters ) #make current G and D untrainable,set it into eval mode G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() # save them into the list Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) #save the checkpoints torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) nfc_prev = opt.nfc #delete the D and G for memory del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 #If training for inpainting if opt.mode == "inpainting": #Importing mask image in space [0,255] if opt.on_drive != None: mask = img.imread('%s/%s/%s' % (opt.on_drive, opt.input_dir, opt.mask_name)) else: mask = img.imread('%s/%s' % (opt.input_dir, opt.mask_name)) #Convert mask to [O,1] space, 0 is masked out area, 1 everywhere else mask = 1 - (mask / 255) #Loading mask to torch tensor mask = torch.from_numpy(mask) #Resizing the initial mask mask = mask[:, :, :, None].view([1, 3, mask.shape[0], mask.shape[1]]) mask = imresize(mask, opt.scale1, opt) #Creating mask pyramid opt.masks = functions.creat_reals_pyramid(mask, [], opt) while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) #plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 memory = [] ##storing memory time = [] ##storing time while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) start = datetime.datetime.now() z_curr, in_s, G_curr, mbs, percent = train_single_scale( D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) memory.append([mbs, percent]) end = datetime.datetime.now() elapsed = end - start time.append(elapsed) print(f'time: {elapsed}') G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) torch.save(full_memory, '%s/full_memory.pth' % (opt.out_)) torch.save(full_time, '%s/full_time.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr #torch.save(full_memory, '%s/full_memory.pk' % (opt.out_)) #torch.save(full_time, '%s/full_time.pk' % (opt.out_)) print(memory) print(time) print(full_memory) print(full_time) #pk.dump(full_memory, open('Full_memory', 'wb')) return
def train(opt, Gs, Zs, reals, NoiseAmp): print('train() current parameters') print(opt) real_ = functions.read_image(opt) in_s = 0 if 'scale_num' in opt and opt.scale_num > 0: # EXPERIMENTAL: if we are in 'continue' mode in_s = torch.full(reals[0].shape, 0, device=opt.device) else: opt.scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.create_reals_pyramid(real, reals, opt) if 'nfc_prev' not in opt: opt.nfc_prev = 0 while opt.scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(opt.scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, opt.scale_num) try: os.makedirs(opt.outf) except OSError: print('directory %s already exists' % opt.outf) pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % opt.outf, functions.convert_image_np(reals[opt.scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if opt.nfc_prev == opt.nfc: G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, opt.scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, opt.scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % opt.out_) torch.save(Gs, '%s/Gs.pth' % opt.out_) torch.save(reals, '%s/reals.pth' % opt.out_) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % opt.out_) opt.scale_num += 1 opt.nfc_prev = opt.nfc del D_curr, G_curr return
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) cur_params = m.bias.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning all weights modules = [ G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0] ] parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'weight'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'weight'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'weight'), (G_curr.tail[0], 'bias')) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, ) for m in modules: prune.remove(m, 'weight') prune.remove(m, 'bias') visualize_weights(modules, 'prune') G_curr.half() ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return