def main(): ############################################ # # set options # ############################################ parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) ############################################ # # distributed training settings # ############################################ if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() print("Rank:", rank) print("------------------DIST-------------------------") ############################################ # # loading resume state if exists # ############################################ if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None ############################################ # # mkdir and loggers # ############################################ if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('base_val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_', level=logging.INFO, screen=True) print("set train log") util.setup_logger('base_val', opt['path']['log'], 'val_', level=logging.INFO, screen=True) print("set val log") logger = logging.getLogger('base') logger_val = logging.getLogger('base_val') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ############################################ # # create train and val dataloader # ############################################ #### # dataset_ratio = 200 # enlarge the size of each epoch, todo: what it is dataset_ratio = 1 # enlarge the size of each epoch, todo: what it is for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) # total_iters = int(opt['train']['niter']) # total_epochs = int(math.ceil(total_iters / train_size)) total_iters = train_size total_epochs = int(opt['train']['epoch']) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) total_epochs = int(opt['train']['epoch']) if opt['train']['enable'] == False: total_epochs = 1 else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None ############################################ # # create model # ############################################ #### model = create_model(opt) print("Model Created! ") #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 print("Not Resume Training") ############################################ # # training # ############################################ #### #### logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) Avg_train_loss = AverageMeter() # total if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_ssim = AverageMeter() Avg_train_loss_vmaf = AverageMeter() elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss_msssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_msssim = AverageMeter() saved_total_loss = 10e10 saved_total_PSNR = -1 for epoch in range(start_epoch, total_epochs): ############################################ # # Start a new epoch # ############################################ # Turn into training mode #model = model.train() # reset total loss Avg_train_loss.reset() current_step = 0 if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss_pix.reset() Avg_train_loss_ssim.reset() elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss_pix.reset() Avg_train_loss_ssim.reset() Avg_train_loss_vmaf.reset() elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss_ssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss_msssim = AverageMeter() elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss_pix = AverageMeter() Avg_train_loss_msssim = AverageMeter() if opt['dist']: train_sampler.set_epoch(epoch) for train_idx, train_data in enumerate(train_loader): if 'debug' in opt['name']: img_dir = os.path.join(opt['path']['train_images']) util.mkdir(img_dir) LQ = train_data['LQs'] GT = train_data['GT'] GT_img = util.tensor2img(GT) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT')) util.save_img(GT_img, save_img_path) for i in range(5): LQ_img = util.tensor2img(LQ[0, i, ...]) # uint8 save_img_path = os.path.join( img_dir, '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ', i)) util.save_img(LQ_img, save_img_path) if (train_idx >= 3): break if opt['train']['enable'] == False: message_train_loss = 'None' break current_step += 1 if current_step > total_iters: print("Total Iteration Reached !") break #### update learning rate if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': pass else: model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': # model.optimize_parameters_without_schudlue(current_step) # else: model.optimize_parameters(current_step) if (opt['train']['pixel_criterion'] == 'cb+ssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) Avg_train_loss_vmaf.update(model.log_dict['vmaf_loss'], 1) elif (opt['train']['pixel_criterion'] == 'ssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'msssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1) elif (opt['train']['pixel_criterion'] == 'cb+msssim'): Avg_train_loss.update(model.log_dict['total_loss'], 1) Avg_train_loss_pix.update(model.log_dict['l_pix'], 1) Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1) else: Avg_train_loss.update(model.log_dict['l_pix'], 1) # add total train loss if (opt['train']['pixel_criterion'] == 'cb+ssim'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' vmaf_avg_loss: {:.4e}'.format( Avg_train_loss_vmaf.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'ssim'): message_train_loss = ' ssim_avg_loss: {:.4e}'.format( Avg_train_loss_ssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'msssim'): message_train_loss = ' msssim_avg_loss: {:.4e}'.format( Avg_train_loss_msssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) elif (opt['train']['pixel_criterion'] == 'cb+msssim'): message_train_loss = ' pix_avg_loss: {:.4e}'.format( Avg_train_loss_pix.avg) message_train_loss += ' msssim_avg_loss: {:.4e}'.format( Avg_train_loss_msssim.avg) message_train_loss += ' total_avg_loss: {:.4e}'.format( Avg_train_loss.avg) else: message_train_loss = ' train_avg_loss: {:.4e}'.format( Avg_train_loss.avg) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) message += message_train_loss if rank <= 0: logger.info(message) ############################################ # # end of one epoch, save epoch model # ############################################ #### save models and training states # if current_step % opt['logger']['save_checkpoint_freq'] == 0: # if rank <= 0: # logger.info('Saving models and training states.') # model.save(current_step) # model.save('latest') # # model.save_training_state(epoch, current_step) # # todo delete previous weights # previous_step = current_step - opt['logger']['save_checkpoint_freq'] # save_filename = '{}_{}.pth'.format(previous_step, 'G') # save_path = os.path.join(opt['path']['models'], save_filename) # if os.path.exists(save_path): # os.remove(save_path) if epoch == 1: save_filename = '{:04d}_{}.pth'.format(0, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G') save_path = os.path.join(opt['path']['models'], save_filename) if os.path.exists(save_path): os.remove(save_path) if rank <= 0: logger.info('Saving models and training states.') save_filename = '{:04d}'.format(epoch) model.save(save_filename) # model.save('latest') # model.save_training_state(epoch, current_step) ############################################ # # end of one epoch, do validation # ############################################ #### validation #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['datasets'].get('val', None): if opt['model'] in [ 'sr', 'srgan' ] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) #util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) else: # video restoration validation if opt['dist']: # todo : multi-GPU testing psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): print('idx', idx) if 'debug' in opt['name']: if (idx >= 3): break val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device='cuda') if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data) # model.test() # model.test_stitch() if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # visuals = model.get_current_visuals() visuals = model.get_current_visuals( save=True, name='{}_{}'.format(folder, idx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder][idx_d] = psnr # calculate SSIM ssim = util.calculate_ssim(rlt_img, gt_img) ssim_rlt[folder][idx_d] = ssim # calculate Val loss val_loss = model.get_loss() val_loss_rlt[folder][idx_d] = val_loss logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, idx, psnr, ssim)) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) for _, v in ssim_rlt.items(): dist.reduce(v, 0) for _, v in val_loss_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # ssim ssim_rlt_avg = {} ssim_total_avg = 0. for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = torch.mean(v).cpu().item() ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # added val_loss_rlt_avg = {} val_loss_total_avg = 0. for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = torch.mean(v).cpu().item() val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: # Todo: our function One GPU pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames psnr_rlt_avg = {} psnr_total_avg = 0. ssim_rlt = {} # with border and center frames ssim_rlt_avg = {} ssim_total_avg = 0. val_loss_rlt = {} val_loss_rlt_avg = {} val_loss_total_avg = 0. for val_inx, val_data in enumerate(val_loader): if 'debug' in opt['name']: if (val_inx >= 5): break folder = val_data['folder'][0] # idx_d = val_data['idx'].item() idx_d = val_data['idx'] # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = [] if val_loss_rlt.get(folder, None) is None: val_loss_rlt[folder] = [] # process the black blank [B N C H W] print(val_data['LQs'].size()) H_S = val_data['LQs'].size(3) # 540 W_S = val_data['LQs'].size(4) # 960 print(H_S) print(W_S) blank_1_S = 0 blank_2_S = 0 print(val_data['LQs'][0, 2, 0, :, :].size()) for i in range(H_S): if not sum(val_data['LQs'][0, 2, 0, i, :]) == 0: blank_1_S = i - 1 # assert not sum(data_S[:, :, 0][i+1]) == 0 break for i in range(H_S): if not sum(val_data['LQs'][0, 2, 0, :, H_S - i - 1]) == 0: blank_2_S = (H_S - 1) - i - 1 # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0 break print('LQ :', blank_1_S, blank_2_S) if blank_1_S == -1: print('LQ has no blank') blank_1_S = 0 blank_2_S = H_S # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:] print("LQ", val_data['LQs'].size()) # end of process the black blank model.feed_data(val_data) if opt['stitch'] == True: model.test_stitch() else: model.test() # large GPU memory # process blank blank_1_L = blank_1_S << 2 blank_2_L = blank_2_S << 2 print(blank_1_L, blank_2_L) print(model.fake_H.size()) if not blank_1_S == 0: # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:] model.fake_H[:, :, 0:blank_1_L, :] = 0 model.fake_H[:, :, blank_2_L:H_S, :] = 0 # end of # process blank visuals = model.get_current_visuals( save=True, name='{}_{:02d}'.format(folder, val_inx), save_path=opt['path']['val_images']) rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) # calculate SSIM ssim = util.calculate_ssim(rlt_img, gt_img) ssim_rlt[folder].append(ssim) # val loss val_loss = model.get_loss() val_loss_rlt[folder].append(val_loss.item()) logger.info( '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format( folder, val_inx, psnr, ssim)) pbar.update('Test {} - {}'.format(folder, idx_d)) # average PSNR for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4e}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average SSIM for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # SSIM: {:.4e}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4e}'.format(k, v) logger.info(log_s) # average VMAF # average Val LOSS for k, v in val_loss_rlt.items(): val_loss_rlt_avg[k] = sum(v) / len(v) val_loss_total_avg += val_loss_rlt_avg[k] val_loss_total_avg /= len(val_loss_rlt) log_l = '# Validation # Loss: {:.4e}:'.format( val_loss_total_avg) for k, v in val_loss_rlt_avg.items(): log_l += ' {}: {:.4e}'.format(k, v) logger.info(log_l) # toal validation log message = '' for v in model.get_current_learning_rate(): message += '{:.5e}'.format(v) logger_val.info( 'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}' .format(epoch, message, psnr_total_avg, ssim_total_avg, message_train_loss, val_loss_total_avg)) # end add if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step) # for k, v in ssim_rlt_avg.items(): # tb_logger.add_scalar(k, v, current_step) # add val loss tb_logger.add_scalar('val_loss_avg', val_loss_total_avg, current_step) for k, v in val_loss_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) ############################################ # # end of validation, save model # ############################################ # logger.info("Finished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_loss_total_avg: saved_total_loss = val_loss_total_avg #torch.save(model.state_dict(), args.save_path + "/best" + ".pth") model.save('best') logger.info( "Best Weights updated for decreased validation loss") else: logger.info( "Weights Not updated for undecreased validation loss") if saved_total_PSNR <= psnr_total_avg: saved_total_PSNR = psnr_total_avg model.save('bestPSNR') logger.info( "Best Weights updated for increased validation PSNR") else: logger.info( "Weights Not updated for unincreased validation PSNR") ############################################ # # end of one epoch, schedule LR # ############################################ # add scheduler todo if opt['train']['lr_scheme'] == 'ReduceLROnPlateau': for scheduler in model.schedulers: # scheduler.step(val_loss_total_avg) scheduler.step(val_loss_total_avg) if rank <= 0: logger.info('Saving the final model.') model.save('last') logger.info('End of training.') tb_logger.close()
def main(): # Load the CSV dataframe print("------Loading Dataframe---------") data_df = data.load(DATA_FILE) print("------DataFrame Loading DONE--------") # Convert to numpy and normalize print("------Normalizing Data---------") train_data = data.normalize(data_df, NORMALIZE) print("------Normalizing Data DONE---------") # Create Pytorch dataloader dataloader = data.create_dataloader(train_data, BATCH_SIZE, DEVICE, NOISE, NOISE_PARAM) print("------Created Dataloader---------") N, ninp = train_data.shape # CREATE MODEL if BOTTLENECK: net = model.DACBottle(ninp, NHID, NBOT, NHLAYERS, BOTTLENECK, RESNET_TRICK, RANDOM_SEED).to(DEVICE) else: net = model.DAC(ninp, NHID, NHLAYERS, RESNET_TRICK, RANDOM_SEED).to(DEVICE) print("------Loaded Model---------") N, ninp = train_data.shape if ((USE_EXISTING_CHECKPOINT or MODE == 'generate') and os.path.isfile(CHECKPOINT_FILE)): print("----------Loading CHECKPOINT-----------------") checkpoint = torch.load(CHECKPOINT_FILE) net.load_state_dict(checkpoint['model_state_dict']) print("----------Loading CHECKPOINT DONE-----------------") if MODE == 'train': # GET NORM DATA WITH NOISE AND GENERATE PREDICTIONS print("-----------Starting training--------------") trainer = train.Trainer(net, LR, LR_DECAY, REG) best_loss = np.inf for i in range(EPOCHS): for bx, by in dataloader: bx = bx.to(DEVICE) by = by.to(DEVICE) loss = trainer.step(bx, by) if i % PRINT_EVERY == 0: print(f"Epoch: {i}\t Training Loss: {loss}") if loss < best_loss: best_loss = loss torch.save({'model_state_dict': net.state_dict()}, CHECKPOINT_FILE) print("-----------Training DONE--------------") elif MODE == 'generate': # GET CLEAN NORM DATA AND GENERATE FEATURES FROM ACTIVATIONS print("----------Generating FEATURES-----------------") model.eval() with torch.no_grad(): all_data = [] for bx, by in dataloader: x = bx.to(DEVICE) out = net.generate(x) if len(all_data) == 0: all_data = out else: all_data = np.vstack((all_data, out)) np.savetxt(OUT_FILE, all_data, delimiter=",") print("----------FEATURES generated and saved to file------------")
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.') opt = option.parse(parser.parse_args().opt, is_train=False) #util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G')) opt = option.dict_to_nonedict(opt) #util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True) #logger = logging.getLogger('base') #logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) #logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) test_loaders.append(test_loader) # Create model model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] print('\nTesting [{:s}]...'.format(test_set_name)) #logger.info('\nTesting [{:s}]...'.format(test_set_name)) test_start_time = time.time() #dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) dataset_dir = test_loader.dataset.opt['dataroot_HR'] #util.mkdir(dataset_dir) idx = 0 for data in test_loader: idx += 1 need_HR = False #if test_loader.dataset.opt['dataroot_HR'] is None else True model.feed_data(data, need_HR=need_HR) img_path = data['LR_path'][0] print('img_path', img_path) sys.stdout.flush() img_name = os.path.splitext(os.path.basename(img_path))[0] print('img_name', img_name) sys.stdout.flush() model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) sr_img = util.tensor2img(visuals['SR']) # uint8 # save images baseinput = os.path.splitext(os.path.basename(img_path))[0][:-8] print('baseinput', baseinput) sys.stdout.flush() model_path = opt['path']['pretrain_model_G'] print('model_path', model_path) sys.stdout.flush() modelname = os.path.splitext(os.path.basename(model_path))[0] print('modelname', modelname) sys.stdout.flush() if not os.path.exists('{1:s}/Models/{0:s}/'.format(modelname, dataset_dir)): os.makedirs('{1:s}/Models/{0:s}/'.format(modelname, dataset_dir)) util.save_img(sr_img, '{2:s}/Models/{0:s}/{1:s}.png'.format(modelname, img_name, dataset_dir)) print(idx, img_name) sys.stdout.flush()
def do_train(args): paddle.set_device(args.device) rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() set_seed(args.seed) train_ds = load_dataset(read_custom_data, filename=os.path.join(args.data_dir, "train.txt"), is_test=False, lazy=False) dev_ds = load_dataset(read_custom_data, filename=os.path.join(args.data_dir, "dev.txt"), is_test=False, lazy=False) tokenizer = ErnieCtmTokenizer.from_pretrained("nptag") model = ErnieCtmNptagModel.from_pretrained("nptag") vocab_size = model.ernie_ctm.config["vocab_size"] trans_func = partial(convert_example, tokenzier=tokenizer, max_seq_len=args.max_seq_len) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64' ), # input_ids Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64' ), # token_type_ids Pad(axis=0, pad_val=-100, dtype='int64'), # labels ): fn(samples) train_data_loader = create_dataloader(train_ds, mode="train", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func) dev_data_loader = create_dataloader(dev_ds, mode="dev", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func) if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): state_dict = paddle.load(args.init_from_ckpt) model.set_dict(state_dict) model = paddle.DataParallel(model) num_training_steps = len(train_data_loader) * args.num_train_epochs lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_proportion) decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, epsilon=args.adam_epsilon, parameters=model.parameters(), weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_params) logger.info("Total steps: %s" % num_training_steps) metric = NPTagAccuracy() criterion = paddle.nn.CrossEntropyLoss() global_step = 0 for epoch in range(1, args.num_train_epochs + 1): logger.info(f"Epoch {epoch} beginnig") start_time = time.time() for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, token_type_ids, labels = batch logits = model(input_ids, token_type_ids) loss = criterion(logits.reshape([-1, vocab_size]), labels.reshape([-1])) loss.backward() optimizer.step() optimizer.clear_grad() lr_scheduler.step() if global_step % args.logging_steps == 0 and rank == 0: end_time = time.time() speed = float(args.logging_steps) / (end_time - start_time) logger.info( "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s" % (global_step, epoch, loss.numpy().item(), speed)) start_time = time.time() if (global_step % args.save_steps == 0 or global_step == num_training_steps) and rank == 0: output_dir = os.path.join(args.output_dir, "model_%d" % (global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model._layers.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) evaluate(model, metric, criterion, dev_data_loader, vocab_size)
assert opt.real_stat_path is not None if opt.phase == 'train': warnings.warn('You are using training set for inference.') if __name__ == '__main__': opt = TestOptions().parse() print(' '.join(sys.argv)) if opt.config_str is not None: assert 'super' in opt.netG or 'sub' in opt.netG config = decode_config(opt.config_str) else: assert 'super' not in opt.model config = None dataloader = create_dataloader(opt) model = create_model(opt) model.setup(opt) web_dir = opt.results_dir # define the website directory webpage = html.HTML(web_dir, 'restore_G_path: %s' % (opt.restore_G_path)) fakes, names = [], [] for i, data in enumerate(tqdm.tqdm(dataloader)): model.set_input(data) # unpack data from data loader if i == 0 and opt.need_profile: model.profile(config) model.test(config) # run inference visuals = model.get_current_visuals() # get image results generated = visuals['fake_B'].cpu() fakes.append(generated) for path in model.get_image_paths():
def SR(solver, opt, model_name): # dataset가져오기-많이 걸리면 0.002 bm_names = [] test_loaders = [] for _, dataset_opt in sorted(opt['datasets'].items()): start = time.time() test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) test_loaders.append(test_loader) print( '===> Test Dataset: [%s] Number of images: [%d] elapsed time: %.4f sec' % (test_set.name(), len(test_set), time.time() - start)) bm_names.append(test_set.name()) #Testset개수만큼 SR for bm, test_loader in zip(bm_names, test_loaders): print("Test set : [%s]" % bm) sr_list = [] path_list = [] total_psnr = [] total_ssim = [] total_time = [] scale = 4 need_HR = False if test_loader.dataset.__class__.__name__.find( 'LRHR') < 0 else True for iter, batch in enumerate(test_loader): solver.feed_data(batch, need_HR=need_HR) # 시간측정 t0 = time.time() solver.test() #SR t1 = time.time() total_time.append((t1 - t0)) visuals = solver.get_current_visual(need_HR=need_HR) sr_list.append(visuals['SR']) # calculate PSNR/SSIM metrics on Python if need_HR: psnr, ssim = util.calc_metrics(visuals['SR'], visuals['HR'], crop_border=scale) total_psnr.append(psnr) total_ssim.append(ssim) path_list.append( os.path.basename(batch['HR_path'][0]).replace( 'HR', model_name)) print( "[%d/%d] %s || PSNR(dB)/SSIM: %.2f/%.4f || Timer: %.4f sec ." % (iter + 1, len(test_loader), os.path.basename(batch['LR_path'][0]), psnr, ssim, (t1 - t0))) else: path_list.append(os.path.basename(batch['LR_path'][0])) print("[%d/%d] %s || Timer: %.4f sec ." % (iter + 1, len(test_loader), os.path.basename(batch['LR_path'][0]), (t1 - t0))) if need_HR: print("---- Average PSNR(dB) /SSIM /Speed(s) for [%s] ----" % bm) print("PSNR: %.2f SSIM: %.4f Speed: %.4f" % (sum(total_psnr) / len(total_psnr), sum(total_ssim) / len(total_ssim), sum(total_time) / len(total_time))) else: print("---- Average Speed(s) for [%s] is %.4f sec ----" % (bm, sum(total_time) / len(total_time))) # save SR results for further evaluation on MATLAB if need_HR: save_img_path = os.path.join('./results/SR/' + degrad, model_name, bm, "x%d" % scale) else: save_img_path = os.path.join('./results/SR/' + bm, model_name, "x%d" % scale) if not os.path.exists(save_img_path): os.makedirs(save_img_path) for img, name in zip(sr_list, path_list): s = time.time() #matplotlib.image.save(os.path.join(save_img_path, name),img) cv2.imwrite( os.path.join(save_img_path, name), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) # 0.609sec 평균 이미지 하나당 0.07sec print("NAME: %s DOWNLOAD TIME:%s\n" % (name, time.time() - s)) #save(os.path.join(save_img_path, name),img) 5.3sec #Image.fromarray(img).save(os.path.join(save_img_path, name)) 2.4sec #imageio.imwrite(os.path.join(save_img_path, name), img) 9.8sec print( "===> Total Saving SR images of [%s]... Save Path: [%s] Time: %s\n" % (bm, save_img_path, time.time() - s)) print("==================================================") print("===> Finished !!")
def main(jsonPath): # options opt = option.parse(jsonPath, is_train=False) util.mkdirs((path for key, path in opt["path"].items() if not key == "pretrain_model_G")) opt = option.dict_to_nonedict(opt) util.setup_logger(None, opt["path"]["log"], "test.log", level=logging.INFO, screen=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info("Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set))) test_loaders.append(test_loader) # Create model model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt["name"] logger.info("\nTesting [{:s}]...".format(test_set_name)) # test_start_time = time.time() dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) util.mkdir(dataset_dir) test_results = OrderedDict() test_results["psnr"] = [] test_results["ssim"] = [] test_results["psnr_y"] = [] test_results["ssim_y"] = [] for data in test_loader: need_HR = False if test_loader.dataset.opt[ "dataroot_HR"] is None else True model.feed_data(data, need_HR=need_HR) img_path = data["LR_path"][0] img_name = os.path.splitext(os.path.basename(img_path))[0] model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) sr_img = util.tensor2img(visuals["SR"]) # uint8 # save images suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) # calculate PSNR and SSIM if need_HR: gt_img = util.tensor2img(visuals["HR"]) gt_img = gt_img / 255.0 sr_img = sr_img / 255.0 crop_border = test_loader.dataset.opt["scale"] cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255) test_results["psnr"].append(psnr) test_results["ssim"].append(ssim) if gt_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) gt_img_y = bgr2ycbcr(gt_img, only_y=True) cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border] cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border] psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255) ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255) test_results["psnr_y"].append(psnr_y) test_results["ssim_y"].append(ssim_y) logger.info( "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}." .format(img_name, psnr, ssim, psnr_y, ssim_y)) else: logger.info( "{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format( img_name, psnr, ssim)) else: logger.info(img_name) if need_HR: # metrics # Average PSNR/SSIM results ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"]) ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"]) logger.info( "----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n" .format(test_set_name, ave_psnr, ave_ssim)) if test_results["psnr_y"] and test_results["ssim_y"]: ave_psnr_y = sum(test_results["psnr_y"]) / len( test_results["psnr_y"]) ave_ssim_y = sum(test_results["ssim_y"]) / len( test_results["ssim_y"]) logger.info( "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n" .format(ave_psnr_y, ave_ssim_y))
##### hyperparameters for federated learning ##### num_clients = opt['fed']['num_clients'] num_selected = int(num_clients * opt['fed']['sample_fraction']) num_rounds = opt['fed']['num_rounds'] client_epochs = opt['fed']['epochs'] ##### create dataloader for client and server ##### for phase, dataset_opt in sorted(opt['datasets'].items()): if phase == 'train': train_set = create_dataset(dataset_opt) train_set_split = vutils.data.random_split( train_set, [int(len(train_set) / num_clients) for _ in range(num_clients)]) train_loaders = [create_dataloader(x, dataset_opt) for x in train_set_split] print("=====> Train Dataset: %s" %train_set.name()) print("=====> Number of image in each client: %d" %len(train_set_split[0])) if train_loaders is None: raise ValueError("[Error] The training data does not exist") elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('======> Val Dataset: %s, Number of images: [%d]' %(val_set.name(), len(val_set))) else: raise NotImplementedError("[Error] Dataset phase [%s] in *.json is not recognized." % phase)
def attach_dataloader(self, opt): aux_opt = self._create_auxiliary_opt(opt) self.loader = data.create_dataloader(aux_opt)
def main(): #### setup options of three networks parser = argparse.ArgumentParser() parser.add_argument("-opt", type=str, help="Path to option YMAL file.") parser.add_argument( "--launcher", choices=["none", "pytorch"], default="none", help="job launcher" ) parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) # choose small opt for SFTMD test, fill path of pre-trained model_F #### set random seed # seed = opt["train"]["manual_seed"] # if seed is None: # seed = random.randint(1, 10000) # load PCA matrix of enough kernel print("load PCA matrix") pca_matrix = torch.load( opt["pca_matrix_path"], map_location=lambda storage, loc: storage ) print("PCA matrix shape: {}".format(pca_matrix.shape)) #### distributed training settings if args.launcher == "none": # disabled distributed training opt["dist"] = False opt["dist"] = False rank = -1 print("Disabled distributed training.") else: opt["dist"] = True opt["dist"] = True init_dist() world_size = ( torch.distributed.get_world_size() ) # Returns the number of processes in the current process group rank = torch.distributed.get_rank() # Returns the rank of current process group util.set_random_seed(0) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True ###### Predictor&Corrector train ###### #### loading resume state if exists if opt["path"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt["path"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) option.check_resume(opt, resume_state["iter"]) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0-7) if resume_state is None: # Predictor path util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( ( path for key, path in opt["path"].items() if not key == "experiments_root" and "pretrain_model" not in key and "resume" not in key ) ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"], level=logging.INFO, screen=False, tofile=True, ) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=False, tofile=True, ) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( "You are using PyTorch {}. Tensorboard will use [tensorboardX]".format( version ) ) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) else: util.setup_logger( "base", opt["path"]["log"], "train", level=logging.INFO, screen=False ) logger = logging.getLogger("base") torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt["batch_size"])) total_iters = int(opt["train"]["niter"]) total_epochs = int(math.ceil(total_iters / train_size)) if opt["dist"]: train_sampler = DistIterSampler( train_set, world_size, rank, dataset_ratio ) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio)) ) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), train_size ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, total_iters ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None #### create model model = create_model(opt) # load pretrained model of SFTMD #### resume training if resume_state: logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 prepro = util.SRMDPreprocessing( scale=opt["scale"], pca_matrix=pca_matrix, cuda=True, **opt["degradation"] ) kernel_size = opt["degradation"]["ksize"] padding = kernel_size // 2 #### training logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) for epoch in range(start_epoch, total_epochs + 1): if opt["dist"]: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break LR_img, ker_map, kernels = prepro(train_data["GT"], True) LR_img = (LR_img * 255).round() / 255 model.feed_data( LR_img, GT_img=train_data["GT"], ker_map=ker_map, kernel=kernels ) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) visuals = model.get_current_visuals() if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = "<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> ".format( epoch, current_step, model.get_current_learning_rate() ) for k, v in logs.items(): message += "{:s}: {:.4e} ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank == 0: logger.info(message) # validation, to produce ker_map_list(fake) if current_step % opt["train"]["val_freq"] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 for _, val_data in enumerate(val_loader): # LR_img, ker_map = prepro(val_data['GT']) LR_img = val_data["LQ"] lr_img = util.tensor2img(LR_img) # save LR image for reference # valid Predictor model.feed_data(LR_img, val_data["GT"]) model.test() visuals = model.get_current_visuals() # Save images for reference img_name = val_data["LQ_path"][0] img_dir = os.path.join(opt["path"]["val_images"], img_name) # img_dir = os.path.join(opt['path']['val_images'], str(current_step), '_', str(step)) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["SR"].squeeze()) # uint8 gt_img = util.tensor2img(visuals["GT"].squeeze()) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) kernel = ( visuals["ker"] .numpy() .reshape( opt["degradation"]["ksize"], opt["degradation"]["ksize"] ) ) kernel = 1 / (np.max(kernel) + 1e-4) * 255 * kernel cv2.imwrite(save_img_path, kernel) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt["scale"] gt_img = gt_img / 255.0 sr_img = sr_img / 255.0 cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size] avg_psnr += util.calculate_psnr( cropped_sr_img * 255, cropped_gt_img * 255 ) idx += 1 avg_psnr = avg_psnr / idx # log logger.info("# Validation # PSNR: {:.6f}".format(avg_psnr)) logger_val = logging.getLogger("val") # validation logger logger_val.info( "<epoch:{:3d}, iter:{:8,d}, psnr: {:.6f}".format( epoch, current_step, avg_psnr ) ) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.add_scalar("psnr", avg_psnr, current_step) #### save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank <= 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of Predictor and Corrector training.") tb_logger.close()
def main_worker(gpu, world_size, idx_server, opt): print('Use GPU: {} for training'.format(gpu)) ngpus_per_node = world_size world_size = opt.world_size rank = idx_server * ngpus_per_node + gpu opt.gpu = gpu dist.init_process_group(backend='nccl', init_method=opt.dist_url, world_size=world_size, rank=rank) torch.cuda.set_device(opt.gpu) # load the dataset dataloader = data.create_dataloader(opt, world_size, rank) # create trainer for our model trainer = Pix2PixTrainer(opt) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader), world_size, rank) # create tool for visualization visualizer = Visualizer(opt, rank) for epoch in iter_counter.training_epochs(): # set epoch for data sampler dataloader.sampler.set_epoch(epoch) iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): iter_counter.record_one_iteration() # Training # train generator trainer.run_generator_one_step(data_i) # train discriminator trainer.run_discriminator_one_step(data_i) # Visualizations if iter_counter.needs_printing(): losses = trainer.get_latest_losses() visualizer.print_current_errors(epoch, iter_counter.epoch_iter, losses, iter_counter.time_per_iter) visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) visuals = OrderedDict([('input_label', data_i['label']), ('synthesized_image', trainer.get_latest_generated()), ('real_image', data_i['image'])]) visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) if rank == 0: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') iter_counter.record_current_iter() trainer.update_learning_rate(epoch) iter_counter.record_epoch_end() if (epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs) and (rank == 0): print('saving the model at the end of epoch %d, iters %d' % (epoch, iter_counter.total_steps_so_far)) trainer.save(epoch) print('Training was successfully finished.')
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('--opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args(), is_train=True) opt = option.dict_to_nonedict( opt) # Convert to NoneDict, which return None for missing key. # train from scratch OR resume training if opt['path']['resume_state']: # resuming training resume_state = torch.load(opt['path']['resume_state']) else: # training from scratch resume_state = None util.mkdir_and_rename( opt['path']['experiments_root']) # rename old folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True) util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) logger = logging.getLogger('base') if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) option.check_resume(opt) # check resume options logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benckmark = True # torch.backends.cudnn.deterministic = True # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # create model model = create_model(opt) # resume training if resume_state: start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 # training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs): for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # validation if current_step % opt['train']['val_freq'] == 0: visuals = model.get_current_visuals(isTrain=True) os.makedirs('valid/' + opt['name'], exist_ok=True) util.save_img( util.tensor2img(visuals['LR']), 'valid/' + opt['name'] + '/' + str(current_step) + '_LR.png') util.save_img( util.tensor2img(visuals['HR']), 'valid/' + opt['name'] + '/' + str(current_step) + '_HR.png') util.save_img( util.tensor2img(visuals['SR']), 'valid/' + opt['name'] + '/' + str(current_step) + '_SR.png') util.save_img( util.tensor2img(visuals['SRgray']), 'valid/' + opt['name'] + '/' + str(current_step) + '_SRgray.png') util.save_img( util.tensor2img(visuals['HRgray']), 'valid/' + opt['name'] + '/' + str(current_step) + '_HRgray.png') # update learning rate model.update_learning_rate() # training model.feed_data(train_data) model.optimize_parameters(current_step) # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v, current_step) logger.info(message) # save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
def main(): parser = argparse.ArgumentParser( description='Test Super Resolution Models') parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.') opt = option.parse(parser.parse_args().opt) opt = option.dict_to_nonedict(opt) # initial configure scale = opt['scale'] degrad = opt['degradation'] network_opt = opt['networks'] model_name = network_opt['which_model'].upper() if opt['self_ensemble']: model_name += 'plus' # create test dataloader bm_names = [] test_loaders = [] for _, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) test_loaders.append(test_loader) print('===> Test Dataset: [%s] Number of images: [%d]' % (test_set.name(), len(test_set))) bm_names.append(test_set.name()) # create solver (and load model) solver = create_solver(opt) # Test phase print('===> Start Test') print("==================================================") print("Method: %s || Scale: %d || Degradation: %s" % (model_name, scale, degrad)) for bm, test_loader in zip(bm_names, test_loaders): print("Test set : [%s]" % bm) sr_list = [] path_list = [] total_psnr = [] total_ssim = [] total_time = [] need_HR = False if test_loader.dataset.__class__.__name__.find( 'LRHR') < 0 else True for iter, batch in enumerate(test_loader): solver.feed_data(batch, need_HR=need_HR) print(batch["LR"].shape) # calculate forward time t0 = time.time() solver.test() t1 = time.time() total_time.append((t1 - t0)) visuals = solver.get_current_visual(need_HR=need_HR) sr_list.append(visuals['SR']) # calculate PSNR/SSIM metrics on Python if need_HR: psnr, ssim = util.calc_metrics(visuals['SR'], visuals['HR'], crop_border=scale) print( visuals['SR'].shape, visuals['HR'].shape, "save....", ) SRxxx = PIL.Image.fromarray(visuals['SR']) HRxxx = PIL.Image.fromarray(visuals['HR']) SRxxx.save("sr_%d.jpg" % (iter)) HRxxx.save("hr_%d.jpg" % (iter)) total_psnr.append(psnr) total_ssim.append(ssim) path_list.append( os.path.basename(batch['HR_path'][0]).replace( 'HR', model_name)) print( "[%d/%d] %s || PSNR(dB)/SSIM: %.2f/%.4f || Timer: %.4f sec ." % (iter + 1, len(test_loader), os.path.basename(batch['LR_path'][0]), psnr, ssim, (t1 - t0))) else: path_list.append(os.path.basename(batch['LR_path'][0])) print("[%d/%d] %s || Timer: %.4f sec ." % (iter + 1, len(test_loader), os.path.basename(batch['LR_path'][0]), (t1 - t0))) if need_HR: print("---- Average PSNR(dB) /SSIM /Speed(s) for [%s] ----" % bm) print("PSNR: %.2f SSIM: %.4f Speed: %.4f" % (sum(total_psnr) / len(total_psnr), sum(total_ssim) / len(total_ssim), sum(total_time) / len(total_time))) else: print("---- Average Speed(s) for [%s] is %.4f sec ----" % (bm, sum(total_time) / len(total_time))) # save SR results for further evaluation on MATLAB if need_HR: save_img_path = os.path.join('./results/SR/' + degrad, model_name, bm, "x%d" % scale) else: save_img_path = os.path.join('./results/SR/' + bm, model_name, "x%d" % scale) print("===> Saving SR images of [%s]... Save Path: [%s]\n" % (bm, save_img_path)) if not os.path.exists(save_img_path): os.makedirs(save_img_path) for img, name in zip(sr_list, path_list): imageio.imwrite(os.path.join(save_img_path, name), img) print("==================================================") print("===> Finished !")
for k in range(8): mask = torch.eq(argmax, k) color.select(0, 0).masked_fill_(mask, lookup_table[k][0]) # R color.select(0, 1).masked_fill_(mask, lookup_table[k][1]) # G color.select(0, 2).masked_fill_(mask, lookup_table[k][2]) # B # void mask = torch.eq(argmax, 255) color.select(0, 0).masked_fill_(mask, lookup_table[8][0]) # R color.select(0, 1).masked_fill_(mask, lookup_table[8][1]) # G color.select(0, 2).masked_fill_(mask, lookup_table[8][2]) # B return color util.mkdir('tmp') train_set = create_dataset(opt) train_loader = create_dataloader(train_set, opt) nrow = int(math.sqrt(opt['batch_size'])) if opt['phase'] == 'train': padding = 2 else: padding = 0 for i, data in enumerate(train_loader): # test dataloader time # if i == 1: # start_time = time.time() # if i == 500: # print(time.time() - start_time) # break if i > 5: break
# [src_ids, token_type_ids] predict_batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type_ids ): [data for data in fn(samples)] predict_trans_func = partial( convert_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length, is_test=True) test_data_loader = create_dataloader( test_ds, mode='eval', batch_size=args.batch_size, batchify_fn=predict_batchify_fn, trans_fn=predict_trans_func) # Load parameters of best model on test_public.json of current task if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): state_dict = paddle.load(args.init_from_ckpt) model.set_dict(state_dict) print("Loaded parameters from %s" % args.init_from_ckpt) else: raise ValueError( "Please set --params_path with correct pretrained model file") y_pred_labels = do_predict( model, tokenizer,
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) label_path = opt['datasets']['val']['dataroot_label_file'] #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt,is_train = False) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if rank <= 0: # # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename(val_data['img1_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], str(current_step)) util.mkdir(img_dir) f = open(os.path.join(img_dir, 'predict_score.txt'), 'a') model.feed_data(val_data) model.test() visuals = model.get_current_visuals() predict_score1 = visuals['predict_score1'].numpy() # Save predict scores f.write('%s %f\n' % (img_name + '.png', predict_score1)) f.close() pbar.update('Test {}'.format(img_name)) # calculate accuracy aligned_pair_accuracy, accuracy_esrganbig, accuracy_srganbig = rank_pair_test(\ os.path.join(img_dir, 'predict_score.txt'), label_path) # log logger.info( '# Validation # Accuracy: {:.4e}, Accuracy_pair1_class1: {:.4e}, Accuracy_pair1_class2: {:.4e} '.format( aligned_pair_accuracy, accuracy_esrganbig, accuracy_srganbig)) logger_val = logging.getLogger('val') # validation logger logger_val.info( '<epoch:{:3d}, iter:{:8,d}> Accuracy: {:.4e}, Accuracy_pair1_class1: {:.4e}, Accuracy_pair1_class2: {:.4e} '.format( epoch, current_step, aligned_pair_accuracy, accuracy_esrganbig, accuracy_srganbig)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('Accuracy', aligned_pair_accuracy, current_step) tb_logger.add_scalar('Accuracy_pair1_class1', accuracy_esrganbig, current_step) tb_logger.add_scalar('Accuracy_pair1_class2', accuracy_srganbig, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') tb_logger.close()
def inference_single_audio(opt, path_label, model): # opt.path_label = path_label dataloader = data.create_dataloader(opt) processed_file_savepath = dataloader.dataset.get_processed_file_savepath() idx = 0 if opt.driving_pose: video_names = [ 'Input_', 'G_Pose_Driven_', 'Pose_Source_', 'Mouth_Source_' ] else: video_names = ['Input_', 'G_Fix_Pose_', 'Mouth_Source_'] is_mouth_frame = os.path.isdir(dataloader.dataset.mouth_frame_path) if not is_mouth_frame: video_names.pop() save_paths = [] for name in video_names: save_path = os.path.join(processed_file_savepath, name) util.mkdir(save_path) save_paths.append(save_path) for data_i in tqdm(dataloader): # print('==============', i, '===============') fake_image_original_pose_a, fake_image_driven_pose_a = model.forward( data_i, mode='inference') for num in range(len(fake_image_driven_pose_a)): util.save_torch_img( data_i['input'][num], os.path.join(save_paths[0], video_names[0] + str(idx) + '.jpg')) if opt.driving_pose: util.save_torch_img( fake_image_driven_pose_a[num], os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg')) util.save_torch_img( data_i['driving_pose_frames'][num], os.path.join(save_paths[2], video_names[2] + str(idx) + '.jpg')) else: util.save_torch_img( fake_image_original_pose_a[num], os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg')) if is_mouth_frame: util.save_torch_img( data_i['target'][num], os.path.join(save_paths[-1], video_names[-1] + str(idx) + '.jpg')) idx += 1 if opt.gen_video: for i, video_name in enumerate(video_names): img2video(processed_file_savepath, video_name, save_paths[i]) video_concat(processed_file_savepath, 'concat', video_names, dataloader.dataset.audio_path) print('results saved...' + processed_file_savepath) del dataloader return
def main(): parser = argparse.ArgumentParser( description='Train Super Resolution Models') parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.', default='options/train/train_SRFBN.json') opt = option.parse(parser.parse_args().opt) writer = SummaryWriter(comment=f'wjh') # random seed seed = opt['solver']['manual_seed'] if seed is None: seed = random.randint(1, 10000) print("===> Random Seed: [%d]" % seed) random.seed(seed) torch.manual_seed(seed) # create train and val dataloader for phase, dataset_opt in sorted(opt['datasets'].items()): if phase == 'train': train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt) print('===> Train Dataset: %s Number of images: [%d]' % (train_set.name(), len(train_set))) if train_loader is None: raise ValueError("[Error] The training data does not exist") elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('===> Val Dataset: %s Number of images: [%d]' % (val_set.name(), len(val_set))) else: raise NotImplementedError( "[Error] Dataset phase [%s] in *.json is not recognized." % phase) solver = create_solver(opt) scale = opt['scale'] model_name = opt['networks']['which_model'].upper() print('===> Start Train') print("==================================================") solver_log = solver.get_current_log() NUM_EPOCH = int(opt['solver']['num_epochs']) start_epoch = solver_log['epoch'] print("Method: %s || Scale: %d || Epoch Range: (%d ~ %d)" % (model_name, scale, start_epoch, NUM_EPOCH)) step = 0 for epoch in range(start_epoch, NUM_EPOCH + 1): print('\n===> Training Epoch: [%d/%d]... Learning Rate: %f' % (epoch, NUM_EPOCH, solver.get_current_learning_rate())) # Initialization solver_log['epoch'] = epoch # Train model train_loss_list = [] with tqdm(total=len(train_loader), desc='Epoch: [%d/%d]' % (epoch, NUM_EPOCH), miniters=1) as t: for iter, batch in enumerate(train_loader): step += 1 solver.feed_data(batch) iter_loss, loss_log = solver.train_step(step) batch_size = batch['LR'].size(0) train_loss_list.append(iter_loss * batch_size) t.set_postfix_str( '''Batch Loss: {}, pixel_loss: {}, feature_loss: {}, tv_loss: {}, style_loss: {}, fft_loss: {}, generator_vanilla_loss: {}, discriminator_loss: {}'''. format(iter_loss, loss_log["pixel_loss"], loss_log["feature_loss"], loss_log["tv_loss"], loss_log["style_loss"], loss_log['fft_loss'], loss_log['generator_vanilla_loss'], loss_log['discriminator_loss'])) writer.add_scalar('Batch Loss/train', iter_loss, step) writer.add_scalar('pixel_loss/train', loss_log["pixel_loss"], step) writer.add_scalar('feature_loss/train', loss_log["feature_loss"], step) writer.add_scalar('tv_loss/train', loss_log["tv_loss"], step) writer.add_scalar('style_loss/train', loss_log["style_loss"], step) writer.add_scalar('fft_loss/train', loss_log["fft_loss"], step) writer.add_scalar('generator_vanilla_loss/train', loss_log["generator_vanilla_loss"], step) writer.add_scalar('discriminator_loss/train', loss_log["discriminator_loss"], step) t.update() solver_log['records']['train_loss'].append( sum(train_loss_list) / len(train_set)) solver_log['records']['lr'].append(solver.get_current_learning_rate()) print('\nEpoch: [%d/%d] Avg Train Loss: %.6f' % (epoch, NUM_EPOCH, sum(train_loss_list) / len(train_set))) print('===> Validating...', ) psnr_list = [] ssim_list = [] val_loss_list = [] for iter, batch in enumerate(val_loader): solver.feed_data(batch) iter_loss = solver.test() val_loss_list.append(iter_loss) # calculate evaluation metrics visuals = solver.get_current_visual() psnr, ssim = util.calc_metrics(visuals['SR'], visuals['HR'], crop_border=scale) psnr_list.append(psnr) ssim_list.append(ssim) if opt["save_image"]: solver.save_current_visual(epoch, iter) solver_log['records']['val_loss'].append( sum(val_loss_list) / len(val_loss_list)) solver_log['records']['psnr'].append(sum(psnr_list) / len(psnr_list)) solver_log['records']['ssim'].append(sum(ssim_list) / len(ssim_list)) writer.add_scalar('val_loss/val', sum(val_loss_list) / len(val_loss_list), step) writer.add_scalar('psnr/val', sum(psnr_list) / len(psnr_list), step) writer.add_scalar('ssim/val', sum(ssim_list) / len(ssim_list), step) # record the best epoch epoch_is_best = False if solver_log['best_pred'] < (sum(psnr_list) / len(psnr_list)): solver_log['best_pred'] = (sum(psnr_list) / len(psnr_list)) epoch_is_best = True solver_log['best_epoch'] = epoch print( "[%s] PSNR: %.2f SSIM: %.4f Loss: %.6f Best PSNR: %.2f in Epoch: [%d]" % (val_set.name(), sum(psnr_list) / len(psnr_list), sum(ssim_list) / len(ssim_list), sum(val_loss_list) / len(val_loss_list), solver_log['best_pred'], solver_log['best_epoch'])) solver.set_current_log(solver_log) solver.save_checkpoint(epoch, epoch_is_best) solver.save_current_log() # update lr solver.update_learning_rate(epoch) writer.close() print('===> Finished !')
def do_train(): paddle.set_device(args.device) rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() set_seed(args.seed) train_ds = load_dataset(read_text_pair, data_path=args.train_set_file, lazy=False) # If you wanna use bert/roberta pretrained model, # pretrained_model = ppnlp.transformers.BertModel.from_pretrained('bert-base-chinese') # pretrained_model = ppnlp.transformers.RobertaModel.from_pretrained('roberta-wwm-ext') pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained( 'ernie-1.0') # If you wanna use bert/roberta pretrained model, # tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained('bert-base-chinese') # tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained('roberta-wwm-ext') tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0') trans_func = partial(convert_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment ): [data for data in fn(samples)] train_data_loader = create_dataloader(train_ds, mode='train', batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func) model = SemanticIndexBatchNeg(pretrained_model, margin=args.margin, scale=args.scale, output_emb_size=args.output_emb_size) if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): state_dict = paddle.load(args.init_from_ckpt) model.set_dict(state_dict) print("warmup from:{}".format(args.init_from_ckpt)) model = paddle.DataParallel(model) num_training_steps = len(train_data_loader) * args.epochs lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_proportion) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, parameters=model.parameters(), weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_params) if args.use_amp: scaler = paddle.amp.GradScaler(init_loss_scaling=args.amp_loss_scale) global_step = 0 tic_train = time.time() for epoch in range(1, args.epochs + 1): for step, batch in enumerate(train_data_loader, start=1): query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch with paddle.amp.auto_cast( args.use_amp, custom_white_list=["layer_norm", "softmax", "gelu"]): loss = model(query_input_ids=query_input_ids, title_input_ids=title_input_ids, query_token_type_ids=query_token_type_ids, title_token_type_ids=title_token_type_ids) if args.use_amp: scaled = scaler.scale(loss) scaled.backward() scaler.minimize(optimizer, scaled) else: loss.backward() optimizer.step() global_step += 1 if global_step % 10 == 0 and rank == 0: print( "global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s" % (global_step, epoch, step, loss, 10 / (time.time() - tic_train))) tic_train = time.time() lr_scheduler.step() optimizer.clear_grad() if global_step % args.save_steps == 0 and rank == 0: save_dir = os.path.join(args.save_dir, "model_%d" % global_step) if not os.path.exists(save_dir): os.makedirs(save_dir) save_param_path = os.path.join(save_dir, 'model_state.pdparams') paddle.save(model.state_dict(), save_param_path) tokenizer.save_pretrained(save_dir)
def main(): ###### SFTMD train ###### #### setup options parser = argparse.ArgumentParser() parser.add_argument('-opt_F', type=str, help='Path to option YMAL file of SFTMD_Net.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt_F = option.parse(args.opt_F, is_train=True) # convert to NoneDict, which returns None for missing keys opt_F = option.dict_to_nonedict(opt_F) #### random seed seed = opt_F['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) util.set_random_seed(seed) # create PCA matrix of enough kernel batch_ker = util.random_batch_kernel(batch=30000, l=21, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, tensor=False) print('batch kernel shape: {}'.format(batch_ker.shape)) b = np.size(batch_ker, 0) batch_ker = batch_ker.reshape((b, -1)) pca_matrix = util.PCA(batch_ker, k=10).float() print('PCA matrix shape: {}'.format(pca_matrix.shape)) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt_F['dist'] = False rank = -1 print('Disabled distributed training.') else: opt_F['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### loading resume state if exists if opt_F['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load(opt_F['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt_F, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: if resume_state is None: util.mkdir_and_rename( opt_F['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt_F['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt_F['path']['log'], 'train_' + opt_F['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt_F['path']['log'], 'val_' + opt_F['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt_F)) # tensorboard logger if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_F['name']) else: util.setup_logger('base', opt_F['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt_F['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt_F['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt_F['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt_F, train_sampler) if rank <= 0: logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt_F, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None assert val_loader is not None #### create model model_F = create_model(opt_F) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model_F.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt_F['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=10, kernel=21, noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(train_data['GT']) #### update learning rate, schedulers model_F.update_learning_rate(current_step, warmup_iter=opt_F['train']['warmup_iter']) #### training model_F.feed_data(train_data, LR_img, ker_map) model_F.optimize_parameters(current_step) #### log if current_step % opt_F['logger']['print_freq'] == 0: logs = model_F.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model_F.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation if current_step % opt_F['train']['val_freq'] == 0 and rank <= 0: avg_psnr = 0.0 idx = 0 for _, val_data in enumerate(val_loader): idx += 1 #### preprocessing for LR_img and kernel map prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=15, noise=False, cuda=True, sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, rate_cln=0.2, noise_high=0.0) LR_img, ker_map = prepro(val_data['GT']) model_F.feed_data(val_data, LR_img, ker_map) model_F.test() visuals = model_F.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] #img_dir = os.path.join(opt_F['path']['val_images'], img_name) img_dir = os.path.join(opt_F['path']['val_images'], str(current_step)) util.mkdir(img_dir) save_img_path = os.path.join(img_dir,'{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt_F['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(epoch, current_step, avg_psnr)) # tensorboard logger if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) #### save models and training states if current_step % opt_F['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model_F.save(current_step) model_F.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model_F.save('latest') logger.info('End of SFTMD training.')
def main(): dataset = 'REDS' # REDS | Vimeo90K | DIV2K800_sub opt = {} opt['dist'] = False opt['gpu_ids'] = [0] if dataset == 'REDS': opt['name'] = 'test_REDS' opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb' opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' opt['mode'] = 'REDS' opt['N_frames'] = 5 opt['phase'] = 'train' opt['use_shuffle'] = True opt['n_workers'] = 8 opt['batch_size'] = 16 opt['GT_size'] = 256 opt['LQ_size'] = 64 opt['scale'] = 4 opt['use_flip'] = True opt['use_rot'] = True opt['interval_list'] = [1] opt['random_reverse'] = False opt['border_mode'] = False opt['cache_keys'] = None opt['data_type'] = 'lmdb' # img | lmdb | mc elif dataset == 'Vimeo90K': opt['name'] = 'test_Vimeo90K' opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' opt['mode'] = 'Vimeo90K' opt['N_frames'] = 7 opt['phase'] = 'train' opt['use_shuffle'] = True opt['n_workers'] = 8 opt['batch_size'] = 16 opt['GT_size'] = 256 opt['LQ_size'] = 64 opt['scale'] = 4 opt['use_flip'] = True opt['use_rot'] = True opt['interval_list'] = [1] opt['random_reverse'] = False opt['border_mode'] = False opt['cache_keys'] = None opt['data_type'] = 'lmdb' # img | lmdb | mc elif dataset == 'DIV2K800_sub': opt['name'] = 'DIV2K800' opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' opt['mode'] = 'LQGT' opt['phase'] = 'train' opt['use_shuffle'] = True opt['n_workers'] = 8 opt['batch_size'] = 16 opt['GT_size'] = 128 opt['scale'] = 4 opt['use_flip'] = True opt['use_rot'] = True opt['color'] = 'RGB' opt['data_type'] = 'lmdb' # img | lmdb else: raise ValueError('Please implement by yourself.') util.mkdir('tmp') train_set = create_dataset(opt) train_loader = create_dataloader(train_set, opt, opt, None) nrow = int(math.sqrt(opt['batch_size'])) padding = 2 if opt['phase'] == 'train' else 0 print('start...') for i, data in enumerate(train_loader): if i > 5: break print(i) if dataset == 'REDS' or dataset == 'Vimeo90K': LQs = data['LQs'] else: LQ = data['LQ'] GT = data['GT'] if dataset == 'REDS' or dataset == 'Vimeo90K': for j in range(LQs.size(1)): torchvision.utils.save_image(LQs[:, j, :, :, :], 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow, padding=padding, normalize=False) else: torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow, padding=padding, normalize=False) torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding, normalize=False)
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YMAL file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) # -------------------------------------------- ADDED -------------------------------------------- filter_low = filters.FilterLow(gaussian=False) l1_loss = torch.nn.L1Loss() mse_loss = torch.nn.MSELoss() if torch.cuda.is_available(): filter_low = filter_low.cuda() l1_loss = l1_loss.cuda() mse_loss = mse_loss.cuda() # ----------------------------------------------------------------------------------------------- #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benckmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break #### update learning rate model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) #### training model.feed_data(train_data) model.optimize_parameters(current_step) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format( epoch, current_step, model.get_current_learning_rate()) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) # validation if current_step % opt['train']['val_freq'] == 0 and rank <= 0: avg_psnr = val_pix_err_f = val_pix_err_nf = val_mean_color_err = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) # ----------------------------------------- ADDED ----------------------------------------- val_pix_err_f += l1_loss(filter_low(visuals['SR']), filter_low(visuals['GT'])) val_pix_err_nf += l1_loss(visuals['SR'], visuals['GT']) val_mean_color_err += mse_loss(visuals['SR'].mean(2).mean(1), visuals['GT'].mean(2).mean(1)) # ----------------------------------------------------------------------------------------- avg_psnr = avg_psnr / idx val_pix_err_f /= idx val_pix_err_nf /= idx val_mean_color_err /= idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format( epoch, current_step, avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) tb_logger.add_scalar('val_pix_err_f', val_pix_err_f, current_step) tb_logger.add_scalar('val_pix_err_nf', val_pix_err_nf, current_step) tb_logger.add_scalar('val_mean_color_err', val_mean_color_err, current_step) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info('Saving models and training states.') model.save(current_step) model.save_training_state(epoch, current_step) if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.')
"base", opt_P["path"]["log"], "test_" + opt_P["name"], level=logging.INFO, screen=True, tofile=True, ) logger = logging.getLogger("base") logger.info(option.dict2str(opt_P)) logger.info(option.dict2str(opt_C)) #### Create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt_P["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info("Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set))) test_loaders.append(test_loader) # load pretrained model by default model_F = create_model(opt_F) model_P = create_model(opt_P) model_C = create_model(opt_C) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt["name"] # path opt[''] logger.info("\nTesting [{:s}]...".format(test_set_name)) test_start_time = time.time() dataset_dir = os.path.join(opt_P["path"]["results_root"], test_set_name) util.mkdir(dataset_dir)
import torch.nn from options.train_options import TrainOptions from data import create_dataloader from models import create_model from utils.util import SaveResults from utils import dataset_util, util import numpy as np import cv2 torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(0) if __name__ == '__main__': opt = TrainOptions().parse() train_data_loader = create_dataloader(opt) train_dataset_size = len(train_data_loader) print('#training images = %d' % train_dataset_size) model = create_model(opt) model.setup(opt) save_results = SaveResults(opt) total_steps = 0 lr = opt.lr_task for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0
import time from tqdm import tqdm from options.train_options import TrainOptions import data as Dataset from model import create_model from util.visualizer import Visualizer if __name__ == '__main__': # get training options opt = TrainOptions().parse() opt.serial_batches = True opt.nThreads = 0 # create a dataset dataset = Dataset.create_dataloader(opt) dataset_size = len(dataset) * opt.batchSize print('training images = %d' % dataset_size) for epoch in range(10): print("epoch", epoch) for i, data in tqdm(enumerate(dataset), total=len(dataset)): pass
from collections import OrderedDict from options.train_options import TrainOptions import data from util.iter_counter import IterationCounter from util.visualizer import Visualizer from trainers.pix2pix_trainer import Pix2PixTrainer # parse options opt = TrainOptions().parse() # print options to help debugging print(' '.join(sys.argv)) # load the dataset dataloader = data.create_dataloader(opt) if opt.unpairTrain: dataloader2 = data.create_dataloader(opt, 2) # create trainer for our model trainer = Pix2PixTrainer(opt) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader)) data_size = len(dataloader) # create tool for visualization visualizer = Visualizer(opt) for epoch in iter_counter.training_epochs(): # for unpair training
state_dict = paddle.load(args.params_path) model.set_dict(state_dict) logger.info("Loaded parameters from %s" % args.params_path) else: raise ValueError( "Please set --params_path with correct pretrained model file") id2corpus = gen_id2corpus(args.corpus_file) # conver_example function's input must be dict corpus_list = [{idx: text} for idx, text in id2corpus.items()] corpus_ds = MapDataset(corpus_list) corpus_data_loader = create_dataloader(corpus_ds, mode='predict', batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func) final_index = build_index(args, corpus_data_loader, model) text_list, text2similar_text = gen_text_file(args.similar_text_pair_file) query_ds = MapDataset(text_list) query_data_loader = create_dataloader(query_ds, mode='predict', batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func)
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: util.mkdir_and_rename( opt['path'] ['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) if version >= 1.1: # PyTorch 1.1 from torch.utils.tensorboard import SummaryWriter else: logger.info( 'You are using PyTorch {}. Tensorboard will use [tensorboardX]' .format(version)) from tensorboardX import SummaryWriter tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) #### random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) if rank <= 0: logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) total_epochs = int( math.ceil(total_iters / (train_size * dataset_ratio))) else: train_sampler = None train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) if rank <= 0: logger.info( 'Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None #### create model model = create_model(opt) ssim_fn = SSIM().to(model.device) ssim_fn.eval() #### resume training if resume_state: logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) start_epoch = resume_state['epoch'] current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format( start_epoch, current_step)) for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) for _, train_data in enumerate(train_loader): if opt['mode'] == 'train': current_step += 1 if current_step > total_iters: break #### training model.feed_data(train_data) model.optimize_parameters(current_step) #### update learning rate model.update_learning_rate( current_step, warmup_iter=opt['train']['warmup_iter']) #### log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format( epoch, current_step) for v in model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if rank <= 0: tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: if rank <= 0: logger.info( 'Saving models and training states {}.'.format( current_step)) model.save(current_step) model.save_training_state(epoch, current_step) else: opt['train']['val_freq'] = 1 #### validation if opt['datasets'].get( 'val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in [ 'sr', 'srgan' ] and rank <= 0: # image restoration validation # does not support multi-GPU validation pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LQ_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr', avg_psnr, current_step) else: # video restoration validation if opt['dist']: # multi-GPU testing psnr_rlt = {} # with border and center frames if rank == 0: pbar = util.ProgressBar(len(val_set)) for idx in range(rank, len(val_set), world_size): val_data = val_set[idx] val_data['LQs'].unsqueeze_(0) val_data['GT'].unsqueeze_(0) folder = val_data['folder'] idx_d, max_idx = val_data['idx'].split('/') idx_d, max_idx = int(idx_d), int(max_idx) if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = torch.zeros( max_idx, dtype=torch.float32, device='cuda') # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda') model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # calculate PSNR psnr_rlt[folder][idx_d] = util.calculate_psnr( rlt_img, gt_img) if opt['datasets']['val'][ 'save_imgs'] and rank <= 0: save_folder = os.path.join( opt['path']['val_images'], 'val_{}_{}'.format(opt['name'], current_step), folder) util.mkdirs(save_folder) cv2.imwrite( os.path.join(save_folder, '{}.png'.format(idx)), rlt_img) if rank == 0: for _ in range(world_size): pbar.update('Test {} - {}/{}'.format( folder, idx_d, max_idx)) # # collect data for _, v in psnr_rlt.items(): dist.reduce(v, 0) dist.barrier() if rank == 0: psnr_rlt_avg = {} psnr_total_avg = 0. for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = torch.mean(v).cpu().item() psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4f}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4f}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt[ 'name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k, v, current_step) else: pbar = util.ProgressBar(len(val_loader)) psnr_rlt = {} # with border and center frames ssim_rlt = {} psnr_rlt_avg = {} ssim_rlt_avg = {} psnr_total_avg = 0. ssim_total_avg = 0. for val_data in val_loader: folder = val_data['folder'][0] idx_d = val_data['idx'] idx = idx_d[0].split('/')[0] # border = val_data['border'].item() if psnr_rlt.get(folder, None) is None: psnr_rlt[folder] = [] if ssim_rlt.get(folder, None) is None: ssim_rlt[folder] = [] model.feed_data(val_data) model.test() visuals = model.get_current_visuals() rlt_img = util.tensor2img(visuals['rlt']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # save images if opt['datasets']['val']['save_imgs']: save_folder = os.path.join( opt['path']['val_images'], 'val_{}_{}'.format(opt['name'], current_step), folder) util.mkdirs(save_folder) cv2.imwrite( os.path.join(save_folder, '{}.png'.format(idx)), rlt_img) # calculate PSNR psnr = util.calculate_psnr(rlt_img, gt_img) psnr_rlt[folder].append(psnr) # calculate ssim with torch.no_grad(): ssim = ssim_fn(model.fake_H, model.real_H).data.cpu().item() ssim_rlt[folder].append(ssim) pbar.update('Test {} - {}'.format(folder, idx_d)) for k, v in ssim_rlt.items(): ssim_rlt_avg[k] = sum(v) / len(v) ssim_total_avg += ssim_rlt_avg[k] ssim_total_avg /= len(ssim_rlt) log_s = '# Validation # SSIM: {:.4f}:'.format( ssim_total_avg) for k, v in ssim_rlt_avg.items(): log_s += ' {}: {:.4f}'.format(k, v) logger.info(log_s) for k, v in psnr_rlt.items(): psnr_rlt_avg[k] = sum(v) / len(v) psnr_total_avg += psnr_rlt_avg[k] psnr_total_avg /= len(psnr_rlt) log_s = '# Validation # PSNR: {:.4f}:'.format( psnr_total_avg) for k, v in psnr_rlt_avg.items(): log_s += ' {}: {:.4f}'.format(k, v) logger.info(log_s) if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step) tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step) for k, v in psnr_rlt_avg.items(): tb_logger.add_scalar(k + '_psnr', v, current_step) for k, v in ssim_rlt_avg.items(): tb_logger.add_scalar(k + '_ssim', v, current_step) if rank <= 0: logger.info('Saving the final model.') model.save('latest') logger.info('End of training.') tb_logger.close()
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import sys import os from collections import OrderedDict import data from options.test_options import TestOptions from models.pix2pix_model import Pix2PixModel from util.visualizer import Visualizer from util import html opt = TestOptions().parse() dataloader = data.create_dataloader(opt) model = Pix2PixModel(opt) model.eval() visualizer = Visualizer(opt) # create a webpage that summarizes the all results web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML( web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test for i, data_i in enumerate(dataloader):
def main(): #### options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, default='options/test/test_KPSAGAN.yml', help='Path to option YMAL file.') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=False) #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False rank = -1 print('Disabled distributed training.') else: opt['dist'] = True init_dist() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU device_id = torch.cuda.current_device() resume_state = torch.load( opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) option.check_resume(opt, resume_state['iter']) # check resume options else: resume_state = None #### mkdir and loggers if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) if resume_state is None: # util.mkdir_and_rename( #opt['path']['experiments_root']) # rename experiment folder if exists util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, screen=True, tofile=True) util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') logger.info(option.dict2str(opt)) else: util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) logger = logging.getLogger('base') # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) torch.backends.cudnn.benckmark = True # torch.backends.cudnn.deterministic = True #### create train and val dataloader dataset_ratio = 200 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt, None) if rank <= 0: logger.info('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) #### create model model = create_model(opt) avg_psnr = 0.0 idx = 0 dataset_dir = '/srv/wuyichao/Super-Resolution/KPSAGAN/BasicSR-master/BasicSR-master-c/result_600000/' util.mkdir(dataset_dir) for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename( val_data['LQ_path'][0]))[0] logger.info(img_name) #img_dir = os.path.join(opt['path']['val_images'], img_name) #util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['GT']) # uint8 # save images suffix = 'cut' #opt['suffix'] if suffix: save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') else: save_img_path = osp.join(dataset_dir, img_name + '.png') util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt['scale'] gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx # log logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger_val = logging.getLogger('val') # validation logger logger_val.info('psnr: {:.4e}'.format(avg_psnr))