def train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale): in_s = torch.full(reals[0].shape, 0, device=opt.device) cur_scale_level = 0 nfc_prev = 0 while cur_scale_level < opt.stop_scale + 1: if cur_scale_level != paint_inject_scale: cur_scale_level += 1 nfc_prev = opt.nfc continue else: opt.nfc = min( opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/in_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) z_curr, in_s, G_curr = train_single_scale( D_curr, G_curr, reals[:cur_scale_level + 1], Gs[:cur_scale_level], Zs[:cur_scale_level], in_s, NoiseAmp[:cur_scale_level], opt, centers=centers) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs[cur_scale_level] = G_curr Zs[cur_scale_level] = z_curr NoiseAmp[cur_scale_level] = opt.noise_amp torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha # generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1) # size: [opt.nc_z, opt.nzx, opt.nzy] # z_opt: input noise for calculating reconstruction loss in Generator fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) # TODO: these plots are not visualized errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): start_time = time.time() if (Gs == []) & (opt.mode != 'SR_train'): # Bottom generator, here without zero init z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) # noise_: input noise for the discriminator noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): # Initialize prev and z_prev # prev: image outputs from previous level # z_prev: image outputs from previous level of fixed noise z_opt if (Gs == []) & (opt.mode != 'SR_train'): # in_s and prev are both noise # z_prev are also noise prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch > 1 and (epoch % 100 == 0 or epoch == (opt.niter - 1)): total_time = time.time() - start_time start_time = time.time() print('scale %d:[%d/%d], total time: %f' % (len(Gs), epoch, opt.niter, total_time)) memory = torch.cuda.max_memory_allocated() # print('allocated memory: %dG %dM %dk %d' % # ( memory // (1024*1024*1024), # (memory // (1024*1024)) % 1024, # (memory // 1024) % 1024, # memory % 1024 )) print('allocated memory: %.03f GB' % (memory / (1024 * 1024 * 1024 * 1.0))) # if epoch % 500 == 0 or epoch == (opt.niter-1): if epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) # plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) # plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) # plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) plt.imsave('%s/prev_plus_noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) nfc_prev = 0 # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, cur_scale_level - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, cur_scale_level - 1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() ################################################################################# # Visualzie weights def visualize_weights(modules, fig_name): ori_weights = torch.tensor([]).cuda() for m in modules: cur_params = m.weight.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) cur_params = m.bias.data.flatten() ori_weights = torch.cat((ori_weights, cur_params)) # sparsity = torch.sum(ori_weights == 0) * 1.0 / (ori_weights.nelement()) ori_weights = ori_weights.cpu().numpy() ori_weights = plt.hist(ori_weights[ori_weights != 0], bins=100) plt.savefig("%s/%s.png" % (opt.outf, fig_name)) plt.close() # Pruning all weights modules = [ G_curr.head.conv, G_curr.head.norm, G_curr.body.block1.conv, G_curr.body.block1.norm, G_curr.body.block2.conv, G_curr.body.block2.norm, G_curr.body.block3.conv, G_curr.body.block3.norm, G_curr.tail[0] ] parameters_to_prune = ((G_curr.head.conv, 'weight'), (G_curr.head.conv, 'bias'), (G_curr.head.norm, 'weight'), (G_curr.head.norm, 'bias'), (G_curr.body.block1.conv, 'weight'), (G_curr.body.block1.conv, 'bias'), (G_curr.body.block1.norm, 'weight'), (G_curr.body.block1.norm, 'bias'), (G_curr.body.block2.conv, 'weight'), (G_curr.body.block2.conv, 'bias'), (G_curr.body.block2.norm, 'weight'), (G_curr.body.block2.norm, 'bias'), (G_curr.body.block3.conv, 'weight'), (G_curr.body.block3.conv, 'bias'), (G_curr.body.block3.norm, 'weight'), (G_curr.body.block3.norm, 'bias'), (G_curr.tail[0], 'weight'), (G_curr.tail[0], 'bias')) visualize_weights(modules, 'ori') # Prune weights prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, ) for m in modules: prune.remove(m, 'weight') prune.remove(m, 'bias') visualize_weights(modules, 'prune') G_curr.half() ################################################################################# Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/pruned_Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] #if opt.input_type == 'audio': # real = real.permute((0, 2, 1)) print("@ train_single_scale:real.shape = ", real.shape, "| opt.mode = ", opt.mode) if opt.input_type == 'image': opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) else: opt.nzx = real.shape[1] # +(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[2] # +(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 if opt.input_type == 'image': m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) else: m_noise = nn.ConstantPad1d(int(pad_noise), 0) m_image = nn.ConstantPad1d(int(pad_image), 0) print("m_noise") alpha = opt.alpha if opt.input_type == 'image': fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) else: fixed_noise = functions.generate_noise([opt.nzx, opt.nzy], device=opt.device) print("fixed_noise.shape = ", fixed_noise.shape) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) #z_opt = torch.full(fixed_noise.shape, 0, device=opt.device, dtype=int) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): if opt.input_type == 'image': z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) if opt.conv_spectrogram == True: z_opt = m_noise(z_opt.expand(1, 2, opt.nzx, opt.nzy)) else: z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) else: z_opt = functions.generate_noise([opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt) if opt.input_type == 'image': noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) if opt.conv_spectrogram == True: noise_ = m_noise(noise_.expand(1, 2, opt.nzx, opt.nzy)) else: noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) else: if opt.input_type == 'image': noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) else: noise_ = functions.generate_noise([opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() if epoch % 100 == 0: print("@ train_single_scale: epoch = ", epoch, "| real.shape = ", real.shape) # if opt.input_type == 'audio': # real = real.permute((0,2,1)) # print("@ train_single_scale: real.shape = ", real.shape) output = netD(real).to(opt.device) # if opt.input_type == 'audio': # real = real.permute((0,2,1)) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): if opt.input_type == 'image': prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) else: prev = torch.full([1, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) if opt.input_type == 'image': z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) else: z_prev = torch.full([1, opt.nzx, opt.nzy], 0, device=opt.device) print("@ train_single_scale: z_prev.shape =", z_prev.shape) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() print('@ train_single_scale: real.shape = ', real.shape, '| z_prev.shape = ', z_prev.shape) RMSE = torch.sqrt(criterion(real, z_prev)) print("@ train_single_scale: RMSE.shape = ", RMSE.shape) opt.noise_amp = opt.noise_amp_init * RMSE print("@ train_single_scale: z_prev.shape = ", z_prev.shape) z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() print('@ train_single_scale: real.shape = ', real.shape, '| z_prev.shape = ', z_prev.shape) RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev if epoch % 100 == 0: print("@ train_single_scale: noise.detach().shape = ", noise.detach().shape, "prev.shape = ", prev.shape, "epoch = ", epoch, "j = ", j) # if opt.input_type == 'audio': # noise = noise.permute((0, 2, 1)) fake = netG(noise.detach(), prev) # if opt.input_type == 'audio': # noise = noise.permute((0, 2, 1)) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): if opt.input_type == 'image': if opt.conv_spectrogram == True: write('%s/fake_sample.wav' % (opt.outf), opt.sample_rate, functions.convert_spectrogram_np(fake.detach(), opt)) write( '%s/G(z_opt).wav' % (opt.outf), opt.sample_rate, functions.convert_spectrogram_np( netG(Z_opt.detach(), z_prev).detach(), opt)) else: plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) else: write('%s/fake_sample.wav' % (opt.outf), opt.sample_rate, functions.convert_audio_np(fake.detach(), opt)) write( '%s/G(z_opt).wav' % (opt.outf), opt.sample_rate, functions.convert_audio_np( netG(Z_opt.detach(), z_prev).detach(), opt)) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) #if opt.input_type == 'audio': # real = real.permute((0, 2, 1)) return z_opt, in_s, netG
in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] #opt.gen_start_scale=0 #print(in_s.shape) #in_s = torch.full(reals[0].shape, 0, device=opt.device) #in_s[0,:,:20,:]=0.2 if opt.quantization_flag: opt.mode = 'paint_train' dir2trained_model = functions.generate_dir2save(opt) # N = len(reals) - 1 # n = opt.paint_start_scale real_s = imresize(real, pow(opt.scale_factor, (N - n)), opt) real_s = real_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] real_quant, centers = functions.quant(real_s, opt.device) plt.imsave('%s/real_quant.png' % dir2save, functions.convert_image_np(real_quant), vmin=0, vmax=1) plt.imsave('%s/in_paint.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1) in_s = functions.quant2centers(ref, centers) in_s = imresize(in_s, pow(opt.scale_factor, (N - n)), opt) # in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] # in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] plt.imsave('%s/in_paint_quant.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1)
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50): #if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzh = (Z_opt.shape[2] - pad1 * 2) * scale_v nzw = (Z_opt.shape[3] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzh, nzw], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzh, nzw], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) # I_prev = m(I_prev) # I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] # I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt # print('z_curr:',z_curr.size()) # print('I_prev:',I_prev.size()) z_in = noise_amp * (z_curr) + I_prev I_curr = G(z_in.detach(), I_prev) if n == len(reals) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) n += 1 return I_curr.detach()
def train(opt, Gs, Zs, reals, NoiseAmp): if opt.input_type == 'image': real_ = functions.read_image(opt) else: real_ = functions.read_audio(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) if opt.input_type == 'image': if opt.conv_spectrogram == True: write('%s/real_scale.wav' % (opt.outf), opt.sample_rate, functions.convert_spectrogram_np(reals[scale_num], opt)) else: plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) else: write('%s/real_scale.wav' % (opt.outf), opt.sample_rate, functions.convert_audio_np(reals[scale_num], opt)) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
N = len(reals) - 1 n = opt.paint_start_scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] if opt.quantization_flag: opt.mode = 'paint_train' dir2trained_model = functions.generate_dir2save(opt) # N = len(reals) - 1 # n = opt.paint_start_scale real_s = imresize(real, pow(opt.scale_factor, (N - n)), opt) real_s = real_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] real_quant, centers = functions.quant(real_s) plt.imsave('%s/real_quant.png' % dir2save, functions.convert_image_np(real_quant), vmin=0, vmax=1) plt.imsave('%s/in_paint.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1) in_s = functions.quant2centers(ref, centers) in_s = imresize(in_s, pow(opt.scale_factor, (N - n)), opt) # in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] # in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] plt.imsave('%s/in_paint_quant.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1) if (os.path.exists(dir2trained_model)): # print('Trained model does not exist, training SinGAN for SR') Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) opt.mode = 'paint2image' else: train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, opt.paint_start_scale) opt.mode = 'paint2image' out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1)
(opt.ref_dir, opt.ref_name[:-4], opt.ref_name[-4:]), opt) if ref.shape[3] != real.shape[3]: mask = imresize_to_shape(mask, [real.shape[2], real.shape[3]], opt) mask = mask[:, :, :real.shape[2], :real.shape[3]] ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]], opt) ref = ref[:, :, :real.shape[2], :real.shape[3]] mask = functions.dilate_mask(mask, opt) N = len(reals) - 1 n = opt.harmonization_start_scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1) out = (1 - mask) * real + mask * out plt.imsave('%s/start_scale=%d.png' % (dir2save, opt.harmonization_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1)
ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) mask = functions.read_image_dir('%s/%s_mask%s' % (opt.ref_dir,opt.ref_name[:-4],opt.ref_name[-4:]), opt) if ref.shape[3] != real.shape[3]: ''' mask = imresize(mask, real.shape[3]/ref.shape[3], opt) mask = mask[:, :, :real.shape[2], :real.shape[3]] ref = imresize(ref, real.shape[3] / ref.shape[3], opt) ref = ref[:, :, :real.shape[2], :real.shape[3]] ''' mask = imresize_to_shape(mask, [real.shape[2],real.shape[3]], opt) mask = mask[:, :, :real.shape[2], :real.shape[3]] ref = imresize_to_shape(ref, [real.shape[2],real.shape[3]], opt) ref = ref[:, :, :real.shape[2], :real.shape[3]] mask = functions.dilate_mask(mask, opt) N = len(reals) - 1 n = opt.editing_start_scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1) plt.imsave('%s/start_scale=%d.png' % (dir2save, opt.editing_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1) out = (1-mask)*real+mask*out plt.imsave('%s/start_scale=%d_masked.png' % (dir2save, opt.editing_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1)
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_images(opt) in_s = 0 scale_num = 0 real = [ imresize(real_[0], opt.scale1, opt), imresize(real_[1], opt.scale1, opt) ] reals = [ functions.creat_reals_pyramid(real[0], reals, opt), functions.creat_reals_pyramid(real[1], reals, opt) ] nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale1.png' % (opt.outf), functions.convert_image_np(reals[0][scale_num]), vmin=0, vmax=1) plt.imsave('%s/real_scale2.png' % (opt.outf), functions.convert_image_np(reals[1][scale_num]), vmin=0, vmax=1) D_curr1, G_curr1 = init_models(opt) D_curr2, G_curr2 = init_models(opt) D_curr = [D_curr1, D_curr2] G_curr = [G_curr1, G_curr2] if (nfc_prev == opt.nfc): G_curr[0].load_state_dict( torch.load('%s/%d/netG1.pth' % (opt.out_, scale_num - 1))) D_curr[0].load_state_dict( torch.load('%s/%d/netD1.pth' % (opt.out_, scale_num - 1))) G_curr[1].load_state_dict( torch.load('%s/%d/netG2.pth' % (opt.out_, scale_num - 1))) D_curr[1].load_state_dict( torch.load('%s/%d/netD2.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr[0] = functions.reset_grads(G_curr[0], False) G_curr[0].eval() D_curr[0] = functions.reset_grads(D_curr[0], False) D_curr[0].eval() G_curr[1] = functions.reset_grads(G_curr[1], False) G_curr[1].eval() D_curr[1] = functions.reset_grads(D_curr[1], False) D_curr[1].eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50): if in_s is None: # torch NCWH tf NWHC in_s = tf.zeros_like(reals[0]) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = tf.keras.layers.ZeroPadding2D(padding=(int(pad1), int(pad1))) nzx = (Z_opt.shape[1] - pad1 * 2) * scale_v nzy = (Z_opt.shape[2] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzx, nzy]) z_curr = tf.tile(z_curr, multiples=(1, 1, 1, 3)) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy]) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, 0:round(scale_v * reals[n].shape[1]), 0:round(scale_h * reals[n].shape[2]), :] I_prev = m(I_prev) I_prev = I_prev[:, 0:z_curr.shape[1], 0:z_curr.shape[2], :] I_prev = functions.upsampling(I_prev, z_curr.shape[1], z_curr.shape[2]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp * (z_curr) + I_prev I_curr = G(z_in, I_prev, train=True) if n == len(reals) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr)) images_cur.append(I_curr) n += 1 return I_curr