def main(args): G_1 = Generator_lr(in_channels=3) SR = EDSR(n_colors=3) # load pretrained model G_1.load_state_dict( torch.load(os.path.join(args.weights_dir, 'final_weights_G_1.pkl'))) SR.load_state_dict( torch.load(os.path.join(args.weights_dir, 'final_weights_SR.pkl'))) G_1.cuda() G_1.eval() SR.cuda() SR.eval() # predict os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'clean'), exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'SR'), exist_ok=True) for image_name in tqdm(os.listdir(args.data_path)): # read file image = Image.open(os.path.join(args.data_path, image_name)) # denoise clean_image = resolv_deonoise(G_1, image) clean_image.save(os.path.join(args.output_dir, 'clean', image_name)) # SR sr_image = resolv_sr(G_1, SR, image) sr_image.save(os.path.join(args.output_dir, 'SR', image_name))
def main(args): os.makedirs(args.log_dir, exist_ok=True) # create models G_1 = Generator_lr(in_channels=3) G_2 = Generator_lr(in_channels=3) D_1 = Discriminator_lr(in_channels=3, in_h=16, in_w=16) SR = EDSR(n_colors=3) G_3 = Generator_sr(in_channels=3) D_2 = Discriminator_sr(in_channels=3, in_h=64, in_w=64) for model in [G_1, G_2, D_1, SR, G_3, D_2]: model.cuda() model.train() # tensorboard writer = SummaryWriter(log_dir=args.log_dir) # create optimizors optim = { 'G_1': torch.optim.Adam(params=filter(lambda p: p.requires_grad, G_1.parameters()), lr=args.lr * 5), 'G_2': torch.optim.Adam(params=filter(lambda p: p.requires_grad, G_2.parameters()), lr=args.lr * 5), 'D_1': torch.optim.Adam(params=filter(lambda p: p.requires_grad, D_1.parameters()), lr=args.lr), 'SR': torch.optim.Adam(params=filter(lambda p: p.requires_grad, SR.parameters()), lr=args.lr * 5), 'G_3': torch.optim.Adam(params=filter(lambda p: p.requires_grad, G_3.parameters()), lr=args.lr), 'D_2': torch.optim.Adam(params=filter(lambda p: p.requires_grad, D_2.parameters()), lr=args.lr) } for key in optim.keys(): optim[key].zero_grad() # get dataloader train_dataset = DIV2KDataset(root=args.data_path) trainloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=3) print('-' * 20) print('Start training') print('-' * 20) iter_index = 0 for epoch in range(args.epochs): G_1.train() SR.train() start = timeit.default_timer() for _, batch in enumerate(trainloader): iter_index += 1 image, label_hr, label_lr = batch image = image.cuda() label_hr = label_hr.cuda() label_lr = label_lr.cuda() '''loss for lr GAN''' '''update G_1 and G_2''' for key in optim.keys(): optim[key].zero_grad() # D loss for D_1 image_clean = G_1(image) loss_D1 = discriminator_loss(discriminator=D_1, fake=image_clean, real=label_lr) loss_D1.backward() optim['D_1'].step() # GD loss for G_1 loss_G1 = generator_discriminator_loss(generator=G_1, discriminator=D_1, input=image) loss_G1.backward() # cycle loss for G_1 and G_2 loss_cycle = 10 * cycle_loss(G_1, G_2, image) loss_cycle.backward() # idt loss for G_1 loss_idt = 5 * identity_loss(clean_image=label_lr, generator=G_1) loss_idt.backward() # tvloss for G_1 loss_tv = 0.5 * tvloss(input=image, generator=G_1) loss_tv.backward() # optimize G_1 and G_2 optim['G_1'].step() optim['G_2'].step() if iter_index % 100 == 0: print( 'iter {}: LR: loss_D1={}, loss_GD={}, loss_cycle={}, loss_idt={}, loss_tv={}' .format(iter_index, loss_D1.item(), loss_G1.item(), loss_cycle.item(), loss_idt.item(), loss_tv.item())) writer.add_scalar('LR/loss_D1', loss_D1.item(), iter_index // 100) writer.add_scalar('LR/loss_GD', loss_G1.item(), iter_index // 100) writer.add_scalar('LR/loss_cycle', loss_cycle.item(), iter_index // 100) writer.add_scalar('LR/loss_idt', loss_idt.item(), iter_index // 100) writer.add_scalar('LR/loss_tv', loss_tv.item(), iter_index // 100) writer.add_image('LR/origin', image[0], iter_index // 100) writer.add_image('LR/denoise', G_1(image)[0], iter_index // 100) '''loss for sr GAN''' '''update G_1, SR and G_3''' for key in optim.keys(): optim[key].zero_grad() image_clean = G_1(image).detach() # D loss for D_2 image_sr = SR(image_clean) loss_D2 = discriminator_loss(discriminator=D_2, fake=image_sr, real=label_hr) loss_D2.backward() optim['D_2'].step() # GD loss for SR loss_SR = generator_discriminator_loss(generator=SR, discriminator=D_2, input=image_clean) loss_SR.backward() # cycle loss for SR and G_3 loss_cycle = 10 * cycle_loss(SR, G_3, image_clean) loss_cycle.backward() # idt loss for SR loss_idt = 5 * identity_loss_sr( clean_image_lr=label_lr, clean_image_hr=label_hr, generator=SR) loss_idt.backward() # tvloss for SR loss_tv = 0.5 * tvloss(input=image_clean, generator=SR) loss_tv.backward() # optimize G_1, SR and G_3 optim['G_1'].step() optim['SR'].step() optim['G_3'].step() if iter_index % 100 == 0: print( ' SR: loss_D2={}, loss_SR={}, loss_cycle={}, loss_idt={}, loss_tv={}' .format(loss_D2.item(), loss_SR.item(), loss_cycle.item(), loss_idt.item(), loss_tv.item())) writer.add_scalar('SR/loss_D2', loss_D2.item(), iter_index // 100) writer.add_scalar('SR/loss_SR', loss_SR.item(), iter_index // 100) writer.add_scalar('SR/loss_cycle', loss_cycle.item(), iter_index // 100) writer.add_scalar('SR/loss_idt', loss_idt.item(), iter_index // 100) writer.add_scalar('SR/loss_tv', loss_tv.item(), iter_index // 100) writer.add_image('SR/origin', image[0], iter_index // 100) writer.add_image('SR/clean_image', G_1(image)[0], iter_index // 100) writer.add_image('SR/SR', SR(G_1(image))[0], iter_index // 100) writer.flush() end = timeit.default_timer() print('epoch {}, using {} seconds'.format(epoch, end - start)) G_1.eval() SR.eval() image = Image.open('/data/data/DIV2K/unsupervised/lr/0001x4d.png') sr_image = resolv_sr(G_1, SR, image) # image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).cuda() # sr_image_tensor = SR(G_1(image_tensor).detach()) # sr_image = torchvision.transforms.functional.to_pil_image(sr_image_tensor[0].cpu()) sr_image.save( os.path.join(args.log_dir, '0001x4d_sr_{}.png'.format(str(epoch)))) torch.save(G_1.state_dict(), os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_1.pkl')) torch.save(G_2.state_dict(), os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_2.pkl')) torch.save(D_1.state_dict(), os.path.join(args.log_dir, 'ep-' + str(epoch) + '_D_1.pkl')) torch.save(SR.state_dict(), os.path.join(args.log_dir, 'ep-' + str(epoch) + '_SR.pkl')) torch.save(G_3.state_dict(), os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_3.pkl')) torch.save(D_2.state_dict(), os.path.join(args.log_dir, 'ep-' + str(epoch) + '_D_2.pkl')) writer.close() print('Training done.') torch.save(G_1.state_dict(), os.path.join(args.log_dir, 'final_weights_G_1.pkl')) torch.save(G_2.state_dict(), os.path.join(args.log_dir, 'final_weights_G_2.pkl')) torch.save(D_1.state_dict(), os.path.join(args.log_dir, 'final_weights_D_1.pkl')) torch.save(SR.state_dict(), os.path.join(args.log_dir, 'final_weights_SR.pkl')) torch.save(G_3.state_dict(), os.path.join(args.log_dir, 'final_weights_G_3.pkl')) torch.save(D_2.state_dict(), os.path.join(args.log_dir, 'final_weights_D_2.pkl')) image = Image.open('/data/data/DIV2K/unsupervised/lr/0001x4d.png') image.save(os.path.join(args.log_dir, '0001x4d.png')) sr_image = resolv_sr(G_1, SR, image) # image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).cuda() # sr_image_tensor = SR(G_1(image_tensor)) # sr_image = torchvision.transforms.functional.to_pil_image(sr_image_tensor[0].cpu()) sr_image.save(os.path.join(args.log_dir, '0001x4d_sr.png'))
def create_reconstruction(): input = Input(shape=(32, 32, 3)) model = Model(inputs=input, outputs=EDSR(input, config.filters, config.nBlocks)) return model
def load_edsr(device, n_resblocks=16, n_feats=64, model_details=True): """ Loads the EDSR model Parameters ---------- device : str device type. n_resblocks : int, optional number of res_blocks. The default is 16. n_feats : int, optional number of features. The default is 64. Returns ------- model : torch.nn.model EDSR model. """ args = { "G0": 64, "RDNconfig": "B", "RDNkSize": 3, "act": "relu", "batch_size": 16, "betas": (0.9, 0.999), "chop": True, "cpu": True, "data_range": "1-800/801-810", "data_test": ["Demo"], "data_train": ["DIV2K"], "debug": False, "decay": "200", "dilation": False, "dir_data": "../../../dataset", "dir_demo": "../test", "epochs": 300, "epsilon": 1e-08, "ext": "sep", "extend": ".", "gamma": 0.5, "gan_k": 1, "gclip": 0, "load": "", "loss": "1*L1", "lr": 0.0001, "model": "EDSR", "momentum": 0.9, "n_GPUs": 1, "n_colors": 3, "n_feats": 64, "n_resblocks": 16, "n_resgroups": 10, "n_threads": 6, "no_augment": False, "optimizer": "ADAM", "patch_size": 192, "pre_train": "download", "precision": "single", "print_every": 100, "reduction": 16, "res_scale": 1, "reset": False, "resume": 0, "rgb_range": 255, "save": "test", "save_gt": False, "save_models": False, "save_results": True, "scale": [4], "seed": 1, "self_ensemble": False, "shift_mean": True, "skip_threshold": 100000000.0, "split_batch": 1, "template": ".", "test_every": 1000, "test_only": True, "weight_decay": 0, } model = edsr.make_model(args).to(device) edsr.load(model) if model_details: pass return model
if not os.path.exists("data"): print("Downloading flower dataset...") subprocess.check_output( "mkdir data && curl https://storage.googleapis.com/wandb/flower-enhance.tar.gz | tar xz -C data", shell=True) config.steps_per_epoch = len( glob.glob(config.train_dir + "/*-in.jpg")) // config.batch_size config.val_steps_per_epoch = len( glob.glob(config.val_dir + "/*-in.jpg")) // config.batch_size # Neural network input1 = Input(shape=(config.input_height, config.input_width, 3), dtype='float32') model = Model(inputs=input1, outputs=EDSR(input1, config.filters, config.nBlocks)) #print(model.summary()) #model.load_weights('edsr.h5') #es = EarlyStopping(monitor='val_perceptual_distance', mode='min', verbose = 1, patience=2) mc = ModelCheckpoint('edsr.h5', monitor='val_perceptual_distance', mode='min', save_best_only=True) ##DONT ALTER metrics=[perceptual_distance] model.compile(optimizer='adam', loss=[perceptual_distance], metrics=[perceptual_distance])
def main(args): # Create directories if it's not hyper-optimisation round. if not args.is_optimisation: results_directory = f'results/result_{args.experiment_num}' os.makedirs('images', exist_ok=True) os.makedirs(results_directory, exist_ok=True) # Save arguments for experiment reproducibility. with open(os.path.join(results_directory, 'arguments.txt'), 'w') \ as file: json.dump(args.__dict__, file, indent=2) # Set size for plots. plt.rcParams['figure.figsize'] = (10, 10) # Select the device to train the model on. device = torch.device(args.device) # Load the dataset. # TODO : Add normalisation transforms.Normalize( # torch.tensor(-4.4713e-07).float(), # torch.tensor(0.1018).float()) # TODO: Add more data augmentation transforms. data_transforms = transforms.Compose([ # RandomHorizontalFlip(), ToTensor() ]) dataset = Data(args.filename_x, args.filename_y, args.data_root, transform=data_transforms) if not args.is_optimisation: print(f"Data sizes, input: {dataset.input_dim}, output: " f"{dataset.output_dim}, Fk: {dataset.output_dim_fk}") train_data, test_data = split_dataset( dataset, args.test_percentage + args.val_percentage) test_data, val_data = split_dataset(test_data, 0.5) # Initialize generator model. if args.model == 'SRCNN': generator = SRCNN(input_dim=dataset.input_dim, output_dim=dataset.output_dim).to(device) elif args.model == 'EDSR': generator = EDSR(args.latent_dim, args.num_res_blocks, output_dim=dataset.output_dim).to(device) elif args.model == 'VDSR': generator = VDSR(args.latent_dim, args.num_res_blocks, output_dim=dataset.output_dim).to(device) # Optimizers optim_G = optim.Adam(generator.parameters(), lr=args.lr) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optim_G, patience=args.scheduler_patience, verbose=True) # Initialize optional Fk discriminator and optimizer. # losses type criterion_dictionary = { "MSE": nn.MSELoss(), "L1": nn.L1Loss(), } reconstruction_criterion = criterion_dictionary[args.criterion_type] # Initialize a dict of empty lists for plotting. plot_log = defaultdict(list) for epoch in range(args.n_epochs): # Train model for one epoch. loss = iter_epoch((generator), (optim_G), train_data, device, batch_size=args.batch_size, reconstruction_criterion=reconstruction_criterion, use_fk_loss=args.use_fk_loss) # Report model performance. if not args.is_optimisation: print(f"Epoch: {epoch}, Loss: {loss['G']}, " f"PSNR: {loss['psnr']}") # SSIM: {loss['ssim']}") plot_log['G'].append(loss['G']) # Model evaluation every eval_iteration and last iteration. if epoch % args.eval_interval == 0 \ or (args.is_optimisation and epoch == args.n_epochs - 1): loss_val = iter_epoch( (generator), (None), val_data, device, batch_size=args.batch_size, eval=True, reconstruction_criterion=reconstruction_criterion, use_fk_loss=args.use_fk_loss) if not args.is_optimisation: print(f"Validation on epoch: {epoch}, Loss: {loss_val['G']}, " f" PSNR: {loss_val['psnr']}" ) #, SSIM: {loss_val['ssim']}") plot_log['G_val'].append(loss_val['G']) plot_log['psnr_val'].append(loss_val['psnr']) # plot_log['ssim_val'].append(loss_val['ssim']) # Update scheduler based on PSNR or separate model losses. if args.is_psnr_step: scheduler_g.step(loss_val['psnr']) else: scheduler_g.step(loss_val['G']) if not args.is_optimisation: pass # save_loss_plot(plot_log['G_val'], results_directory, is_val=True) if not args.is_optimisation: # Plot results. if epoch % args.save_interval == 0: plot_samples(generator, val_data, epoch, device, results_directory) plot_samples(generator, train_data, epoch, device, results_directory, is_train=True) save_loss_plot(plot_log['G'], results_directory) if not args.is_optimisation: # Save the trained generator model. torch.save(generator, os.path.join(results_directory, 'generator.pth')) if args.save_test_dataset: sets_name = ['test', 'val', 'train'] sets = [test_data, val_data, train_data] for name, d_set in zip(sets_name, sets): list_x = [] list_y = [] for sample in d_set: list_x.append(sample['x'].unsqueeze(0)) list_y.append(sample['y'].unsqueeze(0)) tensor_x = torch.cat(list_x, 0) tensor_y = torch.cat(list_y, 0) data_folder_for_results = 'final/data' os.makedirs(data_folder_for_results, exist_ok=True) torch.save( tensor_x, f'{data_folder_for_results}/{name}_data_x_{args.experiment_num}.pt' ) torch.save( tensor_y, f'{data_folder_for_results}/{name}_data_y_{args.experiment_num}.pt' ) return plot_log, generator, test_data if args.is_optimisation: __, test_data = random_split(test_data, [len(test_data) - 2, 2]) return plot_log, generator, test_data