def train(): # Inspired by https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html # Lists to keep track of progress num_epochs = args.max_epoch #img_list = [] iters = 0 print("Starting Training Loop...") sys.stdout.flush() for epoch in range(num_epochs): start = time.time() G_losses = [] D_losses = [] for i, data in enumerate(train_loader, 0): global_i = len(train_loader) * epoch + i ############################ # D network ########################### ## Train with all-real batch #with torch.autograd.detect_anomaly(): optimizerD.zero_grad() real_image, mask, masked_image, loss_mask = data[0].to(device), data[1].to(device), data[2].to(device), data[3].to(device) jitter_real = torch.empty_like(real_image, device=device).uniform_(-0.05 * (0.99 ** epoch), 0.05 * (0.99 ** epoch)) jitter_fake = torch.empty_like(real_image, device=device).uniform_(-0.05 * (0.99 ** epoch), 0.05 * (0.99 ** epoch)) #real_preds, real_feats = netD(real_image, mask) real_preds, real_feats = netD(torch.clamp(real_image + jitter_real, -1, 1), mask) ## Train with all-fake batch # noise = torch.randn(b_size, nz, 1, 1, device=device) _, skips = netS1(masked_image) embed, _ = netS1(real_image, False) #embed, _ = netS2(real_image, False) style_code, mu, sigma = netM(embed) fake = netG(style_code, mask, skips) fake_preds, fake_feats = netD(torch.clamp(fake.detach() + jitter_fake, -1, 1), mask) #fake_preds, fake_feats = netD(fake.detach(), mask) errD = 0.0 for fp, rp in zip(fake_preds, real_preds): errD += losses.hinge_loss_discriminator(fp, rp) errD.backward() optimizerD.step() # dump train metrics to tensorboard if writer is not None: writer.add_scalar(f"loss_D", errD.item(), global_i) ############################ # G network ########################### optimizerM.zero_grad() optimizerS1.zero_grad() #optimizerS2.zero_grad() optimizerG.zero_grad() #l1 = losses.masked_l1(fake, masked_image, loss_mask) * args.cycle_lambda #l1.backward(retain_graph=True) dkl = losses.KL_divergence(mu, sigma) * args.kl_lambda dkl.backward(retain_graph=True) errG_p = 0.0 for ff, rf in zip(fake_vgg_f, real_vgg_f): errG_p += losses.perceptual_loss(ff, rf.detach(), args.fm_lambda) errG_p.backward(retain_graph=True) fake_preds, fake_feats = netD(fake, mask) errG_hinge = 0.0 for fp in fake_preds: errG_hinge += losses.hinge_loss_generator(fp) errG_hinge.backward(retain_graph=True) errG_fm = 0.0 for ff, rf in zip(fake_feats, real_feats): errG_fm += losses.perceptual_loss(ff, rf.detach(), args.fm_lambda) errG_fm.backward() errG = errG_hinge.item() + errG_fm.item() + errG_p.item() #+ l1.item() if args.G_orth > 0.0: losses.ortho(netG, args.G_orth, blacklist=[]) optimizerG.step() optimizerS1.step() #optimizerS2.step() optimizerM.step() if writer is not None: writer.add_scalar(f"loss_G", errG, global_i) # Output training stats if i % 500 == 499: print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t' % (epoch, num_epochs, i, len(train_loader), errD.item(), errG)) sys.stdout.flush() with torch.no_grad(): netG.eval() netS1.eval() #netS2.eval() netM.eval() _, test_skips = netS1(fixed_test_images) test_embed, _ = netS1(fixed_test_real_images, False) #test_embed, _ = netS2(fixed_test_real_images, False) style_code, _, _ = netM(test_embed) test_generated = netG(style_code, fixed_test_masks, test_skips).detach().cpu() netG.train() netS1.train() #netS2.train() netM.train() tim = vutils.save_image(test_generated.data[:16], '%s/%d.png' % (test_save_dir, epoch), normalize=True, save=False) writer.add_image('generated', tim, global_i, dataformats='HWC') G_losses.append(errG) D_losses.append(errD) # Check how the generator is doing by saving G's output on fixed_noise end = time.time() hours, rem = divmod(end - start, 3600) minutes, seconds = divmod(rem, 60) if epoch % args.save_frequency == 0: with torch.no_grad(): netG.eval() _, test_skips = netS1(fixed_test_images) test_embed, _ = netS1(fixed_test_real_images, False) #test_embed, _ = netS2(fixed_test_real_images, False) style_code, _, _ = netM(test_embed) test_generated = netG(style_code, fixed_test_masks, test_skips).detach().cpu() netG.train() netS1.train() #netS2.train() netM.train() # img_list.append(fake.data.numpy()) print("Epoch %d - Elapsed time: {:0>2}:{:0>2}:{:05.2f}".format(epoch, int(hours), int(minutes), seconds)) sys.stdout.flush() _ = vutils.save_image(test_generated.data[:16], '%s/%d.png' % (test_save_dir, epoch), normalize=True) # writer.add_image('generated', tim, epoch, dataformats='HWC') torch.save(netG.state_dict(), os.path.join(args.root_path, 'NetG' + best_model_path)) torch.save(netD.state_dict(), os.path.join(args.root_path, 'NetD' + best_model_path)) # plt.imsave(os.path.join( # './{}/'.format(test_save_dir) + 'img{}.png'.format(datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))), # ((img_list[-1][0] + 1) / 2.0).transpose([1, 2, 0]), cmap='gray', interpolation="none") if writer is not None: writer.add_scalar(f"loss_G_epoch", np.sum(G_losses) / len(train_loader), epoch) writer.add_scalar(f"loss_D_epoch", np.sum(D_losses) / len(train_loader), epoch) iters += 1
transform = Compose([Resize(resize_height, resize_width), HorizontalFlip(p=0.5), ToTensorV2()]) train_dataset = print('Loading Dataset...') sys.stdout.flush() train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) test_batch = next(iter(train_loader)) fixed_test_images = test_batch[2].to(device) fixed_test_masks = test_batch[1].to(device) fixed_test_real_images = test_batch[0].to(device) _ = vutils.save_image(fixed_test_real_images.cpu().data[:16], '!test.png', normalize=True) _ = vutils.save_image(fixed_test_images.cpu().data[:16], '!test_noise.png', normalize=True) ##MODE netD = MultiscaleDiscriminator(args.mask_channels + 3).to(device) netD.apply(weights_init) netG = GauGANUnetStylizationGenerator(args.mask_channels, args.encoder_latent_dim, 2, args.unet_ch, device).to(device) netG.apply(weights_init) netS1 = StyleEncoder(args.encoder_latent_dim, args.unet_ch, 2).to(device) netS1.apply(weights_init) #netS2 = StyleEncoder(args.encoder_latent_dim, args.unet_ch, 2, need_skips=False).to(device) #netS2.apply(weights_init)
transform=transform, return_masked_image=True) #val_dataset = DeepFashion2Dataset(os.path.join(args.data_root, 'validation'), transform=transform, return_masked_image= True ) print('Loading Dataset...') sys.stdout.flush() train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) #val_loader = data_utils.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True) test_batch = next(iter(train_loader)) fixed_test_images = test_batch[2].to(device) fixed_test_masks = test_batch[1].to(device) _ = vutils.save_image(test_batch[0].data[:16], '!test.png', normalize=True) ##MODE netD = MultiscaleDiscriminator(args.mask_channels + 3).to(device) netD.apply(weights_init) netG = GauGANUnetGenerator(args.mask_channels, args.encoder_latent_dim, 2, args.unet_ch).to(device) netG.apply(weights_init) netE = UnetEncoder(args.encoder_latent_dim, args.unet_ch, 2).to(device) netE.apply(weights_init) vgg = Vgg19Full().to(device) vgg.eval()
def RL_restore(blurry_tensor, initial_output, kernels, masks, n_iters, GPU, SAVE_INTERMIDIATE=True, method='basic', saturation_threshold=0.7, reg_factor=1e-3, gamma_correction_factor=1.0, isDebug=False): epsilon = 1e-6 deblug_folder = './RL_deblug' if not os.path.exists(deblug_folder): os.makedirs(deblug_folder) blurry_tensor_ph = blurry_tensor**gamma_correction_factor # to photons space output = initial_output kernels_flipped = torch.flip(kernels, dims=(2, 3)) for it in range(n_iters): output_ph = output**gamma_correction_factor output_reblurred_ph = forward_reblur(output_ph, kernels, masks, GPU, size='same', padding_mode='reflect', manage_saturated_pixels=False, max_value=1) #output_reblurred = output_reblurred_ph**(1.0/gamma_correction_factor) # to pixels space if method == 'basic': relative_blur = torch.div(blurry_tensor_ph, output_reblurred_ph + epsilon) elif method == 'function': R = apply_saturation_function(output_reblurred_ph, max_value=1) R_prima = apply_saturation_function(output_reblurred_ph, max_value=1, get_derivative=True) relative_blur = torch.div(blurry_tensor_ph * R_prima, R + epsilon) + 1 - R_prima elif method == 'masked': mask = blurry_tensor < saturation_threshold relative_blur = torch.div( blurry_tensor_ph * mask, output_reblurred_ph + epsilon) + 1 - mask.float() error_estimate = forward_reblur(relative_blur, kernels_flipped, masks, GPU, size='same', padding_mode='reflect', manage_saturated_pixels=False, max_value=1) output_ph = output_ph * error_estimate J_reg_grad = reg_factor * normalised_gradient_divergence(output) output = output_ph**(1.0 / gamma_correction_factor) * (1.0 / (1 - J_reg_grad)) if isDebug: # compute ||K*I -B|| # reblur_loss = model.reblurLoss(2*output_reblurred, model.real_A[:,:, K//2:-K//2+1, K//2:-K//2+1]) * opt.lambda_reblur reblur_loss = torch.mean( (output_reblurred_ph**(1.0 / gamma_correction_factor) - blurry_tensor)**2) PSNR_reblur = 10 * np.log10(1 / reblur_loss.item()) print('PSNR_reblur', PSNR_reblur) if (((SAVE_INTERMIDIATE and it % np.max([1, n_iters // 10]) == 0) or it == (n_iters - 1))): if it == (n_iters - 1): filename = os.path.join( deblug_folder, 'iter_%06i_restored.png' % (n_iters)) else: filename = os.path.join(deblug_folder, 'iter_%06i.png' % it) save_image(tensor2im(output[0].detach().clamp(0, 1) - 0.5), filename) print(it, 'PSNR_reblur: ', PSNR_reblur.item()) return output
blurry_image = gray2rgb(blurry_image) new_shape = (int(args.resize_factor * M), int(args.resize_factor * N), C) blurry_image = resize(blurry_image, new_shape).astype(np.float32) initial_image = blurry_image.copy() blurry_tensor = transforms.ToTensor()(blurry_image) blurry_tensor = blurry_tensor[None, :, :, :] blurry_tensor = blurry_tensor.cuda(args.gpu_id) initial_restoration_tensor = transforms.ToTensor()(initial_image) initial_restoration_tensor = initial_restoration_tensor[None, :, :, :] initial_restoration_tensor = initial_restoration_tensor.cuda(args.gpu_id) save_image(tensor2im(initial_restoration_tensor[0] - 0.5), os.path.join(args.output_folder, img_name + '.png')) with torch.no_grad(): blurry_tensor_to_compute_kernels = blurry_tensor**args.gamma_factor - 0.5 kernels, masks = two_heads(blurry_tensor_to_compute_kernels) save_kernels_grid( blurry_tensor[0], kernels[0], masks[0], os.path.join(args.output_folder, img_name + '_kernels' + '.png')) output = initial_restoration_tensor with torch.no_grad(): if args.saturation_method == 'combined': output = combined_RL_restore( blurry_tensor,