def test_loop(model, opt, dataloaders, data_params): logger = util.get_root_logger() # read data_params znorms = data_params['znorm'] # prepare the metric calculation classes for RGB and Y_only images calc_metrics = opt.get('metrics', None) if calc_metrics: test_metrics = metrics.MetricsDict(metrics = calc_metrics) test_metrics_y = metrics.MetricsDict(metrics = calc_metrics) for phase, dataloader in dataloaders.items(): name = dataloader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(name)) dataset_dir = os.path.join(opt['path']['results_root'], name) util.mkdir(dataset_dir) for data in dataloader: znorm = znorms[name] need_HR = False if dataloader.dataset.opt['dataroot_HR'] is None else True # set up per image CEM wrapper if configured CEM_net = get_CEM(opt, data) model.feed_data(data, need_HR=need_HR) # unpack data from data loader # img_path = data['LR_path'][0] img_path = get_img_path(data) img_name = os.path.splitext(os.path.basename(img_path))[0] # test with eval mode. This only affects layers like batchnorm and dropout. test_mode = opt.get('test_mode', None) if test_mode == 'x8': # geometric self-ensemble model.test_x8(CEM_net=CEM_net) elif test_mode == 'chop': # chop images in patches/crops, to reduce VRAM usage model.test_chop(patch_size=opt.get('chop_patch_size', 100), step=opt.get('chop_step', 0.9), CEM_net=CEM_net) else: # normal inference model.test(CEM_net=CEM_net) # run inference # get image results visuals = model.get_current_visuals(need_HR=need_HR) # post-process options if using CEM if opt.get('use_cem', None) and opt['cem_config'].get('out_orig', False): # run regular inference if test_mode == 'x8': model.test_x8() elif test_mode == 'chop': model.test_chop(patch_size=opt.get('chop_patch_size', 100), step=opt.get('chop_step', 0.9)) else: model.test() orig_visuals = model.get_current_visuals(need_HR=need_HR) if opt['cem_config'].get('out_filter', False): GF = GuidedFilter(ks=opt['cem_config'].get('out_filter_ks', 7)) filt = GF(visuals['SR'].unsqueeze(0), (visuals['SR']-orig_visuals['SR']).unsqueeze(0)).squeeze(0) visuals['SR'] = orig_visuals['SR']+filt if opt['cem_config'].get('out_keepY', False): out_regY = rgb_to_ycbcr(orig_visuals['SR']).unsqueeze(0) out_cemY = rgb_to_ycbcr(visuals['SR']).unsqueeze(0) visuals['SR'] = ycbcr_to_rgb(torch.cat([out_regY[:, 0:1, :, :], out_cemY[:, 1:2, :, :], out_cemY[:, 2:3, :, :]], 1)).squeeze(0) res_options = visuals_check(visuals.keys(), opt.get('val_comparison', None)) # save images save_img_path = os.path.join(dataset_dir, img_name + opt.get('suffix', '')) # save single images or lr / sr comparison if opt['val_comparison'] and len(res_options['save_imgs']) > 1: comp_images = [tensor2np(visuals[save_img_name], denormalize=znorm) for save_img_name in res_options['save_imgs']] util.save_img_comp(comp_images, save_img_path + '.png') else: for save_img_name in res_options['save_imgs']: imn = '_' + save_img_name if len(res_options['save_imgs']) > 1 else '' util.save_img(tensor2np(visuals[save_img_name], denormalize=znorm), save_img_path + imn + '.png') # calculate metrics if HR dataset is provided and metrics are configured in options if need_HR and calc_metrics and res_options['aligned_metrics']: metric_imgs = [tensor2np(visuals[x], denormalize=znorm) for x in res_options['compare_imgs']] test_results = test_metrics.calculate_metrics(metric_imgs[0], metric_imgs[1], crop_size=opt['scale']) # prepare single image metrics log message logger_m = '{:20s} -'.format(img_name) for k, v in test_results: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res if gt_img.shape[2] == 3: # RGB image, calculate y_only metrics test_results_y = test_metrics_y.calculate_metrics(metric_imgs[0], metric_imgs[1], crop_size=opt['scale'], only_y=True) # add the y only results to the single image log message for k, v in test_results_y: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res logger.info(logger_m) else: logger.info(img_name) # average metrics results for the dataset if need_HR and calc_metrics: # aggregate the metrics results (automatically resets the metric classes) avg_metrics = test_metrics.get_averages() avg_metrics_y = test_metrics_y.get_averages() # prepare log average metrics message agg_logger_m = '' for r in avg_metrics: formatted_res = r['name'].upper() + ': {:.6f}, '.format(r['average']) agg_logger_m += formatted_res logger.info('----Average metrics results for {}----\n\t'.format(name) + agg_logger_m[:-2]) if len(avg_metrics_y > 0): # prepare log average Y channel metrics message agg_logger_m = '' for r in avg_metrics_y: formatted_res = r['name'].upper() + '_Y' + ': {:.6f}, '.format(r['average']) agg_logger_m += formatted_res logger.info('----Y channel, average metrics ----\n\t' + agg_logger_m[:-2])
def test_loop(model, opt, dataloaders, data_params): logger = util.get_root_logger() # read data_params znorms = data_params['znorm'] # prepare the metric calculation classes for RGB and Y_only images calc_metrics = opt.get('metrics', None) if calc_metrics: test_metrics = metrics.MetricsDict(metrics=calc_metrics) test_metrics_y = metrics.MetricsDict(metrics=calc_metrics) for phase, dataloader in dataloaders.items(): name = dataloader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(name)) dataset_dir = os.path.join(opt['path']['results_root'], name) util.mkdir(dataset_dir) nlls = [] for data in dataloader: znorm = znorms[name] need_HR = False if dataloader.dataset.opt[ 'dataroot_HR'] is None else True # set up per image CEM wrapper if configured CEM_net = get_CEM(opt, data) model.feed_data(data, need_HR=need_HR) # unpack data from data loader # img_path = data['LR_path'][0] img_path = get_img_path(data) img_name = os.path.splitext(os.path.basename(img_path))[0] # test with eval mode. This only affects layers like batchnorm and dropout. test_mode = opt.get('test_mode', None) if test_mode == 'x8': # geometric self-ensemble # model.test_x8(CEM_net=CEM_net) break elif test_mode == 'chop': # chop images in patches/crops, to reduce VRAM usage # model.test_chop(patch_size=opt.get('chop_patch_size', 100), # step=opt.get('chop_step', 0.9), # CEM_net=CEM_net) break else: # normal inference model.test(CEM_net=CEM_net) # run inference if hasattr(model, 'nll'): nll = model.nll if model.nll else 0 nlls.append(nll) # get image results visuals = model.get_current_visuals(need_HR=need_HR) res_options = visuals_check(visuals.keys(), opt.get('val_comparison', None)) # save images save_img_path = os.path.join(dataset_dir, img_name + opt.get('suffix', '')) # Save SR images for reference sr_img = None if hasattr(model, 'heats'): # SRFlow opt['val_comparison'] = False for heat in model.heats: for i in range(model.n_sample): for save_img_name in res_options['save_imgs']: imn = '_' + save_img_name if len( res_options['save_imgs']) > 1 else '' imn += '_h{:03d}_s{:d}'.format(int(heat * 100), i) util.save_img( tensor2np(visuals[save_img_name, heat, i], denormalize=znorm), save_img_path + imn + '.png') else: # regular SR if not opt['val_comparison']: for save_img_name in res_options['save_imgs']: imn = '_' + save_img_name if len( res_options['save_imgs']) > 1 else '' util.save_img( tensor2np(visuals[save_img_name], denormalize=znorm), save_img_path + imn + '.png') # save single images or lr / sr comparison if opt['val_comparison'] and len(res_options['save_imgs']) > 1: comp_images = [ tensor2np(visuals[save_img_name], denormalize=znorm) for save_img_name in res_options['save_imgs'] ] util.save_img_comp(comp_images, save_img_path + '.png') # else: # util.save_img(sr_img, save_img_path) # calculate metrics if HR dataset is provided and metrics are configured in options if need_HR and calc_metrics and res_options['aligned_metrics']: metric_imgs = [ tensor2np(visuals[x], denormalize=znorm) for x in res_options['compare_imgs'] ] test_results = test_metrics.calculate_metrics( metric_imgs[0], metric_imgs[1], crop_size=opt['scale']) # prepare single image metrics log message logger_m = '{:20s} -'.format(img_name) for k, v in test_results: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res if gt_img.shape[2] == 3: # RGB image, calculate y_only metrics test_results_y = test_metrics_y.calculate_metrics( metric_imgs[0], metric_imgs[1], crop_size=opt['scale'], only_y=True) # add the y only results to the single image log message for k, v in test_results_y: formatted_res = k.upper() + ': {:.6f}, '.format(v) logger_m += formatted_res logger.info(logger_m) else: logger.info(img_name) # average metrics results for the dataset if need_HR and calc_metrics: # aggregate the metrics results (automatically resets the metric classes) avg_metrics = test_metrics.get_averages() avg_metrics_y = test_metrics_y.get_averages() # prepare log average metrics message agg_logger_m = '' for r in avg_metrics: formatted_res = r['name'].upper() + ': {:.6f}, '.format( r['average']) agg_logger_m += formatted_res logger.info( '----Average metrics results for {}----\n\t'.format(name) + agg_logger_m[:-2]) if len(avg_metrics_y > 0): # prepare log average Y channel metrics message agg_logger_m = '' for r in avg_metrics_y: formatted_res = r['name'].upper( ) + '_Y' + ': {:.6f}, '.format(r['average']) agg_logger_m += formatted_res logger.info('----Y channel, average metrics ----\n\t' + agg_logger_m[:-2])
def fit(model, opt, dataloaders, steps_states, data_params, loggers): # read data_params batch_size = data_params['batch_size'] virtual_batch_size = data_params['virtual_batch_size'] total_iters = data_params['total_iters'] total_epochs = data_params['total_epochs'] # read steps_states start_epoch = steps_states["start_epoch"] current_step = steps_states["current_step"] virtual_step = steps_states["virtual_step"] # read loggers logger = util.get_root_logger() tb_logger = loggers["tb_logger"] # training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) try: timer = metrics.Timer() # iteration timer timerData = metrics.TickTock() # data timer timerEpoch = metrics.TickTock() # epoch timer # outer loop for different epochs for epoch in range(start_epoch, (total_epochs * (virtual_batch_size // batch_size))+1): timerData.tick() timerEpoch.tick() # inner iteration loop within one epoch for n, train_data in enumerate(dataloaders['train'], start=1): timerData.tock() virtual_step += 1 take_step = False if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0: current_step += 1 take_step = True if current_step > total_iters: break # training model.feed_data(train_data) # unpack data from dataset and apply preprocessing model.optimize_parameters(virtual_step) # calculate loss functions, get gradients, update network weights # log def eta(t_iter): # calculate training ETA in hours return (t_iter * (opt['train']['niter'] - current_step)) / 3600 if t_iter > 0 else 0 if current_step % opt['logger']['print_freq'] == 0 and take_step: # iteration end time avg_time = timer.get_average_and_reset() avg_data_time = timerData.get_average_and_reset() # print training losses and save logging information to disk message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.4f}s, td:{:.4f}s, eta:{:.4f}h> '.format( epoch, current_step, model.get_current_learning_rate(current_step), avg_time, avg_data_time, eta(avg_time)) # tensorboard training logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs # tb_logger.add_scalar('loss/nll', nll, current_step) tb_logger.add_scalar('lr/base', model.get_current_learning_rate(), current_step) tb_logger.add_scalar('time/iteration', timer.get_last_iteration(), current_step) tb_logger.add_scalar('time/data', timerData.get_last_iteration(), current_step) logs = model.get_current_log() for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard loss logger if opt['use_tb_logger'] and 'debug' not in opt['name']: if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs tb_logger.add_scalar(k, v, current_step) # tb_logger.flush() logger.info(message) # start time for next iteration #TODO:skip the validation time from calculation timer.tick() # update learning rate if model.optGstep and model.optDstep and take_step: model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1)) # save latest models and training states every <save_checkpoint_freq> iterations if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step: if model.swa: model.save(current_step, opt['logger']['overwrite_chkp'], loader=dataloaders['train']) else: model.save(current_step, opt['logger']['overwrite_chkp']) model.save_training_state( epoch=epoch + (n >= len(dataloaders['train'])), iter_step=current_step, latest=opt['logger']['overwrite_chkp'] ) logger.info('Models and training states saved.') # validation if dataloaders.get('val', None) and current_step % opt['train']['val_freq'] == 0 and take_step: val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None)) nlls = [] for val_data in dataloaders['val']: model.feed_data(val_data) # unpack data from data loader model.test() # run inference if hasattr(model, 'nll'): nll = model.nll if model.nll else 0 nlls.append(nll) """ Get Visuals """ visuals = model.get_current_visuals() # get image results img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) # Save SR images for reference sr_img = None if hasattr(model, 'heats'): # SRFlow opt['train']['val_comparison'] = False for heat in model.heats: for i in range(model.n_sample): sr_img = tensor2np(visuals['SR', heat, i], denormalize=opt['datasets']['train']['znorm']) if opt['train']['overwrite_val_imgs']: save_img_path = os.path.join(img_dir, '{:s}_h{:03d}_s{:d}.png'.format(img_name, int(heat * 100), i)) else: save_img_path = os.path.join(img_dir, '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name, current_step, int(heat * 100), i)) util.save_img(sr_img, save_img_path) else: # regular SR sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm']) if opt['train']['overwrite_val_imgs']: save_img_path = os.path.join(img_dir, '{:s}.png'.format(img_name)) else: save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) if not opt['train']['val_comparison']: util.save_img(sr_img, save_img_path) assert sr_img is not None # Save GT images for reference gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm']) if opt['train']['save_gt']: save_img_path_gt = os.path.join(img_dir, '{:s}_GT.png'.format(img_name)) if not os.path.isfile(save_img_path_gt): util.save_img(gt_img, save_img_path_gt) # Save LQ images for reference if opt['train']['save_lr']: save_img_path_lq = os.path.join(img_dir, '{:s}_LQ.png'.format(img_name)) if not os.path.isfile(save_img_path_lq): lq_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm']) util.save_img(lq_img, save_img_path_lq, scale=opt['scale']) # save single images or LQ / SR comparison if opt['train']['val_comparison']: lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm']) util.save_img_comp([lr_img, sr_img], save_img_path) # else: # util.save_img(sr_img, save_img_path) """ Get Metrics # TODO: test using tensor based metrics (batch) instead of numpy. """ val_metrics.calculate_metrics(sr_img, gt_img, crop_size=opt['scale']) # , only_y=True) avg_metrics = val_metrics.get_averages() if nlls: avg_nll = sum(nlls) / len(nlls) del val_metrics # log logger_m = '' for r in avg_metrics: formatted_res = r['name'].upper() + ': {:.5g}, '.format(r['average']) logger_m += formatted_res if nlls: logger_m += 'avg_nll: {:.4e} '.format(avg_nll) logger.info('# Validation # ' + logger_m[:-2]) logger_val = logging.getLogger('val') # validation logger logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step) + logger_m[:-2]) # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: for r in avg_metrics: tb_logger.add_scalar(r['name'], r['average'], current_step) if nlls: tb_logger.add_scalar('average nll', avg_nll, current_step) # tb_logger.flush() # tb_logger_valid.add_scalar(r['name'], r['average'], current_step) # tb_logger_valid.flush() timerData.tick() timerEpoch.tock() logger.info('End of epoch {} / {} \t Time Taken: {:.4f} sec'.format( epoch, total_epochs, timerEpoch.get_last_iteration())) logger.info('Saving the final model.') if model.swa: model.save('latest', loader=dataloaders['train']) else: model.save('latest') logger.info('End of training.') except KeyboardInterrupt: # catch a KeyboardInterrupt and save the model and state to resume later if model.swa: model.save(current_step, True, loader=dataloaders['train']) else: model.save(current_step, True) model.save_training_state(epoch + (n >= len(dataloaders['train'])), current_step, True) logger.info('Training interrupted. Latest models and training states saved.')
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) opt = option.dict_to_nonedict(opt) # Convert to NoneDict, which return None for missing key. # train from scratch OR resume training if opt['path']['resume_state']: if os.path.isdir(opt['path']['resume_state']): import glob resume_state_path = util.sorted_nicely(glob.glob(os.path.normpath(opt['path']['resume_state']) + '/*.state'))[-1] else: resume_state_path = opt['path']['resume_state'] resume_state = torch.load(resume_state_path) else: # training from scratch resume_state = None util.mkdir_and_rename(opt['path']['experiments_root']) # rename old folder if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True) util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) logger = logging.getLogger('base') if resume_state: logger.info('Set [resume_state] to ' + resume_state_path) logger.info('Resuming training from epoch: {}, iter: {}.'.format( resume_state['epoch'], resume_state['iter'])) option.check_resume(opt) # check resume options logger.info(option.dict2str(opt)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: from tensorboardX import SummaryWriter try: tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name']) #for version tensorboardX >= 1.7 except: tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) #for version tensorboardX < 1.6 # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) util.set_random_seed(seed) # if the model does not change and input sizes remain the same during training then there may be benefit # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # create train and val dataloader val_loader = False for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) batch_size = dataset_opt.get('batch_size', 4) virtual_batch_size = dataset_opt.get('virtual_batch_size', batch_size) virtual_batch_size = virtual_batch_size if virtual_batch_size > batch_size else batch_size train_size = int(math.ceil(len(train_set) / batch_size)) logger.info('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epochs = int(math.ceil(total_iters / train_size)) logger.info('Total epochs needed: {:d} for iters {:,d}'.format( total_epochs, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # create model model = create_model(opt) # resume training if resume_state: start_epoch = resume_state['epoch'] current_step = resume_state['iter'] virtual_step = current_step * virtual_batch_size / batch_size \ if virtual_batch_size and virtual_batch_size > batch_size else current_step model.resume_training(resume_state) # handle optimizers and schedulers model.update_schedulers(opt['train']) # updated schedulers in case JSON configuration has changed del resume_state # start the iteration time when resuming t0 = time.time() else: current_step = 0 virtual_step = 0 start_epoch = 0 # training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) try: for epoch in range(start_epoch, total_epochs*(virtual_batch_size//batch_size)): for n, train_data in enumerate(train_loader,start=1): if virtual_step == 0: # first iteration start time t0 = time.time() virtual_step += 1 take_step = False if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0: current_step += 1 take_step = True if current_step > total_iters: break # training model.feed_data(train_data) model.optimize_parameters(virtual_step) # log if current_step % opt['logger']['print_freq'] == 0 and take_step: # iteration end time t1 = time.time() logs = model.get_current_log() message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, i_time: {:.4f} sec.> '.format( epoch, current_step, model.get_current_learning_rate(current_step), (t1 - t0)) for k, v in logs.items(): message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: tb_logger.add_scalar(k, v, current_step) logger.info(message) # # start time for next iteration # t0 = time.time() # update learning rate if model.optGstep and model.optDstep and take_step: model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1)) # save models and training states (changed to save models before validation) if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step: if model.swa: model.save(current_step, opt['logger']['overwrite_chkp'], loader=train_loader) else: model.save(current_step, opt['logger']['overwrite_chkp']) model.save_training_state(epoch + (n >= len(train_loader)), current_step, opt['logger']['overwrite_chkp']) logger.info('Models and training states saved.') # validation if val_loader and current_step % opt['train']['val_freq'] == 0 and take_step: val_sr_imgs_list = [] val_gt_imgs_list = [] val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None)) for val_data in val_loader: img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test(val_data) """ Get Visuals """ visuals = model.get_current_visuals() sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm']) gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm']) # Save SR images for reference if opt['train']['overwrite_val_imgs']: save_img_path = os.path.join(img_dir, '{:s}.png'.format(\ img_name)) else: save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\ img_name, current_step)) # save single images or lr / sr comparison if opt['train']['val_comparison']: lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm']) util.save_img_comp(lr_img, sr_img, save_img_path) else: util.save_img(sr_img, save_img_path) """ Get Metrics # TODO: test using tensor based metrics (batch) instead of numpy. """ crop_size = opt['scale'] val_metrics.calculate_metrics(sr_img, gt_img, crop_size = crop_size) #, only_y=True) avg_metrics = val_metrics.get_averages() del val_metrics # log logger_m = '' for r in avg_metrics: #print(r) formatted_res = r['name'].upper()+': {:.5g}, '.format(r['average']) logger_m += formatted_res logger.info('# Validation # '+logger_m[:-2]) logger_val = logging.getLogger('val') # validation logger logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step)+logger_m[:-2]) # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: for r in avg_metrics: tb_logger.add_scalar(r['name'], r['average'], current_step) # # reset time for next iteration to skip the validation time from calculation # t0 = time.time() if current_step % opt['logger']['print_freq'] == 0 and take_step or \ (val_loader and current_step % opt['train']['val_freq'] == 0 and take_step): # reset time for next iteration to skip the validation time from calculation t0 = time.time() logger.info('Saving the final model.') if model.swa: model.save('latest', loader=train_loader) else: model.save('latest') logger.info('End of training.') except KeyboardInterrupt: # catch a KeyboardInterrupt and save the model and state to resume later if model.swa: model.save(current_step, True, loader=train_loader) else: model.save(current_step, True) model.save_training_state(epoch + (n >= len(train_loader)), current_step, True) logger.info('Training interrupted. Latest models and training states saved.')