def _train_discriminator_with_fake(netD, fake, opt, real): output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() return D_G_z, errD_fake, gradient_penalty
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha # generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1) # size: [opt.nc_z, opt.nzx, opt.nzy] # z_opt: input noise for calculating reconstruction loss in Generator fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) # TODO: these plots are not visualized errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): start_time = time.time() if (Gs == []) & (opt.mode != 'SR_train'): # Bottom generator, here without zero init z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) # noise_: input noise for the discriminator noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): # Initialize prev and z_prev # prev: image outputs from previous level # z_prev: image outputs from previous level of fixed noise z_opt if (Gs == []) & (opt.mode != 'SR_train'): # in_s and prev are both noise # z_prev are also noise prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch > 1 and (epoch % 100 == 0 or epoch == (opt.niter - 1)): total_time = time.time() - start_time start_time = time.time() print('scale %d:[%d/%d], total time: %f' % (len(Gs), epoch, opt.niter, total_time)) memory = torch.cuda.max_memory_allocated() # print('allocated memory: %dG %dM %dk %d' % # ( memory // (1024*1024*1024), # (memory // (1024*1024)) % 1024, # (memory // 1024) % 1024, # memory % 1024 )) print('allocated memory: %.03f GB' % (memory / (1024 * 1024 * 1024 * 1.0))) # if epoch % 500 == 0 or epoch == (opt.niter-1): if epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) # plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) # plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) # plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) plt.imsave('%s/prev_plus_noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): stamp = datetime.datetime.now() timestamp.append(stamp) delta = (timestamp[-1] - timestamp[-2]).seconds mbs, percent = memory_check() print('scale %d:[%d/%d] | Mb: %.3f | Percent %.3f | secs %.3f ' % (len(Gs), epoch, opt.niter, mbs, 100 * percent, delta)) full_memory.append([mbs, percent]) full_time.append(delta) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) nvidia_smi.nvmlInit() handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) mbs = mem_res.used / (1024**2) percent = mem_res.used / mem_res.total print(f'mem: {mem_res.used / (1024**2)} (GiB)') # usage in GiB print(f'mem: {100 * (mem_res.used / mem_res.total):.6f}%') # percentage return z_opt, in_s, netG, mbs, percent
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): schedulerD.step() schedulerG.step() if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy]) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy]) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, crops, masks, eye_color, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real_fullsize = reals[len(Gs)] crop_size = crops[len(Gs)].size()[2] fixed_crop = real_fullsize[:, :, 0:crop_size, 0:crop_size] if opt.random_crop: real, h_idx, w_idx = functions.random_crop(real_fullsize.clone(), crop_size) else: real = real_fullsize.clone() mask = masks[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) width opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) height opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs)) for epoch in range(opt.niter): if opt.resize: max_patch_size = int( min(real.size()[2], real.size()[3], mask.size()[2] * 1.25)) min_patch_size = int(max(mask.size()[2] * 0.75, 1)) patch_size = random.randint(min_patch_size, max_patch_size) mask_in = nn.functional.interpolate(mask.clone(), size=patch_size) eye_in = nn.functional.interpolate(eye.clone(), size=patch_size) else: mask_in = mask.clone() eye_in = eye.clone() eye_colored = eye_in.clone() if opt.random_eye_color: eye_color = functions.get_eye_color(real) eye_colored[:, 0, :, :] *= (eye_color[0] / 255) eye_colored[:, 1, :, :] *= (eye_color[1] / 255) eye_colored[:, 2, :, :] *= (eye_color[2] / 255) if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() #print(z_prev.get_device()) #print(real.get_device()) RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # Stacking masks and noise to make input G_input = functions.make_input(noise, mask_in, eye_colored) fake_background = netG(G_input.detach(), prev) import copy netG_copy = copy.deepcopy(netG) # Cropping mask shape from generated image and putting on top of real image at random location fake, fake_ind, eye_ind = functions.gen_fake( real, fake_background, mask_in, eye_in, eye_color, opt) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev input_opt = functions.make_input(Z_opt, mask_in, eye_in) rec_loss = alpha * loss(netG(input_opt.detach(), z_prev), real) #rec_loss = alpha*loss(netG(input_opt.detach(),z_prev),fixed_crop) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach())) plt.imsave('%s/fake_indicator.png' % (opt.outf), functions.convert_image_np(fake_ind.detach())) plt.imsave('%s/eye_indicator.png' % (opt.outf), functions.convert_image_np(eye_ind.detach())) plt.imsave('%s/background.png' % (opt.outf), functions.convert_image_np(fake_background.detach())) #plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np(netG(input_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() if opt.random_crop: real, h_idx, w_idx = functions.random_crop( real_fullsize, crop_size) #randomly find crop in image if opt.random_eye: eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs)) functions.save_networks(netG, netD, z_opt, opt) if len(Gs) == (opt.stop_scale): netG = netG_copy return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): # print("Gs:", Gs) # Gs:scale尺度 print("len(Gs):", len(Gs)) # 获取当前scale的真实值 real = reals[len(Gs)] opt.nzx = real.shape[2] # +(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] # +(opt.ker_size-1)*(opt.num_layer) # 接受野 opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride print(opt.receptive_field) # out: 3+2*4*1 = 11 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) print(pad_noise) # pad_noise: 5 pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) # ZeroPad2d 在输入的数据周围做zero-padding m_noise = nn.ZeroPad2d(int(pad_noise)) print('m_noise', m_noise) # ZeroPad2d(padding=(5, 5, 5, 5), value=0.0) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha print(alpha) fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) # print('fixed_noise', fixed_noise.shape) # torch.Size([1, 3, 76, 76]) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # 设置优化器 optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) # 绘画损失列表 list_ssim = [] list_ssim_1 = [] list_ssim_2 = [] list_ssim_3 = [] list_ssim_4 = [] errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] # 循环迭代 for epoch in range(opt.niter): schedulerD.step() schedulerG.step() # 与运算 if (Gs == []) & (opt.mode != 'SR_train'): # opt.nzx和opt.nzy是当前scale的尺寸 z_opt = functions.generate_noise([1, opt.nzx, opt.nzy]) # 扩充维度为(1, 3, opt.nzx, opt.nzy) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy]) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) # 更新判别器 ########################### # Dsteps = 3 for j in range(opt.Dsteps): # train with real 用真实图像训练 netD.zero_grad() output = netD(real).to(opt.device) # print(netD) # print('output', output) # 4维数据 errD_real = -output.mean() # -a # print('errD_real', errD_real) # -2. errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake 用虚假图像训练 # 仅第一次训练用到z_prev(噪声) if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) # nc_z 3通道 z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) # print('z_prev', z_prev) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s # MSE 军方误差损失函数 criterion = nn.MSELoss() # 均方根误差, 标准误差 RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() # 标准误差 RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # .detach()用于切断反向传播 fake = netG(noise.detach(), prev) # 添加的ssim_loss ssim_loss = ssim(real, fake, data_range=255, size_average=True) output = netD(fake.detach()) # 判别以后损失反向传播 errD_fake = output.mean() # print('errD_fake', errD_fake) # -0.0072 errD_fake.backward(retain_graph=True) # 判别器_生成器_噪声 D_G_z = output.mean().item() # print('D_G_z', D_G_z) # 梯度惩罚--------------------------------------------------------------------------- gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad) # print('gradient_penalty', gradient_penalty) # 梯度惩罚更新 gradient_penalty.backward() # 损失函数:对抗损失 + 重构损失 D_ssim_3 = errD_real + errD_fake + gradient_penalty errD = errD_real + errD_fake + gradient_penalty # print('item', errD.item()) # D_ssim = 0.8 * (errD_real + errD_fake) + 0.68 * gradient_penalty + (1 - ssim_loss) # D_ssim_1 = 0.7 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss) # D_ssim_2 = 0.75 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss) # errD = (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss) # D_ssim_4 = 0.6 * (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss) # int_ssim = D_ssim.item() # int_ssim = round(int_ssim, 4) # int_ssim_1 = D_ssim_1.item() # int_ssim_1 = round(int_ssim_1, 4) # # int_ssim_2 = D_ssim_2.item() # int_ssim_2 = round(int_ssim_2, 4) int_ssim_3 = D_ssim_3.item() int_ssim_3 = round(int_ssim_3, 4) optimizerD.step() errDint = [] errD2plot.append(errD.detach()) # print('errD2plot', errD2plot) for i in range(len(errD2plot)): errDint.append(errD2plot[i].cpu().numpy()) # list_ssim.append(int_ssim) # list_ssim_1.append(int_ssim_1) # list_ssim_2.append(int_ssim_2) # list_ssim_3.append(int_ssim_3) # list_ssim_4.append(int_ssim_4) # print('list_ssim', list_ssim) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) # D_fake_map = output.detach() # errG均值函数 errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errGint = [] errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) for i in range(len(errG2plot)): errGint.append(errG2plot[i].cpu().numpy()) if epoch % 100 == 0 or epoch == (opt.niter - 1): # len(Gs):scale epoch= , niter = 2000 print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) # plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) # plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) # plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) # plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) # plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) # 目前是训练单次scale, 绘制第一次scale while (len(Gs) == 0): name = 'G_loos & G_loos' functions.plot_learning_curves(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break while (len(Gs) == 1): name = 'G_loos & G_loos_1' functions.plot_learning_curves_1(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break # while (len(Gs) == 2): # name = 'G_loos & G_loos_2' # functions.plot_learning_curves(errGint, errDint, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', # 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 3): # name = 'G_loos & G_loos_3' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 4): # name = 'G_loos & G_loos_4' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 5): # name = 'G_loos & G_loos_5' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 6): # name = 'G_loos & G_loos_6' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 7): # name = 'G_loos & G_loos_7' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 8): name = 'G_loos & G_loos_8' functions.plot_learning_curves_8(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, masks, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] mask = masks[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha # Here we calculate the decrease in output size due to conv layers. # required for setting the size of discriminators map. # TODO: currently, calculation doesn't consider opt.stride if (opt.ker_size % 2 == 0): r = (opt.num_layer) * (opt.ker_size - 1) else: #(opt.ker_size % 2 != 0): r = (opt.num_layer) * (opt.ker_size - 1) / 2 r = int(r) _, _, h, w = mask.size() discriminators_mask = mask.detach()[:, :, r:h - r, r:w - r][:, 0, :, :].unsqueeze(0) _, _, h, w = discriminators_mask.size() fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] norm = [] norm.append(1) norm.append((h * w) / discriminators_mask.sum().item()) plt.imsave('%s/mask.png' % (opt.outf), functions.convert_image_np(real * mask)) for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): if (epoch == 0): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) output = output * discriminators_mask D_real_map = output.detach() errD_real = -(output.mean()) * norm[opt.norm] #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device, discriminators_mask * norm[opt.norm]) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev netG_out = netG(Z_opt.detach(), z_prev) netG_out = netG_out * mask real = real * mask rec_loss = alpha * loss(netG_out, real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errD2plot.append(errD.detach()) errG2plot.append(errG.detach()) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) with open('%s/D_real2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(D_real2plot, fp) with open('%s/D_fake2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(D_fake2plot, fp) # with open('%s/D_real2plot' % (opt.outf), "rb") as fp: # Unpickling # D_fake2plot = pickle.load(fp) with open('%s/errD2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(errD2plot, fp) with open('%s/errG2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(errG2plot, fp) with open('%s/z_opt2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(z_opt2plot, fp) schedulerD.step() schedulerG.step() # plt.imsave('%s/masked_img.png' % (opt.outf), functions.convert_image_np(real*mask)) functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): """This is the core to understand it all. """ real = reals[len(Gs)] # real image at current scale opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': # Supplementary says they generate noise on the border in animation mode opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy ]) # select a fixed noise at start z_opt = torch.full( fixed_noise.shape, 0, device=opt.device) # but they didn't use it, just used all 0 instead. z_opt = m_noise(z_opt) # get 0 padded (still all 0) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] # collect the error in D training errG2plot = [] # collect the error in G training D_real2plot = [] # collect D_x error for real image D_fake2plot = [] # Discriminator error for fake image z_opt2plot = [] # collect reconstruction error for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): # if it's the first scale. z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand( 1, 3, opt.nzx, opt.nzy)) # this is chosen start of each epoch noise_ = functions.generate_noise( [1, opt.nzx, opt.nzy], device=opt.device) # single channel noise noise_ = m_noise( noise_.expand(1, 3, opt.nzx, opt.nzy) ) # expand is like repeat, it copy data along channel axis. So noise in RGB channels share the same val else: # for each epocs only one noise_ is used! why? noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # a few D step first # train with real netD.zero_grad() output = netD(real).to( opt.device) # Output of netD, the mean of it is the score! #D_real_map = output.detach() errD_real = -output.mean( ) #-a # want to maximize D output for patches in real img. errD_real.backward(retain_graph=True) D_x = -errD_real.item() # D loss for the real image! # train with fake, need to generate through the previous Generators if (j == 0) & (epoch == 0): # first Dstep in this level (epoch 0) if (Gs == []) & (opt.mode != 'SR_train'): # initial scale prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev # in_s doesn't get padded! prev = m_image(prev) # prev gets padded! z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion( real, z_prev)) # MSE between real and z_prev opt.noise_amp = opt.noise_amp_init * RMSE # learn the noise amplitude z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) # process the in_s ! prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): # top level noise = noise_ # now, full 0 tersor else: # other level noise = opt.noise_amp * noise_ + prev # now, full 0 tersor still # generate a single fake image through G and pass to D fake = netG( noise.detach(), prev ) # netG takes 2 inputs prev and noise, net process noise and + prev output = netD( fake.detach() ) # netD score the fake image, note they detach here, so error not back prop to G errD_fake = output.mean() # decrease the D output for fake. errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach().item() ) # loss combining D loss for real, fake and gradient ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): # then a few G steps netG.zero_grad() output = netD( fake) # note here the same fake image is used multi times!??? #D_fake_map = output.detach() errG = -output.mean( ) # the D, adversarial loss part. here want to minimize adversarial loss. (Fake the img) errG.backward(retain_graph=True) # Why? retain_graph if alpha != 0: # compute the reconstruction loss is alpha non-zero loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers( z_prev, centers ) # z_prev here are all inherited from D training part plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) # Why? retain_graph rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach().item() + rec_loss.item()) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss.item()) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) plt.imsave( '%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1 ) # this is the noise that go into generator, so prev + amp * noise_ plt.imsave('%s/Z_opt.png' % (opt.outf), functions.convert_image_np(Z_opt), vmin=0, vmax=1) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save( { "errD": errD2plot, "errG": errG2plot, "D_real": D_real2plot, "D_fake": D_fake2plot, "recons": z_opt2plot }, '%s/loss_trace.pth' % (opt.outf)) plt.figure() plt.plot(errD2plot, label="errD") plt.plot(errG2plot, label="errG") plt.plot(D_real2plot, label="Dreal") plt.plot(D_fake2plot, label="Dfake") plt.plot(z_opt2plot, label="recons") plt.legend() plt.savefig('%s/loss_trace.png' % (opt.outf)) plt.close() torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale( netD, #current discriminator netG, #current generator reals, #the list of all resized data Gs, #generator list Zs, # in_s, # NoiseAmp, # opt, #parameters centers=None): real = reals[len(Gs)] # get the current resized real picture #get the x and y opt.nzx = real.shape[2] opt.nzy = real.shape[3] #receptive field opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride #padding width pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) #this stuff create a torch.nn class adding 0 pads, tf is slightly harder m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) #get alpha from opt alpha = opt.alpha #generate a noise in the following size fixed_noise = functions.generate_noise( [ opt.nc_z, #noise # channels opt.nzx, opt.nzy ], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) #generate a tensor of size fixed_noise.shape filled with 0. z_opt = m_noise(z_opt) #give it a zero pad with width int(pad_noise) # setup optimizer and learning rate optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) #some plot list errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] #for iteration number' loop for epoch in tqdm_notebook(range(opt.niter), desc=f"scale {len(Gs)}", leave=False): #if it's the first graph, for G need an additional imput if (Gs == []): #generate a noise of size [1,opt.nzx,opt.nzy] z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) #give it a zero pad with width int(pad_noise) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) #generate another noise noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) #give it additional dimention with size 3, in all these dimension all 3 layers are the same #the padding it in all the dimension noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) # when it's not the first graph else: #nc_z is 'noise # channels' noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) #the padding it in all the dimension noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### # for Discriminator inner steps' loop for j in range(opt.Dsteps): # train with real #before training reset grad, torch operation netD.zero_grad() #generate a result output = netD(real).to(opt.device) #error for D, to minimize -(D(x) + D(G(z))), the mean should be -1 errD_real = -output.mean() #-a # have all the gradients computed errD_real.backward(retain_graph=True) #return the list with all dictionary keys with negative values D_x = -errD_real.item() # train with fake # for the first loop in the first epoch if (j == 0) & (epoch == 0): #if it's the first scale if (Gs == []): #set prev to all 0 prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev #zero padding with width int(pad_image) prev = m_image(prev) #set z_prev to all 0 z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) #padding with noise z_prev = m_noise(z_prev) #set amp = 1 opt.noise_amp = 1 else: # generate the prev from rand mode prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) #zero padding with width int(pad_image) prev = m_image(prev) # generate the z_prev from rec mode z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) #use opt.noise_amp_init*RMSE as the loss criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE # add a padding of width (int(pad_image)) z_prev = m_image(z_prev) else: #generate the prev form rand mode prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) # add a padding of width (int(pad_image)) prev = m_image(prev) #if it's the first scale if (Gs == []): #a noise added additional dimention with size 3, in all these dimension all 3 layers are the same #the padding it in all the dimension noise = noise_ else: #amplify the padded noise + prev noise = opt.noise_amp * noise_ + prev # generate the fake graph with noise # detach() detaches the output from the computationnal graph. # So no gradient will be backproped along this variable # in the very first loop G is RAW now fake = netG(noise.detach(), prev) # generate the output output = netD(fake.detach()) # generate the error from fake, to minimize -(D(x) + D(G(z))), the mean should be positive errD_fake = output.mean() # have all the gradients computed errD_fake.backward(retain_graph=True) #get the discriminator D_G_z = output.mean().item() #calculate the penalty gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) #calculate gradient gradient_penalty.backward() #calculate penal D errD = errD_real + errD_fake + gradient_penalty #updates the parameters. optimizerD.step() #add the stuff into a record errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): # init to 0 netG.zero_grad() # generate the output from the discrimator output = netD(fake) #the the loss of G is negative to result of D for competition errG = -output.mean() #calculate the backward errG.backward(retain_graph=True) if alpha != 0: #define MSE loss loss = nn.MSELoss() #amplify the z Z_opt = opt.noise_amp * z_opt + z_prev #use the result generate from Z_opt.detach(),z_prev, calculate the MSE with real, scale with alpha rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) #backward rec_loss.backward(retain_graph=True) #get a number loss rec_loss = rec_loss.detach() else: #alpha = 0 #else get Z as z Z_opt = z_opt #set the rec_loss = o rec_loss = 0 #update the result optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) #if epoch % 25 == 0 or epoch == (opt.niter-1): #replaced by tqdm #print(f'scale {len(Gs)}:[{epoch}/{opt.niter}]' #if epoch % 500 == 0 or epoch == (opt.niter-1): if epoch == (opt.niter - 1): #only saved once (for small graph) #save the fake sample plt.imsave(f'{opt.outf}/fake_sample.png', functions.convert_image_np(fake.detach()), vmin=0, vmax=1) #save the z_opt plt.imsave(f'{opt.outf}/G(z_opt).png', functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #save the model torch.save(z_opt, f'{opt.outf}/z_opt.pth') #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) #update learning rate schedulerD.step() schedulerG.step() # save the model functions.save_networks(netG, netD, z_opt, opt) # return the z, in_s(what's this), and generator G return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = [reals[0][len(Gs)], reals[1][len(Gs)]] print(len(real)) print(real[0].shape) opt.nzx = real[0].shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real[0].shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = [ optim.Adam(netD[i].parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) for i in range(2) ] optimizerG = [ optim.Adam(netG[i].parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) for i in range(2) ] schedulerD = [ torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD[i], milestones=[1600], gamma=opt.gamma) for i in range(2) ] schedulerG = [ torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG[i], milestones=[1600], gamma=opt.gamma) for i in range(2) ] errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD[0].zero_grad() netD[1].zero_grad() output1 = netD[0](real[1]).to( opt.device) # first discriminator trains vs 2nd image output2 = netD[1](real[0]).to(opt.device) #D_real_map = output.detach() errD_real1 = -output1.mean() #-a errD_real2 = -output2.mean() #-a errD_real1.backward(retain_graph=True) errD_real2.backward(retain_graph=True) # D_x1 = -errD_real.item() # D_x2 = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() print(real[0].shape) RMSE = torch.sqrt(criterion(real[0], z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # the money maker fake1 = netG[0](noise.detach(), prev) fake2 = netG[1](noise.detach(), fake1) output1 = netD[0](fake1.detach()) output2 = netD[1](fake2.detach()) errD_fake1 = output1.mean() errD_fake2 = output2.mean() errD_fake1.backward(retain_graph=True) errD_fake2.backward(retain_graph=True) # D_G_z = output.mean().item() # for graphing only # this is wasserstiens gradient_penalty = functions.calc_gradient_penalty( netD, real, [fake1, fake2], opt.lambda_grad, opt.device) gradient_penalty.backward() #errD = errD_real + errD_fake + gradient_penalty optimizerD[0].step() optimizerD[1].step() #errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG[0].zero_grad() netG[1].zero_grad() output1 = netD[0](fake1) output2 = netD[1](fake2) #D_fake_map = output.detach() errG1 = -output1.mean() errG1.backward(retain_graph=True) errG2 = -output2.mean() errG2.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() Z_opt = opt.noise_amp * z_opt + z_prev im1 = m_image(netG[0](Z_opt.detach(), z_prev)) #print(im1.size(), Z_opt.detach().size(), real[0].size()) rec_loss = alpha * loss(netG[1](im1, z_prev), real[0]) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() # Cycle Consistency Loss else: Z_opt = z_opt rec_loss = 0 optimizerG[0].step() optimizerG[1].step() # errG2plot.append(errG.detach()+rec_loss) # D_real2plot.append(D_x) # D_fake2plot.append(D_G_z)x` # z_opt2plot.append(rec_loss) if epoch % 500 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): # plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) # plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np(netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD[0].step() schedulerG[0].step() schedulerD[1].step() schedulerG[1].step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG