def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): stamp = datetime.datetime.now() timestamp.append(stamp) delta = (timestamp[-1] - timestamp[-2]).seconds mbs, percent = memory_check() print('scale %d:[%d/%d] | Mb: %.3f | Percent %.3f | secs %.3f ' % (len(Gs), epoch, opt.niter, mbs, 100 * percent, delta)) full_memory.append([mbs, percent]) full_time.append(delta) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) nvidia_smi.nvmlInit() handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) mbs = mem_res.used / (1024**2) percent = mem_res.used / mem_res.total print(f'mem: {mem_res.used / (1024**2)} (GiB)') # usage in GiB print(f'mem: {100 * (mem_res.used / mem_res.total):.6f}%') # percentage return z_opt, in_s, netG, mbs, percent
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha # generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1) # size: [opt.nc_z, opt.nzx, opt.nzy] # z_opt: input noise for calculating reconstruction loss in Generator fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) # TODO: these plots are not visualized errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): start_time = time.time() if (Gs == []) & (opt.mode != 'SR_train'): # Bottom generator, here without zero init z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) # noise_: input noise for the discriminator noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): # Initialize prev and z_prev # prev: image outputs from previous level # z_prev: image outputs from previous level of fixed noise z_opt if (Gs == []) & (opt.mode != 'SR_train'): # in_s and prev are both noise # z_prev are also noise prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch > 1 and (epoch % 100 == 0 or epoch == (opt.niter - 1)): total_time = time.time() - start_time start_time = time.time() print('scale %d:[%d/%d], total time: %f' % (len(Gs), epoch, opt.niter, total_time)) memory = torch.cuda.max_memory_allocated() # print('allocated memory: %dG %dM %dk %d' % # ( memory // (1024*1024*1024), # (memory // (1024*1024)) % 1024, # (memory // 1024) % 1024, # memory % 1024 )) print('allocated memory: %.03f GB' % (memory / (1024 * 1024 * 1024 * 1.0))) # if epoch % 500 == 0 or epoch == (opt.niter-1): if epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) # plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) # plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) # plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) plt.imsave('%s/prev_plus_noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netD_mask1, netD_mask2,netG,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt): real1 = reals1[len(Gs)] real2 = reals2[len(Gs)] if opt.replace_background: background_real1 = create_background(functions.convert_image_np(real1)) real2 = create_img_over_background(functions.convert_image_np(real2), background_real1) plt.imsave('%s/background_real_scale1.png' % (opt.outf), background_real1, vmin=0, vmax=1) plt.imsave('%s/real_scale2_new.png' % (opt.outf), real2, vmin=0, vmax=1) real2 = functions.np2torch(real2, opt) # assumption: the images are the same size opt.nzx = real1.shape[2]#+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real1.shape[3]#+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy],device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) l1_loss = nn.L1Loss() zero_mask_tensor = torch.zeros([1,opt.nzx,opt.nzy], device=opt.device) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d) optimizerD_masked1 = optim.Adam(netD_mask1.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask1) optimizerD_masked2 = optim.Adam(netD_mask2.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask2) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) if opt.cyclic_lr: schedulerD = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) schedulerD_masked1 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked1, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) schedulerD__masked2 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked2, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) schedulerG = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerG, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) else: schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,milestones=[1600],gamma=opt.gamma) schedulerD_masked1 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked1, milestones=[1600], gamma=opt.gamma) schedulerD__masked2 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked2, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,milestones=[1600],gamma=opt.gamma) discriminators = [netD, netD_mask1, netD_mask2] discriminators_optimizers = [optimizerD, optimizerD_masked1, optimizerD_masked2] discriminators_schedulers = [schedulerD, schedulerD_masked1, schedulerD__masked2] err_D_img1_2plot = [] err_D_img2_2plot = [] err_D_mask1_2plot = [] err_D_mask2_2plot = [] errG_total_loss_2plot = [] errG_total_loss1_2plot = [] errG_total_loss2_2plot = [] errG_fake1_2plot = [] errG_fake2_2plot = [] D1_real2plot = [] D2_real2plot = [] D1_fake2plot = [] D2_fake2plot = [] l1_mask_loss2plot = [] mask_loss2plot = [] reconstruction_loss1_2plot = [] reconstruction_loss2_2plot = [] for epoch in range(opt.niter): """ We want to ensure that there exists a specific set of input noise maps, which generates the original image x. We specifically choose {z*, 0, 0, ..., 0}, where z* is some fixed noise map. In the first scale, we create that z* (aka z_opt). On other scales, z_opt is just zeros (initialized above) """ is_first_scale = len(Gs) == 0 noise1_, z_opt1 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z1) noise2_, z_opt2 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z2) ############################ # (1) Update D networks: # - netD: train with real on 2 input images (real1, real2) # - netD: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2) # if opt.enable_mask is ON, then: # - netD_mask1: train with real on real1 # - netD_mask2: train with real on real2 # - netD_mask1: train with fake on the generated fake image with mask1 applied on it # - netD_mask2: train with fake on the generated fake image with mask2 applied on it ########################### for j in range(opt.Dsteps): # train with real for discriminator in discriminators: discriminator.zero_grad() errD_real1, D_x1 = discriminator_train_with_real(netD, opt, real1) errD_real2, D_x2 = discriminator_train_with_real(netD, opt, real2) # single discriminator for each image if opt.enable_mask: errD_mask1_real1, _ = discriminator_train_with_real(netD_mask1, opt, real1) errD_mask2_real2, _ = discriminator_train_with_real(netD_mask2, opt, real2) # train with fake in_s1, noise1, prev1, new_z_prev1 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch, in_s1, is_first_scale, j, m_image, m_noise, noise1_, opt, real1, reals1, NoiseMode.Z1) in_s2, noise2, prev2, new_z_prev2 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch, in_s2, is_first_scale, j, m_image, m_noise, noise2_, opt, real2, reals2, NoiseMode.Z2) if new_z_prev1 is not None: z_prev1 = new_z_prev1 if new_z_prev2 is not None: z_prev2 = new_z_prev2 # Z1 only: mixed_noise1 = functions.merge_noise_vectors(noise1, torch.zeros(noise1.shape, device=opt.device), opt.noise_vectors_merge_method) fake1_output = _generate_fake(netG, mixed_noise1, prev1) if opt.enable_mask: fake1, fake1_mask1, fake1_mask2 = fake1_output else: fake1 = fake1_output[0] D_G_z_1, errD_fake1, gradient_penalty1 = _train_discriminator_with_fake(netD, fake1, opt, real1) # Z2 only: mixed_noise2 = functions.merge_noise_vectors(torch.zeros(noise2.shape, device=opt.device), noise2, opt.noise_vectors_merge_method) fake2_output = _generate_fake(netG, mixed_noise2, prev2) if opt.enable_mask: fake2, fake2_mask1, fake2_mask2 = fake2_output else: fake2 = fake2_output[0] D_G_z_2, errD_fake2, gradient_penalty2 = _train_discriminator_with_fake(netD, fake2, opt, real2) if opt.enable_mask: _, errD_mask1_fake1, _ = _train_discriminator_with_fake(netD_mask1, fake1_mask1, opt, real1) _, errD_mask2_fake2, _ = _train_discriminator_with_fake(netD_mask2, fake2_mask2, opt, real2) errD_image1 = errD_real1 + errD_fake1 + gradient_penalty1 errD_image2 = errD_real2 + errD_fake2 + gradient_penalty2 for discriminator_optimizer in discriminators_optimizers: discriminator_optimizer.step() err_D_img1_2plot.append(errD_image1.detach()) err_D_img2_2plot.append(errD_image2.detach()) ############################ # (2) Update G network: # - netG: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2) against netD # - netG: reconstruction loss against 2 real images # if opt.enable_mask is ON, then: # - netD_mask1: train with fake on the generated fake image with mask1 applied on it against netD_mask1 # - netD_mask2: train with fake on the generated fake image with mask2 applied on it against netD_mask2 ########################### for j in range(opt.Gsteps): netG.zero_grad() errG_fake1, D_fake1_map = _generator_train_with_fake(fake1, netD) errG_fake2, D_fake2_map = _generator_train_with_fake(fake2, netD) rec_loss1, Z_opt1 = _reconstruction_loss(alpha, netG, opt, z_opt1, z_prev1, real1, NoiseMode.Z1, opt.noise_amp1) rec_loss2, Z_opt2 = _reconstruction_loss(alpha, netG, opt, z_opt2, z_prev2, real2, NoiseMode.Z2, opt.noise_amp2) if opt.enable_mask: mask_loss_fake1_mask1, D_mask1_fake1_mask1_map = _generator_train_with_fake(fake1_mask1, netD_mask1) mask_loss_fake2_mask1, D_mask1_fake2_mask1_map = _generator_train_with_fake(fake2_mask1, netD_mask1) mask_loss_fake1_mask2, D_mask2_fake1_mask2_map = _generator_train_with_fake(fake1_mask2, netD_mask2) mask_loss_fake2_mask2, D_mask2_fake2_mask2_map = _generator_train_with_fake(fake2_mask2, netD_mask2) optimizerG.step() errG_total_loss1_2plot.append(errG_fake1.detach()+rec_loss1) errG_total_loss2_2plot.append(errG_fake2.detach()+rec_loss2) errG_fake1_2plot.append(errG_fake1.detach()) errG_fake2_2plot.append(errG_fake2.detach()) G_total_loss = errG_fake1.detach()+rec_loss1 + errG_fake2.detach()+rec_loss2 errG_total_loss_2plot.append(G_total_loss) D1_real2plot.append(D_x1) D2_real2plot.append(D_x2) D1_fake2plot.append(D_G_z_1) D2_fake2plot.append(D_G_z_2) reconstruction_loss1_2plot.append(rec_loss1) reconstruction_loss2_2plot.append(rec_loss2) if epoch % 25 == 0 or epoch == (opt.niter-1): logger.info('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter-1): plt.imsave('%s/fake_sample1.png' % (opt.outf), functions.convert_image_np(fake1.detach()), vmin=0, vmax=1) plt.imsave('%s/fake_sample2.png' % (opt.outf), functions.convert_image_np(fake2.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt1).png' % (opt.outf), functions.convert_image_np(netG(Z_opt1.detach(), z_prev1)[0].detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt2).png' % (opt.outf), functions.convert_image_np(netG(Z_opt2.detach(), z_prev2)[0].detach()), vmin=0, vmax=1) # torch.save((mixed_noise1, prev1), '%s/fake1_noise_source.pth' % (opt.outf)) # torch.save((mixed_noise2, prev2), '%s/fake2_noise_source.pth' % (opt.outf)) # torch.save((Z_opt1, z_prev1), '%s/G(z_opt1)_noise_source.pth' % (opt.outf)) # torch.save((Z_opt2, z_prev2), '%s/G(z_opt2)_noise_source.pth' % (opt.outf)) torch.save(z_opt1, '%s/z_opt1.pth' % (opt.outf)) torch.save(z_opt2, '%s/z_opt2.pth' % (opt.outf)) if epoch == (opt.niter-1) and opt.enable_mask: plt.imsave('%s/fake1_mask1.png' % (opt.outf), functions.convert_image_np(fake1_mask1.detach()), vmin=0, vmax=1) plt.imsave('%s/fake2_mask1.png' % (opt.outf), functions.convert_image_np(fake2_mask1.detach()), vmin=0, vmax=1) plt.imsave('%s/fake1_mask2.png' % (opt.outf), functions.convert_image_np(fake1_mask2.detach()), vmin=0, vmax=1) plt.imsave('%s/fake2_mask2.png' % (opt.outf), functions.convert_image_np(fake2_mask2.detach()), vmin=0, vmax=1) _imsave_discriminator_map(D_fake1_map, "D_fake1_map", opt) _imsave_discriminator_map(D_fake2_map, "D_fake2_map", opt) _imsave_discriminator_map(D_mask1_fake1_mask1_map, "D_mask1_fake1_mask1_map", opt) _imsave_discriminator_map(D_mask1_fake2_mask1_map, "D_mask1_fake2_mask1_map", opt) _imsave_discriminator_map(D_mask2_fake1_mask2_map, "D_mask2_fake1_mask2_map", opt) _imsave_discriminator_map(D_mask2_fake2_mask2_map, "D_mask2_fake2_mask2_map", opt) for discriminator_scheduler in discriminators_schedulers: discriminator_scheduler.step() schedulerG.step() functions.save_networks(netG,netD, netD_mask1, netD_mask2,z_opt1, z_opt2,opt) functions.plot_learning_curves("G_loss", opt.niter, [errG_total_loss_2plot, errG_total_loss1_2plot, errG_total_loss2_2plot, errG_fake1_2plot, errG_fake2_2plot, reconstruction_loss1_2plot, reconstruction_loss2_2plot], ["G_total_loss", "G_total_loss1", "G_total_loss2", "G_fake1_loss", "G_fake2_loss", "G_recon_loss_1", "G_recon_loss_2"], opt.outf) d_plots = [err_D_img1_2plot, err_D_img2_2plot] d_labels = ["D1_total_loss", "D2_total_loss"] functions.plot_learning_curves("D_loss", opt.niter, d_plots, d_labels, opt.outf) functions.plot_learning_curves("G_vs_D_loss", opt.niter, [errG_total_loss_2plot, errG_total_loss1_2plot, errG_total_loss2_2plot, err_D_img1_2plot, err_D_img2_2plot], ["G_total_loss", "G_total_loss1", "G_total_loss2", "D1_total_loss", "D2_total_loss"], opt.outf) return (z_opt1, z_opt2), (in_s1, in_s2), netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): schedulerD.step() schedulerG.step() if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy]) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy]) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): # print("Gs:", Gs) # Gs:scale尺度 print("len(Gs):", len(Gs)) # 获取当前scale的真实值 real = reals[len(Gs)] opt.nzx = real.shape[2] # +(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] # +(opt.ker_size-1)*(opt.num_layer) # 接受野 opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride print(opt.receptive_field) # out: 3+2*4*1 = 11 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) print(pad_noise) # pad_noise: 5 pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) # ZeroPad2d 在输入的数据周围做zero-padding m_noise = nn.ZeroPad2d(int(pad_noise)) print('m_noise', m_noise) # ZeroPad2d(padding=(5, 5, 5, 5), value=0.0) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha print(alpha) fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) # print('fixed_noise', fixed_noise.shape) # torch.Size([1, 3, 76, 76]) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # 设置优化器 optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) # 绘画损失列表 list_ssim = [] list_ssim_1 = [] list_ssim_2 = [] list_ssim_3 = [] list_ssim_4 = [] errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] # 循环迭代 for epoch in range(opt.niter): schedulerD.step() schedulerG.step() # 与运算 if (Gs == []) & (opt.mode != 'SR_train'): # opt.nzx和opt.nzy是当前scale的尺寸 z_opt = functions.generate_noise([1, opt.nzx, opt.nzy]) # 扩充维度为(1, 3, opt.nzx, opt.nzy) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy]) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) # 更新判别器 ########################### # Dsteps = 3 for j in range(opt.Dsteps): # train with real 用真实图像训练 netD.zero_grad() output = netD(real).to(opt.device) # print(netD) # print('output', output) # 4维数据 errD_real = -output.mean() # -a # print('errD_real', errD_real) # -2. errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake 用虚假图像训练 # 仅第一次训练用到z_prev(噪声) if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) # nc_z 3通道 z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) # print('z_prev', z_prev) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s # MSE 军方误差损失函数 criterion = nn.MSELoss() # 均方根误差, 标准误差 RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() # 标准误差 RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # .detach()用于切断反向传播 fake = netG(noise.detach(), prev) # 添加的ssim_loss ssim_loss = ssim(real, fake, data_range=255, size_average=True) output = netD(fake.detach()) # 判别以后损失反向传播 errD_fake = output.mean() # print('errD_fake', errD_fake) # -0.0072 errD_fake.backward(retain_graph=True) # 判别器_生成器_噪声 D_G_z = output.mean().item() # print('D_G_z', D_G_z) # 梯度惩罚--------------------------------------------------------------------------- gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad) # print('gradient_penalty', gradient_penalty) # 梯度惩罚更新 gradient_penalty.backward() # 损失函数:对抗损失 + 重构损失 D_ssim_3 = errD_real + errD_fake + gradient_penalty errD = errD_real + errD_fake + gradient_penalty # print('item', errD.item()) # D_ssim = 0.8 * (errD_real + errD_fake) + 0.68 * gradient_penalty + (1 - ssim_loss) # D_ssim_1 = 0.7 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss) # D_ssim_2 = 0.75 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss) # errD = (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss) # D_ssim_4 = 0.6 * (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss) # int_ssim = D_ssim.item() # int_ssim = round(int_ssim, 4) # int_ssim_1 = D_ssim_1.item() # int_ssim_1 = round(int_ssim_1, 4) # # int_ssim_2 = D_ssim_2.item() # int_ssim_2 = round(int_ssim_2, 4) int_ssim_3 = D_ssim_3.item() int_ssim_3 = round(int_ssim_3, 4) optimizerD.step() errDint = [] errD2plot.append(errD.detach()) # print('errD2plot', errD2plot) for i in range(len(errD2plot)): errDint.append(errD2plot[i].cpu().numpy()) # list_ssim.append(int_ssim) # list_ssim_1.append(int_ssim_1) # list_ssim_2.append(int_ssim_2) # list_ssim_3.append(int_ssim_3) # list_ssim_4.append(int_ssim_4) # print('list_ssim', list_ssim) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) # D_fake_map = output.detach() # errG均值函数 errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errGint = [] errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) for i in range(len(errG2plot)): errGint.append(errG2plot[i].cpu().numpy()) if epoch % 100 == 0 or epoch == (opt.niter - 1): # len(Gs):scale epoch= , niter = 2000 print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) # plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) # plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) # plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) # plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) # plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) # 目前是训练单次scale, 绘制第一次scale while (len(Gs) == 0): name = 'G_loos & G_loos' functions.plot_learning_curves(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break while (len(Gs) == 1): name = 'G_loos & G_loos_1' functions.plot_learning_curves_1(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break # while (len(Gs) == 2): # name = 'G_loos & G_loos_2' # functions.plot_learning_curves(errGint, errDint, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', # 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 3): # name = 'G_loos & G_loos_3' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 4): # name = 'G_loos & G_loos_4' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 5): # name = 'G_loos & G_loos_5' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 6): # name = 'G_loos & G_loos_6' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 7): # name = 'G_loos & G_loos_7' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 8): name = 'G_loos & G_loos_8' functions.plot_learning_curves_8(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, crops, masks, eye_color, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real_fullsize = reals[len(Gs)] crop_size = crops[len(Gs)].size()[2] fixed_crop = real_fullsize[:, :, 0:crop_size, 0:crop_size] if opt.random_crop: real, h_idx, w_idx = functions.random_crop(real_fullsize.clone(), crop_size) else: real = real_fullsize.clone() mask = masks[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) width opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) height opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs)) for epoch in range(opt.niter): if opt.resize: max_patch_size = int( min(real.size()[2], real.size()[3], mask.size()[2] * 1.25)) min_patch_size = int(max(mask.size()[2] * 0.75, 1)) patch_size = random.randint(min_patch_size, max_patch_size) mask_in = nn.functional.interpolate(mask.clone(), size=patch_size) eye_in = nn.functional.interpolate(eye.clone(), size=patch_size) else: mask_in = mask.clone() eye_in = eye.clone() eye_colored = eye_in.clone() if opt.random_eye_color: eye_color = functions.get_eye_color(real) eye_colored[:, 0, :, :] *= (eye_color[0] / 255) eye_colored[:, 1, :, :] *= (eye_color[1] / 255) eye_colored[:, 2, :, :] *= (eye_color[2] / 255) if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() #print(z_prev.get_device()) #print(real.get_device()) RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # Stacking masks and noise to make input G_input = functions.make_input(noise, mask_in, eye_colored) fake_background = netG(G_input.detach(), prev) import copy netG_copy = copy.deepcopy(netG) # Cropping mask shape from generated image and putting on top of real image at random location fake, fake_ind, eye_ind = functions.gen_fake( real, fake_background, mask_in, eye_in, eye_color, opt) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev input_opt = functions.make_input(Z_opt, mask_in, eye_in) rec_loss = alpha * loss(netG(input_opt.detach(), z_prev), real) #rec_loss = alpha*loss(netG(input_opt.detach(),z_prev),fixed_crop) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach())) plt.imsave('%s/fake_indicator.png' % (opt.outf), functions.convert_image_np(fake_ind.detach())) plt.imsave('%s/eye_indicator.png' % (opt.outf), functions.convert_image_np(eye_ind.detach())) plt.imsave('%s/background.png' % (opt.outf), functions.convert_image_np(fake_background.detach())) #plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np(netG(input_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() if opt.random_crop: real, h_idx, w_idx = functions.random_crop( real_fullsize, crop_size) #randomly find crop in image if opt.random_eye: eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs)) functions.save_networks(netG, netD, z_opt, opt) if len(Gs) == (opt.stop_scale): netG = netG_copy return z_opt, in_s, netG
def train_single_scale(netD,netG,reals,Gs,Zs,in_s,NoiseAmp,opt,scale_num, netG_optimizer, netD_optimizer): real = reals[len(Gs)] opt.nzx = real.shape[1]#+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[2]#+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = tf.keras.layers.ZeroPadding2D(padding=(int(pad_noise), int(pad_noise))) m_image = tf.keras.layers.ZeroPadding2D(padding=(int(pad_image), int(pad_image))) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy]) z_opt = tf.zeros_like(fixed_noise) z_opt = m_noise(z_opt) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): with tf.GradientTape(persistent=True) as netD_tape, tf.GradientTape(persistent=True) as netG_tape: if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1,opt.nzx,opt.nzy]) # (1,33,25) z_opt = tf.broadcast_to(z_opt, [1, z_opt.shape[1], z_opt.shape[2], 3]) # (1,33,25,3) z_opt = m_noise(z_opt) noise_ = functions.generate_noise([1,opt.nzx,opt.nzy]) noise_ = tf.broadcast_to(noise_, [1, noise_.shape[1], noise_.shape[2], 3]) noise_ = m_noise(noise_) else: noise_ = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): real = reals[len(Gs)] output_real = netD(real) errD_real = -tf.reduce_mean(output_real) D_x = float(-errD_real.numpy()) # Conversion to numpy required # train with fake if (j==0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = tf.zeros([1,opt.nzx,opt.nzy,opt.nc_z]) in_s = prev prev = m_image(prev) z_prev = tf.zeros([1, opt.nzx, opt.nzy, opt.nc_z]) z_prev = m_noise(z_prev) opt.noise_amp = 1 else: prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt) prev = m_image(prev) z_prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rec',m_noise,m_image,opt) RMSE = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(real, z_prev)))) opt.noise_amp = opt.noise_amp_init*RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt) prev = m_image(prev) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp*noise_+prev fake = netG(noise, prev, training=True) output_fake = netD(fake) errD_fake = tf.reduce_mean(output_fake) D_G_z = float(output_fake.numpy().mean()) fake_gp = functions.fake_gp_generator(real, fake) with tf.GradientTape() as gp_tape: gp_tape.watch(fake_gp) gp_D_src = netD(fake_gp) gp_D_grad = gp_tape.gradient(gp_D_src, fake_gp) gp = opt.lambda_grad*tf.reduce_mean(((tf.norm(gp_D_grad, ord=2, axis=3)-1.0)**2)) errD = errD_real + errD_fake + gp print('errD_real:', errD_real) print('errD_fake:', errD_fake) netD_gradients = netD_tape.gradient(errD, netD.trainable_variables) netD_optimizer.apply_gradients(zip(netD_gradients, netD.trainable_variables)) for j in range(opt.Gsteps): errG_fake = -tf.reduce_mean(output_fake) if alpha!=0: Z_opt = opt.noise_amp*z_opt+z_prev rec_loss = alpha * tf.reduce_mean(tf.square(tf.subtract(netG(Z_opt, z_prev, training=True), real))) else: Z_opt = z_opt rec_loss = 0 errG = errG_fake + rec_loss print('errG_fake:', errG_fake) print('rec_loss:', rec_loss) netG_gradients = netG_tape.gradient(errG, netG.trainable_variables) # 오직 netG만 update! netG_optimizer.apply_gradients(zip(netG_gradients, netG.trainable_variables)) del netG_tape, netD_tape errG2plot.append(errG+rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 1 == 0 or epoch == (opt.niter-1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter-1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake)) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np(netG(Z_opt, z_prev, training=False))) functions.save_networks(netD,netG,z_opt,opt,scale_num) # single scale training 끝날때 마다 저장 함 return z_opt,in_s,netG
def train_single_scale(netD, netG, reals, masks, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] mask = masks[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha # Here we calculate the decrease in output size due to conv layers. # required for setting the size of discriminators map. # TODO: currently, calculation doesn't consider opt.stride if (opt.ker_size % 2 == 0): r = (opt.num_layer) * (opt.ker_size - 1) else: #(opt.ker_size % 2 != 0): r = (opt.num_layer) * (opt.ker_size - 1) / 2 r = int(r) _, _, h, w = mask.size() discriminators_mask = mask.detach()[:, :, r:h - r, r:w - r][:, 0, :, :].unsqueeze(0) _, _, h, w = discriminators_mask.size() fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] norm = [] norm.append(1) norm.append((h * w) / discriminators_mask.sum().item()) plt.imsave('%s/mask.png' % (opt.outf), functions.convert_image_np(real * mask)) for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): if (epoch == 0): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) output = output * discriminators_mask D_real_map = output.detach() errD_real = -(output.mean()) * norm[opt.norm] #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device, discriminators_mask * norm[opt.norm]) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev netG_out = netG(Z_opt.detach(), z_prev) netG_out = netG_out * mask real = real * mask rec_loss = alpha * loss(netG_out, real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errD2plot.append(errD.detach()) errG2plot.append(errG.detach()) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) with open('%s/D_real2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(D_real2plot, fp) with open('%s/D_fake2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(D_fake2plot, fp) # with open('%s/D_real2plot' % (opt.outf), "rb") as fp: # Unpickling # D_fake2plot = pickle.load(fp) with open('%s/errD2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(errD2plot, fp) with open('%s/errG2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(errG2plot, fp) with open('%s/z_opt2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(z_opt2plot, fp) schedulerD.step() schedulerG.step() # plt.imsave('%s/masked_img.png' % (opt.outf), functions.convert_image_np(real*mask)) functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): """This is the core to understand it all. """ real = reals[len(Gs)] # real image at current scale opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': # Supplementary says they generate noise on the border in animation mode opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy ]) # select a fixed noise at start z_opt = torch.full( fixed_noise.shape, 0, device=opt.device) # but they didn't use it, just used all 0 instead. z_opt = m_noise(z_opt) # get 0 padded (still all 0) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] # collect the error in D training errG2plot = [] # collect the error in G training D_real2plot = [] # collect D_x error for real image D_fake2plot = [] # Discriminator error for fake image z_opt2plot = [] # collect reconstruction error for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): # if it's the first scale. z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand( 1, 3, opt.nzx, opt.nzy)) # this is chosen start of each epoch noise_ = functions.generate_noise( [1, opt.nzx, opt.nzy], device=opt.device) # single channel noise noise_ = m_noise( noise_.expand(1, 3, opt.nzx, opt.nzy) ) # expand is like repeat, it copy data along channel axis. So noise in RGB channels share the same val else: # for each epocs only one noise_ is used! why? noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # a few D step first # train with real netD.zero_grad() output = netD(real).to( opt.device) # Output of netD, the mean of it is the score! #D_real_map = output.detach() errD_real = -output.mean( ) #-a # want to maximize D output for patches in real img. errD_real.backward(retain_graph=True) D_x = -errD_real.item() # D loss for the real image! # train with fake, need to generate through the previous Generators if (j == 0) & (epoch == 0): # first Dstep in this level (epoch 0) if (Gs == []) & (opt.mode != 'SR_train'): # initial scale prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev # in_s doesn't get padded! prev = m_image(prev) # prev gets padded! z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion( real, z_prev)) # MSE between real and z_prev opt.noise_amp = opt.noise_amp_init * RMSE # learn the noise amplitude z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) # process the in_s ! prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): # top level noise = noise_ # now, full 0 tersor else: # other level noise = opt.noise_amp * noise_ + prev # now, full 0 tersor still # generate a single fake image through G and pass to D fake = netG( noise.detach(), prev ) # netG takes 2 inputs prev and noise, net process noise and + prev output = netD( fake.detach() ) # netD score the fake image, note they detach here, so error not back prop to G errD_fake = output.mean() # decrease the D output for fake. errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach().item() ) # loss combining D loss for real, fake and gradient ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): # then a few G steps netG.zero_grad() output = netD( fake) # note here the same fake image is used multi times!??? #D_fake_map = output.detach() errG = -output.mean( ) # the D, adversarial loss part. here want to minimize adversarial loss. (Fake the img) errG.backward(retain_graph=True) # Why? retain_graph if alpha != 0: # compute the reconstruction loss is alpha non-zero loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers( z_prev, centers ) # z_prev here are all inherited from D training part plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) # Why? retain_graph rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach().item() + rec_loss.item()) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss.item()) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) plt.imsave( '%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1 ) # this is the noise that go into generator, so prev + amp * noise_ plt.imsave('%s/Z_opt.png' % (opt.outf), functions.convert_image_np(Z_opt), vmin=0, vmax=1) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save( { "errD": errD2plot, "errG": errG2plot, "D_real": D_real2plot, "D_fake": D_fake2plot, "recons": z_opt2plot }, '%s/loss_trace.pth' % (opt.outf)) plt.figure() plt.plot(errD2plot, label="errD") plt.plot(errG2plot, label="errG") plt.plot(D_real2plot, label="Dreal") plt.plot(D_fake2plot, label="Dfake") plt.plot(z_opt2plot, label="recons") plt.legend() plt.savefig('%s/loss_trace.png' % (opt.outf)) plt.close() torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale( netD, #current discriminator netG, #current generator reals, #the list of all resized data Gs, #generator list Zs, # in_s, # NoiseAmp, # opt, #parameters centers=None): real = reals[len(Gs)] # get the current resized real picture #get the x and y opt.nzx = real.shape[2] opt.nzy = real.shape[3] #receptive field opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride #padding width pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) #this stuff create a torch.nn class adding 0 pads, tf is slightly harder m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) #get alpha from opt alpha = opt.alpha #generate a noise in the following size fixed_noise = functions.generate_noise( [ opt.nc_z, #noise # channels opt.nzx, opt.nzy ], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) #generate a tensor of size fixed_noise.shape filled with 0. z_opt = m_noise(z_opt) #give it a zero pad with width int(pad_noise) # setup optimizer and learning rate optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) #some plot list errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] #for iteration number' loop for epoch in tqdm_notebook(range(opt.niter), desc=f"scale {len(Gs)}", leave=False): #if it's the first graph, for G need an additional imput if (Gs == []): #generate a noise of size [1,opt.nzx,opt.nzy] z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) #give it a zero pad with width int(pad_noise) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) #generate another noise noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) #give it additional dimention with size 3, in all these dimension all 3 layers are the same #the padding it in all the dimension noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) # when it's not the first graph else: #nc_z is 'noise # channels' noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) #the padding it in all the dimension noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### # for Discriminator inner steps' loop for j in range(opt.Dsteps): # train with real #before training reset grad, torch operation netD.zero_grad() #generate a result output = netD(real).to(opt.device) #error for D, to minimize -(D(x) + D(G(z))), the mean should be -1 errD_real = -output.mean() #-a # have all the gradients computed errD_real.backward(retain_graph=True) #return the list with all dictionary keys with negative values D_x = -errD_real.item() # train with fake # for the first loop in the first epoch if (j == 0) & (epoch == 0): #if it's the first scale if (Gs == []): #set prev to all 0 prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev #zero padding with width int(pad_image) prev = m_image(prev) #set z_prev to all 0 z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) #padding with noise z_prev = m_noise(z_prev) #set amp = 1 opt.noise_amp = 1 else: # generate the prev from rand mode prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) #zero padding with width int(pad_image) prev = m_image(prev) # generate the z_prev from rec mode z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) #use opt.noise_amp_init*RMSE as the loss criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE # add a padding of width (int(pad_image)) z_prev = m_image(z_prev) else: #generate the prev form rand mode prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) # add a padding of width (int(pad_image)) prev = m_image(prev) #if it's the first scale if (Gs == []): #a noise added additional dimention with size 3, in all these dimension all 3 layers are the same #the padding it in all the dimension noise = noise_ else: #amplify the padded noise + prev noise = opt.noise_amp * noise_ + prev # generate the fake graph with noise # detach() detaches the output from the computationnal graph. # So no gradient will be backproped along this variable # in the very first loop G is RAW now fake = netG(noise.detach(), prev) # generate the output output = netD(fake.detach()) # generate the error from fake, to minimize -(D(x) + D(G(z))), the mean should be positive errD_fake = output.mean() # have all the gradients computed errD_fake.backward(retain_graph=True) #get the discriminator D_G_z = output.mean().item() #calculate the penalty gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) #calculate gradient gradient_penalty.backward() #calculate penal D errD = errD_real + errD_fake + gradient_penalty #updates the parameters. optimizerD.step() #add the stuff into a record errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): # init to 0 netG.zero_grad() # generate the output from the discrimator output = netD(fake) #the the loss of G is negative to result of D for competition errG = -output.mean() #calculate the backward errG.backward(retain_graph=True) if alpha != 0: #define MSE loss loss = nn.MSELoss() #amplify the z Z_opt = opt.noise_amp * z_opt + z_prev #use the result generate from Z_opt.detach(),z_prev, calculate the MSE with real, scale with alpha rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) #backward rec_loss.backward(retain_graph=True) #get a number loss rec_loss = rec_loss.detach() else: #alpha = 0 #else get Z as z Z_opt = z_opt #set the rec_loss = o rec_loss = 0 #update the result optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) #if epoch % 25 == 0 or epoch == (opt.niter-1): #replaced by tqdm #print(f'scale {len(Gs)}:[{epoch}/{opt.niter}]' #if epoch % 500 == 0 or epoch == (opt.niter-1): if epoch == (opt.niter - 1): #only saved once (for small graph) #save the fake sample plt.imsave(f'{opt.outf}/fake_sample.png', functions.convert_image_np(fake.detach()), vmin=0, vmax=1) #save the z_opt plt.imsave(f'{opt.outf}/G(z_opt).png', functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #save the model torch.save(z_opt, f'{opt.outf}/z_opt.pth') #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) #update learning rate schedulerD.step() schedulerG.step() # save the model functions.save_networks(netG, netD, z_opt, opt) # return the z, in_s(what's this), and generator G return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = [reals[0][len(Gs)], reals[1][len(Gs)]] print(len(real)) print(real[0].shape) opt.nzx = real[0].shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real[0].shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = [ optim.Adam(netD[i].parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) for i in range(2) ] optimizerG = [ optim.Adam(netG[i].parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) for i in range(2) ] schedulerD = [ torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD[i], milestones=[1600], gamma=opt.gamma) for i in range(2) ] schedulerG = [ torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG[i], milestones=[1600], gamma=opt.gamma) for i in range(2) ] errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD[0].zero_grad() netD[1].zero_grad() output1 = netD[0](real[1]).to( opt.device) # first discriminator trains vs 2nd image output2 = netD[1](real[0]).to(opt.device) #D_real_map = output.detach() errD_real1 = -output1.mean() #-a errD_real2 = -output2.mean() #-a errD_real1.backward(retain_graph=True) errD_real2.backward(retain_graph=True) # D_x1 = -errD_real.item() # D_x2 = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() print(real[0].shape) RMSE = torch.sqrt(criterion(real[0], z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # the money maker fake1 = netG[0](noise.detach(), prev) fake2 = netG[1](noise.detach(), fake1) output1 = netD[0](fake1.detach()) output2 = netD[1](fake2.detach()) errD_fake1 = output1.mean() errD_fake2 = output2.mean() errD_fake1.backward(retain_graph=True) errD_fake2.backward(retain_graph=True) # D_G_z = output.mean().item() # for graphing only # this is wasserstiens gradient_penalty = functions.calc_gradient_penalty( netD, real, [fake1, fake2], opt.lambda_grad, opt.device) gradient_penalty.backward() #errD = errD_real + errD_fake + gradient_penalty optimizerD[0].step() optimizerD[1].step() #errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG[0].zero_grad() netG[1].zero_grad() output1 = netD[0](fake1) output2 = netD[1](fake2) #D_fake_map = output.detach() errG1 = -output1.mean() errG1.backward(retain_graph=True) errG2 = -output2.mean() errG2.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() Z_opt = opt.noise_amp * z_opt + z_prev im1 = m_image(netG[0](Z_opt.detach(), z_prev)) #print(im1.size(), Z_opt.detach().size(), real[0].size()) rec_loss = alpha * loss(netG[1](im1, z_prev), real[0]) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() # Cycle Consistency Loss else: Z_opt = z_opt rec_loss = 0 optimizerG[0].step() optimizerG[1].step() # errG2plot.append(errG.detach()+rec_loss) # D_real2plot.append(D_x) # D_fake2plot.append(D_G_z)x` # z_opt2plot.append(rec_loss) if epoch % 500 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): # plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) # plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np(netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD[0].step() schedulerG[0].step() schedulerD[1].step() schedulerG[1].step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG