def get_resume_state(opt): logger = util.get_root_logger() # train from scratch OR resume training if opt['path']['resume_state']: if os.path.isdir(opt['path']['resume_state']): resume_state_path = glob.glob(opt['path']['resume_state'] + '/*.state') resume_state_path = util.sorted_nicely(resume_state_path)[-1] else: resume_state_path = opt['path']['resume_state'] if opt['gpu_ids']: resume_state = torch.load(resume_state_path) else: resume_state = torch.load(resume_state_path, map_location=torch.device('cpu')) logger.info('Set [resume_state] to {}'.format(resume_state_path)) logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) options.check_resume(opt) # check resume options else: # training from scratch resume_state = None return resume_state
def get_dataloaders(opt): logger = util.get_root_logger() gpu_ids = opt.get('gpu_ids', None) gpu_ids = gpu_ids if gpu_ids else [] # Create datasets and dataloaders dataloaders = {} data_params = {} znorm = {} for phase, dataset_opt in opt['datasets'].items(): if opt['is_train'] and phase not in ['train', 'val']: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) name = dataset_opt['name'] dataset = create_dataset(dataset_opt) if not dataset: raise Exception('Dataset "{}" for phase "{}" is empty.'.format( name, phase)) dataloaders[phase] = create_dataloader(dataset, dataset_opt, gpu_ids) if opt['is_train'] and phase == 'train': batch_size = dataset_opt.get('batch_size', 4) virtual_batch_size = dataset_opt.get('virtual_batch_size', batch_size) virtual_batch_size = virtual_batch_size if virtual_batch_size > batch_size else batch_size # train_size = int(math.ceil(len(dataset) / batch_size)) train_size = len(dataloaders[phase]) logger.info(f'Number of train images: {len(dataset):,d}, ' f'epoch iters: {train_size:,d}') total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) logger.info(f'Total epochs needed: {total_epochs:d} for ' f'iters {total_iters:,d}') data_params = { "batch_size": batch_size, "virtual_batch_size": virtual_batch_size, "total_iters": total_iters, "total_epochs": total_epochs } assert dataset is not None else: logger.info('Number of {:s} images in [{:s}]: {:,d}'.format( phase, name, len(dataset))) if phase != 'val': znorm[name] = dataset_opt.get('znorm', False) if not opt['is_train']: data_params['znorm'] = znorm if not dataloaders: raise Exception("No Dataloader has been created.") return dataloaders, data_params
def get_random_seed(opt): logger = util.get_root_logger() # set random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) opt['train']['manual_seed'] = seed logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) return opt
def configure_loggers(opt=None): tofile = opt.get('logger', {}).get('save_logfile', True) if opt['is_train']: # config loggers. Before it, the log will not work util.get_root_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True, tofile=tofile) util.get_root_logger('val', opt['path']['log'], 'val', level=logging.INFO, tofile=tofile) else: util.get_root_logger(None, opt['path']['log'], 'test', level=logging.INFO, screen=True, tofile=tofile) logger = util.get_root_logger() # 'base' logger.info(options.dict2str(opt)) # initialize tensorboard logger tb_logger = None if opt.get('use_tb_logger', False) and 'debug' not in opt['name']: version = float(torch.__version__[0:3]) log_dir = os.path.join(opt['path']['root'], 'tb_logger', opt['name']) # log_dir = os.path.join(opt['path']['experiments_root'], opt['name'], 'tb') # logdir_valid = os.path.join(opt['path']['root'], 'tb_logger', opt['name'] + 'valid') # logdir_valid = os.path.join(opt['path']['experiments_root'], opt['name'], 'tb_valid') if version >= 1.1: # PyTorch 1.1 # official PyTorch tensorboard try: from torch.utils.tensorboard import SummaryWriter except: from tensorboardX import SummaryWriter else: logger.info( 'You are using PyTorch {}. Using [tensorboardX].'.format( version)) from tensorboardX import SummaryWriter try: # for versions PyTorch > 1.1 and tensorboardX < 1.6 tb_logger = SummaryWriter(log_dir=log_dir) # tb_logger_valid = SummaryWriter(log_dir=logdir_valid) except: # for version tensorboardX >= 1.7 tb_logger = SummaryWriter(logdir=log_dir) # tb_logger_valid = SummaryWriter(logdir=logdir_valid) return {"tb_logger": tb_logger}
def test_loop(model, opt, dataloaders, data_params): logger = util.get_root_logger() # read data_params znorms = data_params['znorm'] # prepare the metric calculation classes for RGB and Y_only images calc_metrics = opt.get('metrics', None) if calc_metrics: test_metrics = metrics.MetricsDict(metrics = calc_metrics) test_metrics_y = metrics.MetricsDict(metrics = calc_metrics) for phase, dataloader in dataloaders.items(): name = dataloader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(name)) dataset_dir = os.path.join(opt['path']['results_root'], name) util.mkdir(dataset_dir) for data in dataloader: znorm = znorms[name] need_HR = False if dataloader.dataset.opt['dataroot_HR'] is None else True # set up per image CEM wrapper if configured CEM_net = get_CEM(opt, data) model.feed_data(data, need_HR=need_HR) # unpack data from data loader # img_path = data['LR_path'][0] img_path = get_img_path(data) img_name = os.path.splitext(os.path.basename(img_path))[0] # test with eval mode. This only affects layers like batchnorm and dropout. test_mode = opt.get('test_mode', None) if test_mode == 'x8': # geometric self-ensemble model.test_x8(CEM_net=CEM_net) elif test_mode == 'chop': # chop images in patches/crops, to reduce VRAM usage model.test_chop(patch_size=opt.get('chop_patch_size', 100), step=opt.get('chop_step', 0.9), CEM_net=CEM_net) else: # normal inference model.test(CEM_net=CEM_net) # run inference # get image results visuals = model.get_current_visuals(need_HR=need_HR) # post-process options if using CEM if opt.get('use_cem', None) and opt['cem_config'].get('out_orig', False): # run regular inference if test_mode == 'x8': model.test_x8() elif test_mode == 'chop': model.test_chop(patch_size=opt.get('chop_patch_size', 100), step=opt.get('chop_step', 0.9)) else: model.test() orig_visuals = model.get_current_visuals(need_HR=need_HR) if opt['cem_config'].get('out_filter', False): GF = GuidedFilter(ks=opt['cem_config'].get('out_filter_ks', 7)) filt = GF(visuals['SR'].unsqueeze(0), (visuals['SR']-orig_visuals['SR']).unsqueeze(0)).squeeze(0) visuals['SR'] = orig_visuals['SR']+filt if opt['cem_config'].get('out_keepY', False): out_regY = rgb_to_ycbcr(orig_visuals['SR']).unsqueeze(0) out_cemY = rgb_to_ycbcr(visuals['SR']).unsqueeze(0) visuals['SR'] = ycbcr_to_rgb(torch.cat([out_regY[:, 0:1, :, :], out_cemY[:, 1:2, :, :], out_cemY[:, 2:3, :, :]], 1)).squeeze(0) res_options = visuals_check(visuals.keys(), opt.get('val_comparison', None)) # save images save_img_path = os.path.join(dataset_dir, img_name + opt.get('suffix', '')) # save single images or lr / sr comparison if opt['val_comparison'] and len(res_options['save_imgs']) > 1: comp_images = [tensor2np(visuals[save_img_name], denormalize=znorm) for save_img_name in res_options['save_imgs']] util.save_img_comp(comp_images, save_img_path + '.png') else: for save_img_name in res_options['save_imgs']: imn = '_' + save_img_name if len(res_options['save_imgs']) > 1 else '' util.save_img(tensor2np(visuals[save_img_name], denormalize=znorm), save_img_path + imn + '.png') # calculate metrics if HR dataset is provided and metrics are configured in options if need_HR and calc_metrics and res_options['aligned_metrics']: metric_imgs = [tensor2np(visuals[x], denormalize=znorm) for x in res_options['compare_imgs']] test_results = test_metrics.calculate_metrics(metric_imgs[0], metric_imgs[1], crop_size=opt['scale']) # prepare single image metrics log message logger_m = '{:20s} -'.format(img_name) for k, v in test_results: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res if gt_img.shape[2] == 3: # RGB image, calculate y_only metrics test_results_y = test_metrics_y.calculate_metrics(metric_imgs[0], metric_imgs[1], crop_size=opt['scale'], only_y=True) # add the y only results to the single image log message for k, v in test_results_y: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res logger.info(logger_m) else: logger.info(img_name) # average metrics results for the dataset if need_HR and calc_metrics: # aggregate the metrics results (automatically resets the metric classes) avg_metrics = test_metrics.get_averages() avg_metrics_y = test_metrics_y.get_averages() # prepare log average metrics message agg_logger_m = '' for r in avg_metrics: formatted_res = r['name'].upper() + ': {:.6f}, '.format(r['average']) agg_logger_m += formatted_res logger.info('----Average metrics results for {}----\n\t'.format(name) + agg_logger_m[:-2]) if len(avg_metrics_y > 0): # prepare log average Y channel metrics message agg_logger_m = '' for r in avg_metrics_y: formatted_res = r['name'].upper() + '_Y' + ': {:.6f}, '.format(r['average']) agg_logger_m += formatted_res logger.info('----Y channel, average metrics ----\n\t' + agg_logger_m[:-2])
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to options file.') opt = options.parse(parser.parse_args().opt, is_train=False) util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G')) opt = options.dict_to_nonedict(opt) util.get_root_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True) logger = logging.getLogger('base') logger.info(options.dict2str(opt)) # Create test dataset and dataloader test_loaders = [] znorm = False # TMP for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info('Number of test images in [{:s}]: {:d}'.format( dataset_opt['name'], len(test_set))) test_loaders.append(test_loader) # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop. if dataset_opt['znorm'] and znorm == False: znorm = True # Create model model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(test_set_name)) test_start_time = time.time() dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) util.mkdir(dataset_dir) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_HR = False if test_loader.dataset.opt[ 'dataroot_HR'] is None else True model.feed_data(data, need_HR=need_HR) img_path = data['in_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other) top_img = tensor2np(visuals['top_fake']) # uint8 bot_img = tensor2np(visuals['bottom_fake']) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix) else: save_img_path = os.path.join(dataset_dir, img_name) util.save_img(top_img, save_img_path + '_top.png') util.save_img(bot_img, save_img_path + '_bot.png') #TODO: update to use metrics functions # calculate PSNR and SSIM if need_HR: #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other) gt_img = tensor2img(visuals['HR'], denormalize=znorm) # uint8 gt_img = gt_img / 255. sr_img = sr_img / 255. crop_border = test_loader.dataset.opt['scale'] cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) if gt_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) gt_img_y = bgr2ycbcr(gt_img, only_y=True) cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border] cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border] psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255) ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) logger.info( '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.' .format(img_name, psnr, ssim, psnr_y, ssim_y)) else: logger.info( '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format( img_name, psnr, ssim)) else: logger.info(img_name) #TODO: update to use metrics functions if need_HR: # metrics # Average PSNR/SSIM results ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) logger.info( '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n' .format(test_set_name, ave_psnr, ave_ssim)) if test_results['psnr_y'] and test_results['ssim_y']: ave_psnr_y = sum(test_results['psnr_y']) / len( test_results['psnr_y']) ave_ssim_y = sum(test_results['ssim_y']) / len( test_results['ssim_y']) logger.info( '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n' .format(ave_psnr_y, ave_ssim_y))
def fit(model, opt, dataloaders, steps_states, data_params, loggers): # read data_params batch_size = data_params['batch_size'] virtual_batch_size = data_params['virtual_batch_size'] total_iters = data_params['total_iters'] total_epochs = data_params['total_epochs'] # read steps_states start_epoch = steps_states["start_epoch"] current_step = steps_states["current_step"] virtual_step = steps_states["virtual_step"] # read loggers logger = util.get_root_logger() tb_logger = loggers["tb_logger"] # training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) try: timer = metrics.Timer() # iteration timer timerData = metrics.TickTock() # data timer timerEpoch = metrics.TickTock() # epoch timer # outer loop for different epochs for epoch in range(start_epoch, (total_epochs * (virtual_batch_size // batch_size))+1): timerData.tick() timerEpoch.tick() # inner iteration loop within one epoch for n, train_data in enumerate(dataloaders['train'], start=1): timerData.tock() virtual_step += 1 take_step = False if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0: current_step += 1 take_step = True if current_step > total_iters: break # training model.feed_data(train_data) # unpack data from dataset and apply preprocessing model.optimize_parameters(virtual_step) # calculate loss functions, get gradients, update network weights # log def eta(t_iter): # calculate training ETA in hours return (t_iter * (opt['train']['niter'] - current_step)) / 3600 if t_iter > 0 else 0 if current_step % opt['logger']['print_freq'] == 0 and take_step: # iteration end time avg_time = timer.get_average_and_reset() avg_data_time = timerData.get_average_and_reset() # print training losses and save logging information to disk message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.4f}s, td:{:.4f}s, eta:{:.4f}h> '.format( epoch, current_step, model.get_current_learning_rate(current_step), avg_time, avg_data_time, eta(avg_time)) # tensorboard training logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs # tb_logger.add_scalar('loss/nll', nll, current_step) tb_logger.add_scalar('lr/base', model.get_current_learning_rate(), current_step) tb_logger.add_scalar('time/iteration', timer.get_last_iteration(), current_step) tb_logger.add_scalar('time/data', timerData.get_last_iteration(), current_step) logs = model.get_current_log() for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard loss logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs tb_logger.add_scalar(k, v, current_step) # tb_logger.flush() logger.info(message) # start time for next iteration #TODO:skip the validation time from calculation timer.tick() # update learning rate if model.optGstep and model.optDstep and take_step: model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1)) # save latest models and training states every <save_checkpoint_freq> iterations if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step: if model.swa: model.save(current_step, opt['logger']['overwrite_chkp'], loader=dataloaders['train']) else: model.save(current_step, opt['logger']['overwrite_chkp']) model.save_training_state( epoch=epoch + (n >= len(dataloaders['train'])), iter_step=current_step, latest=opt['logger']['overwrite_chkp'] ) logger.info('Models and training states saved.') # validation if dataloaders.get('val', None) and current_step % opt['train']['val_freq'] == 0 and take_step: val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None)) nlls = [] for val_data in dataloaders['val']: model.feed_data(val_data) # unpack data from data loader model.test() # run inference if hasattr(model, 'nll'): nll = model.nll if model.nll else 0 nlls.append(nll) """ Get Visuals """ visuals = model.get_current_visuals() # get image results img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) # Save SR images for reference sr_img = None if hasattr(model, 'heats'): # SRFlow opt['train']['val_comparison'] = False for heat in model.heats: for i in range(model.n_sample): sr_img = tensor2np(visuals['SR', heat, i], denormalize=opt['datasets']['train']['znorm']) if opt['train']['overwrite_val_imgs']: save_img_path = os.path.join(img_dir, '{:s}_h{:03d}_s{:d}.png'.format(img_name, int(heat * 100), i)) else: save_img_path = os.path.join(img_dir, '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name, current_step, int(heat * 100), i)) util.save_img(sr_img, save_img_path) else: # regular SR sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm']) if opt['train']['overwrite_val_imgs']: save_img_path = os.path.join(img_dir, '{:s}.png'.format(img_name)) else: save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) if not opt['train']['val_comparison']: util.save_img(sr_img, save_img_path) assert sr_img is not None # Save GT images for reference gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm']) if opt['train']['save_gt']: save_img_path_gt = os.path.join(img_dir, '{:s}_GT.png'.format(img_name)) if not os.path.isfile(save_img_path_gt): util.save_img(gt_img, save_img_path_gt) # Save LQ images for reference if opt['train']['save_lr']: save_img_path_lq = os.path.join(img_dir, '{:s}_LQ.png'.format(img_name)) if not os.path.isfile(save_img_path_lq): lq_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm']) util.save_img(lq_img, save_img_path_lq, scale=opt['scale']) # save single images or LQ / SR comparison if opt['train']['val_comparison']: lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm']) util.save_img_comp([lr_img, sr_img], save_img_path) # else: # util.save_img(sr_img, save_img_path) """ Get Metrics # TODO: test using tensor based metrics (batch) instead of numpy. """ val_metrics.calculate_metrics(sr_img, gt_img, crop_size=opt['scale']) # , only_y=True) avg_metrics = val_metrics.get_averages() if nlls: avg_nll = sum(nlls) / len(nlls) del val_metrics # log logger_m = '' for r in avg_metrics: formatted_res = r['name'].upper() + ': {:.5g}, '.format(r['average']) logger_m += formatted_res if nlls: logger_m += 'avg_nll: {:.4e} '.format(avg_nll) logger.info('# Validation # ' + logger_m[:-2]) logger_val = logging.getLogger('val') # validation logger logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step) + logger_m[:-2]) # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: for r in avg_metrics: tb_logger.add_scalar(r['name'], r['average'], current_step) if nlls: tb_logger.add_scalar('average nll', avg_nll, current_step) # tb_logger.flush() # tb_logger_valid.add_scalar(r['name'], r['average'], current_step) # tb_logger_valid.flush() timerData.tick() timerEpoch.tock() logger.info('End of epoch {} / {} \t Time Taken: {:.4f} sec'.format( epoch, total_epochs, timerEpoch.get_last_iteration())) logger.info('Saving the final model.') if model.swa: model.save('latest', loader=dataloaders['train']) else: model.save('latest') logger.info('End of training.') except KeyboardInterrupt: # catch a KeyboardInterrupt and save the model and state to resume later if model.swa: model.save(current_step, True, loader=dataloaders['train']) else: model.save(current_step, True) model.save_training_state(epoch + (n >= len(dataloaders['train'])), current_step, True) logger.info('Training interrupted. Latest models and training states saved.')
def test_loop(model, opt, dataloaders, data_params): logger = util.get_root_logger() # read data_params znorms = data_params['znorm'] # prepare the metric calculation classes for RGB and Y_only images calc_metrics = opt.get('metrics', None) if calc_metrics: test_metrics = metrics.MetricsDict(metrics=calc_metrics) test_metrics_y = metrics.MetricsDict(metrics=calc_metrics) for phase, dataloader in dataloaders.items(): name = dataloader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(name)) dataset_dir = os.path.join(opt['path']['results_root'], name) util.mkdir(dataset_dir) nlls = [] for data in dataloader: znorm = znorms[name] need_HR = False if dataloader.dataset.opt[ 'dataroot_HR'] is None else True # set up per image CEM wrapper if configured CEM_net = get_CEM(opt, data) model.feed_data(data, need_HR=need_HR) # unpack data from data loader # img_path = data['LR_path'][0] img_path = get_img_path(data) img_name = os.path.splitext(os.path.basename(img_path))[0] # test with eval mode. This only affects layers like batchnorm and dropout. test_mode = opt.get('test_mode', None) if test_mode == 'x8': # geometric self-ensemble # model.test_x8(CEM_net=CEM_net) break elif test_mode == 'chop': # chop images in patches/crops, to reduce VRAM usage # model.test_chop(patch_size=opt.get('chop_patch_size', 100), # step=opt.get('chop_step', 0.9), # CEM_net=CEM_net) break else: # normal inference model.test(CEM_net=CEM_net) # run inference if hasattr(model, 'nll'): nll = model.nll if model.nll else 0 nlls.append(nll) # get image results visuals = model.get_current_visuals(need_HR=need_HR) res_options = visuals_check(visuals.keys(), opt.get('val_comparison', None)) # save images save_img_path = os.path.join(dataset_dir, img_name + opt.get('suffix', '')) # Save SR images for reference sr_img = None if hasattr(model, 'heats'): # SRFlow opt['val_comparison'] = False for heat in model.heats: for i in range(model.n_sample): for save_img_name in res_options['save_imgs']: imn = '_' + save_img_name if len( res_options['save_imgs']) > 1 else '' imn += '_h{:03d}_s{:d}'.format(int(heat * 100), i) util.save_img( tensor2np(visuals[save_img_name, heat, i], denormalize=znorm), save_img_path + imn + '.png') else: # regular SR if not opt['val_comparison']: for save_img_name in res_options['save_imgs']: imn = '_' + save_img_name if len( res_options['save_imgs']) > 1 else '' util.save_img( tensor2np(visuals[save_img_name], denormalize=znorm), save_img_path + imn + '.png') # save single images or lr / sr comparison if opt['val_comparison'] and len(res_options['save_imgs']) > 1: comp_images = [ tensor2np(visuals[save_img_name], denormalize=znorm) for save_img_name in res_options['save_imgs'] ] util.save_img_comp(comp_images, save_img_path + '.png') # else: # util.save_img(sr_img, save_img_path) # calculate metrics if HR dataset is provided and metrics are configured in options if need_HR and calc_metrics and res_options['aligned_metrics']: metric_imgs = [ tensor2np(visuals[x], denormalize=znorm) for x in res_options['compare_imgs'] ] test_results = test_metrics.calculate_metrics( metric_imgs[0], metric_imgs[1], crop_size=opt['scale']) # prepare single image metrics log message logger_m = '{:20s} -'.format(img_name) for k, v in test_results: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res if gt_img.shape[2] == 3: # RGB image, calculate y_only metrics test_results_y = test_metrics_y.calculate_metrics( metric_imgs[0], metric_imgs[1], crop_size=opt['scale'], only_y=True) # add the y only results to the single image log message for k, v in test_results_y: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res logger.info(logger_m) else: logger.info(img_name) # average metrics results for the dataset if need_HR and calc_metrics: # aggregate the metrics results (automatically resets the metric classes) avg_metrics = test_metrics.get_averages() avg_metrics_y = test_metrics_y.get_averages() # prepare log average metrics message agg_logger_m = '' for r in avg_metrics: formatted_res = r['name'].upper() + ': {:.6f}, '.format( r['average']) agg_logger_m += formatted_res logger.info( '----Average metrics results for {}----\n\t'.format(name) + agg_logger_m[:-2]) if len(avg_metrics_y > 0): # prepare log average Y channel metrics message agg_logger_m = '' for r in avg_metrics_y: formatted_res = r['name'].upper( ) + '_Y' + ': {:.6f}, '.format(r['average']) agg_logger_m += formatted_res logger.info('----Y channel, average metrics ----\n\t' + agg_logger_m[:-2])
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to options file.') opt = options.parse(parser.parse_args().opt, is_train=False) util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G')) logger = util.get_root_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True) logger = logging.getLogger('base') logger.info(options.dict2str(opt)) scale = opt.get('scale', 4) # Create test dataset and dataloader test_loaders = [] znorm = False # TMP # znorm_list = [] ''' video_list = os.listdir(cfg.testset_dir) for idx_video in range(len(video_list)): video_name = video_list[idx_video] # dataloader test_set = TestsetLoader(cfg, video_name) test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False) ''' for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) test_loaders.append(test_loader) # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop. # if dataset_opt.get['znorm'] and znorm == False: # znorm = True znorm = dataset_opt.get('znorm', False) # znorm_list.apped(znorm) # Create model model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(test_set_name)) test_start_time = time.time() dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) util.mkdir(dataset_dir) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True img_path = data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] # tmp_vis(data['LR'][:,1,:,:,:], True) if opt.get('chop_forward', None): # data if len(data['LR'].size()) == 4: b, n_frames, h_lr, w_lr = data['LR'].size() LR_y_cube = data['LR'].view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w elif len(data['LR'].size()) == 5: # for networks that work with 3 channel images _, n_frames, _, _, _ = data['LR'].size() LR_y_cube = data['LR'] # b, t, c, h, w # print(LR_y_cube.shape) # print(data['LR_bicubic'].shape) # crop borders to ensure each patch can be divisible by 2 # TODO: this is modcrop, not sure if really needed, check (the dataloader already does modcrop) _, _, _, h, w = LR_y_cube.size() h = int(h // 16) * 16 w = int(w // 16) * 16 LR_y_cube = LR_y_cube[:, :, :, :h, :w] if isinstance(data['LR_bicubic'], torch.Tensor): # SR_cb = data['LR_bicubic'][:, 1, :, :][:, :, :h * scale, :w * scale] SR_cb = data['LR_bicubic'][:, 1, :h * scale, :w * scale] # SR_cr = data['LR_bicubic'][:, 2, :, :][:, :, :h * scale, :w * scale] SR_cr = data['LR_bicubic'][:, 2, :h * scale, :w * scale] SR_y = chop_forward(LR_y_cube, model, scale, need_HR=need_HR).squeeze(0) # SR_y = np.array(SR_y.data.cpu()) if test_loader.dataset.opt.get('srcolors', None): print(SR_y.shape, SR_cb.shape, SR_cr.shape) sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3)) else: sr_img = SR_y else: # data model.feed_data(data, need_HR=need_HR) # SR_y = net(LR_y_cube).squeeze(0) model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) # ds = torch.nn.AvgPool2d(2, stride=2, count_include_pad=False) # tmp_vis(ds(visuals['SR']), True) # tmp_vis(visuals['SR'], True) if test_loader.dataset.opt.get('y_only', None) and test_loader.dataset.opt.get('srcolors', None): SR_cb = data['LR_bicubic'][:, 1, :, :] SR_cr = data['LR_bicubic'][:, 2, :, :] # tmp_vis(ds(SR_cb), True) # tmp_vis(ds(SR_cr), True) sr_img = ycbcr_to_rgb(torch.stack((visuals['SR'], SR_cb, SR_cr), -3)) else: sr_img = visuals['SR'] # if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other) sr_img = tensor2np(sr_img, denormalize=znorm) # uint8 # save images suffix = opt['suffix'] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + '.png') else: save_img_path = os.path.join(dataset_dir, img_name + '.png') util.save_img(sr_img, save_img_path) # TODO: update to use metrics functions # calculate PSNR and SSIM if need_HR: # if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other) gt_img = tensor2img(visuals['HR'], denormalize=znorm) # uint8 gt_img = gt_img / 255. sr_img = sr_img / 255. crop_border = test_loader.dataset.opt['scale'] cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) if gt_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) gt_img_y = bgr2ycbcr(gt_img, only_y=True) cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border] cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border] psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255) ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.' \ .format(img_name, psnr, ssim, psnr_y, ssim_y)) else: logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) else: logger.info(img_name) # TODO: update to use metrics functions if need_HR: # metrics # Average PSNR/SSIM results ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n' \ .format(test_set_name, ave_psnr, ave_ssim)) if test_results['psnr_y'] and test_results['ssim_y']: ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n' \ .format(ave_psnr_y, ave_ssim_y))