Example #1
0
def test_loop(model, opt, dataloaders, data_params):
    logger = util.get_root_logger()

    # read data_params
    znorms = data_params['znorm']

    # prepare the metric calculation classes for RGB and Y_only images
    calc_metrics = opt.get('metrics', None)
    if calc_metrics:
        test_metrics = metrics.MetricsDict(metrics = calc_metrics)
        test_metrics_y = metrics.MetricsDict(metrics = calc_metrics)

    for phase, dataloader in dataloaders.items():
        name = dataloader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(name))
        dataset_dir = os.path.join(opt['path']['results_root'], name)
        util.mkdir(dataset_dir)

        for data in dataloader:
            znorm = znorms[name]
            need_HR = False if dataloader.dataset.opt['dataroot_HR'] is None else True

            # set up per image CEM wrapper if configured
            CEM_net = get_CEM(opt, data)

            model.feed_data(data, need_HR=need_HR)  # unpack data from data loader
            # img_path = data['LR_path'][0]
            img_path = get_img_path(data)
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            # test with eval mode. This only affects layers like batchnorm and dropout.
            test_mode = opt.get('test_mode', None)
            if test_mode == 'x8':
                # geometric self-ensemble
                model.test_x8(CEM_net=CEM_net)
            elif test_mode == 'chop':
                # chop images in patches/crops, to reduce VRAM usage
                model.test_chop(patch_size=opt.get('chop_patch_size', 100), 
                                step=opt.get('chop_step', 0.9),
                                CEM_net=CEM_net)
            else:
                # normal inference
                model.test(CEM_net=CEM_net)  # run inference
            
            # get image results
            visuals = model.get_current_visuals(need_HR=need_HR)

            # post-process options if using CEM
            if opt.get('use_cem', None) and opt['cem_config'].get('out_orig', False):
                # run regular inference
                if test_mode == 'x8':
                    model.test_x8()
                elif test_mode == 'chop':
                    model.test_chop(patch_size=opt.get('chop_patch_size', 100), 
                                    step=opt.get('chop_step', 0.9))
                else:
                    model.test()
                orig_visuals = model.get_current_visuals(need_HR=need_HR)

                if opt['cem_config'].get('out_filter', False):
                    GF = GuidedFilter(ks=opt['cem_config'].get('out_filter_ks', 7))
                    filt = GF(visuals['SR'].unsqueeze(0), (visuals['SR']-orig_visuals['SR']).unsqueeze(0)).squeeze(0)
                    visuals['SR'] = orig_visuals['SR']+filt

                if opt['cem_config'].get('out_keepY', False):
                    out_regY = rgb_to_ycbcr(orig_visuals['SR']).unsqueeze(0)
                    out_cemY = rgb_to_ycbcr(visuals['SR']).unsqueeze(0)
                    visuals['SR'] = ycbcr_to_rgb(torch.cat([out_regY[:, 0:1, :, :], out_cemY[:, 1:2, :, :], out_cemY[:, 2:3, :, :]], 1)).squeeze(0)

            res_options = visuals_check(visuals.keys(), opt.get('val_comparison', None))

            # save images
            save_img_path = os.path.join(dataset_dir, img_name + opt.get('suffix', ''))

            # save single images or lr / sr comparison
            if opt['val_comparison'] and len(res_options['save_imgs']) > 1:
                comp_images = [tensor2np(visuals[save_img_name], denormalize=znorm) for save_img_name in res_options['save_imgs']]
                util.save_img_comp(comp_images, save_img_path + '.png')
            else:
                for save_img_name in res_options['save_imgs']:
                    imn = '_' + save_img_name if len(res_options['save_imgs']) > 1 else ''
                    util.save_img(tensor2np(visuals[save_img_name], denormalize=znorm), save_img_path + imn + '.png')

            # calculate metrics if HR dataset is provided and metrics are configured in options
            if need_HR and calc_metrics and res_options['aligned_metrics']:
                metric_imgs = [tensor2np(visuals[x], denormalize=znorm) for x in res_options['compare_imgs']]
                test_results = test_metrics.calculate_metrics(metric_imgs[0], metric_imgs[1], 
                                                              crop_size=opt['scale'])
                
                # prepare single image metrics log message
                logger_m = '{:20s} -'.format(img_name)
                for k, v in test_results:
                    formatted_res = k.upper() + ': {:.6f}, '.format(v)
                    logger_m += formatted_res

                if gt_img.shape[2] == 3:  # RGB image, calculate y_only metrics
                    test_results_y = test_metrics_y.calculate_metrics(metric_imgs[0], metric_imgs[1], 
                                                                      crop_size=opt['scale'], only_y=True)
                    
                    # add the y only results to the single image log message
                    for k, v in test_results_y:
                        formatted_res = k.upper() + ': {:.6f}, '.format(v)
                        logger_m += formatted_res
                
                logger.info(logger_m)
            else:
                logger.info(img_name)

        # average metrics results for the dataset
        if need_HR and calc_metrics:
            
            # aggregate the metrics results (automatically resets the metric classes)
            avg_metrics = test_metrics.get_averages()
            avg_metrics_y = test_metrics_y.get_averages()

            # prepare log average metrics message
            agg_logger_m = ''
            for r in avg_metrics:
                formatted_res = r['name'].upper() + ': {:.6f}, '.format(r['average'])
                agg_logger_m += formatted_res
            logger.info('----Average metrics results for {}----\n\t'.format(name) + agg_logger_m[:-2])
            
            if len(avg_metrics_y > 0):
                # prepare log average Y channel metrics message
                agg_logger_m = ''
                for r in avg_metrics_y:
                    formatted_res = r['name'].upper() + '_Y' + ': {:.6f}, '.format(r['average'])
                    agg_logger_m += formatted_res
                logger.info('----Y channel, average metrics ----\n\t' + agg_logger_m[:-2])
Example #2
0
def test_loop(model, opt, dataloaders, data_params):
    logger = util.get_root_logger()

    # read data_params
    znorms = data_params['znorm']

    # prepare the metric calculation classes for RGB and Y_only images
    calc_metrics = opt.get('metrics', None)
    if calc_metrics:
        test_metrics = metrics.MetricsDict(metrics=calc_metrics)
        test_metrics_y = metrics.MetricsDict(metrics=calc_metrics)

    for phase, dataloader in dataloaders.items():
        name = dataloader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(name))
        dataset_dir = os.path.join(opt['path']['results_root'], name)
        util.mkdir(dataset_dir)

        nlls = []
        for data in dataloader:
            znorm = znorms[name]
            need_HR = False if dataloader.dataset.opt[
                'dataroot_HR'] is None else True

            # set up per image CEM wrapper if configured
            CEM_net = get_CEM(opt, data)

            model.feed_data(data,
                            need_HR=need_HR)  # unpack data from data loader
            # img_path = data['LR_path'][0]
            img_path = get_img_path(data)
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            # test with eval mode. This only affects layers like batchnorm and dropout.
            test_mode = opt.get('test_mode', None)
            if test_mode == 'x8':
                # geometric self-ensemble
                # model.test_x8(CEM_net=CEM_net)
                break
            elif test_mode == 'chop':
                # chop images in patches/crops, to reduce VRAM usage
                # model.test_chop(patch_size=opt.get('chop_patch_size', 100),
                #                 step=opt.get('chop_step', 0.9),
                #                 CEM_net=CEM_net)
                break
            else:
                # normal inference
                model.test(CEM_net=CEM_net)  # run inference

            if hasattr(model, 'nll'):
                nll = model.nll if model.nll else 0
                nlls.append(nll)

            # get image results
            visuals = model.get_current_visuals(need_HR=need_HR)

            res_options = visuals_check(visuals.keys(),
                                        opt.get('val_comparison', None))

            # save images
            save_img_path = os.path.join(dataset_dir,
                                         img_name + opt.get('suffix', ''))

            # Save SR images for reference
            sr_img = None
            if hasattr(model, 'heats'):  # SRFlow
                opt['val_comparison'] = False
                for heat in model.heats:
                    for i in range(model.n_sample):
                        for save_img_name in res_options['save_imgs']:
                            imn = '_' + save_img_name if len(
                                res_options['save_imgs']) > 1 else ''
                            imn += '_h{:03d}_s{:d}'.format(int(heat * 100), i)
                            util.save_img(
                                tensor2np(visuals[save_img_name, heat, i],
                                          denormalize=znorm),
                                save_img_path + imn + '.png')
            else:  # regular SR
                if not opt['val_comparison']:
                    for save_img_name in res_options['save_imgs']:
                        imn = '_' + save_img_name if len(
                            res_options['save_imgs']) > 1 else ''
                        util.save_img(
                            tensor2np(visuals[save_img_name],
                                      denormalize=znorm),
                            save_img_path + imn + '.png')

            # save single images or lr / sr comparison
            if opt['val_comparison'] and len(res_options['save_imgs']) > 1:
                comp_images = [
                    tensor2np(visuals[save_img_name], denormalize=znorm)
                    for save_img_name in res_options['save_imgs']
                ]
                util.save_img_comp(comp_images, save_img_path + '.png')
            # else:
            # util.save_img(sr_img, save_img_path)

            # calculate metrics if HR dataset is provided and metrics are configured in options
            if need_HR and calc_metrics and res_options['aligned_metrics']:
                metric_imgs = [
                    tensor2np(visuals[x], denormalize=znorm)
                    for x in res_options['compare_imgs']
                ]
                test_results = test_metrics.calculate_metrics(
                    metric_imgs[0], metric_imgs[1], crop_size=opt['scale'])

                # prepare single image metrics log message
                logger_m = '{:20s} -'.format(img_name)
                for k, v in test_results:
                    formatted_res = k.upper() + ': {:.6f}, '.format(v)
                    logger_m += formatted_res

                if gt_img.shape[2] == 3:  # RGB image, calculate y_only metrics
                    test_results_y = test_metrics_y.calculate_metrics(
                        metric_imgs[0],
                        metric_imgs[1],
                        crop_size=opt['scale'],
                        only_y=True)

                    # add the y only results to the single image log message
                    for k, v in test_results_y:
                        formatted_res = k.upper() + ': {:.6f}, '.format(v)
                        logger_m += formatted_res

                logger.info(logger_m)
            else:
                logger.info(img_name)

        # average metrics results for the dataset
        if need_HR and calc_metrics:

            # aggregate the metrics results (automatically resets the metric classes)
            avg_metrics = test_metrics.get_averages()
            avg_metrics_y = test_metrics_y.get_averages()

            # prepare log average metrics message
            agg_logger_m = ''
            for r in avg_metrics:
                formatted_res = r['name'].upper() + ': {:.6f}, '.format(
                    r['average'])
                agg_logger_m += formatted_res
            logger.info(
                '----Average metrics results for {}----\n\t'.format(name) +
                agg_logger_m[:-2])

            if len(avg_metrics_y > 0):
                # prepare log average Y channel metrics message
                agg_logger_m = ''
                for r in avg_metrics_y:
                    formatted_res = r['name'].upper(
                    ) + '_Y' + ': {:.6f}, '.format(r['average'])
                    agg_logger_m += formatted_res
                logger.info('----Y channel, average metrics ----\n\t' +
                            agg_logger_m[:-2])
Example #3
0
def fit(model, opt, dataloaders, steps_states, data_params, loggers):
    # read data_params
    batch_size = data_params['batch_size']
    virtual_batch_size = data_params['virtual_batch_size']
    total_iters = data_params['total_iters']
    total_epochs = data_params['total_epochs']

    # read steps_states
    start_epoch = steps_states["start_epoch"]
    current_step = steps_states["current_step"]
    virtual_step = steps_states["virtual_step"]

    # read loggers
    logger = util.get_root_logger()
    tb_logger = loggers["tb_logger"]
    
    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    try:
        timer = metrics.Timer()  # iteration timer
        timerData = metrics.TickTock()  # data timer
        timerEpoch = metrics.TickTock()  # epoch timer
        # outer loop for different epochs
        for epoch in range(start_epoch, (total_epochs * (virtual_batch_size // batch_size))+1):
            timerData.tick()
            timerEpoch.tick()

            # inner iteration loop within one epoch
            for n, train_data in enumerate(dataloaders['train'], start=1):
                timerData.tock()

                virtual_step += 1
                take_step = False
                if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0:
                    current_step += 1
                    take_step = True
                    if current_step > total_iters:
                        break

                # training
                model.feed_data(train_data)  # unpack data from dataset and apply preprocessing
                model.optimize_parameters(virtual_step)  # calculate loss functions, get gradients, update network weights

                # log
                def eta(t_iter):
                    # calculate training ETA in hours
                    return (t_iter * (opt['train']['niter'] - current_step)) / 3600 if t_iter > 0 else 0

                if current_step % opt['logger']['print_freq'] == 0 and take_step:
                    # iteration end time
                    avg_time = timer.get_average_and_reset()
                    avg_data_time = timerData.get_average_and_reset()

                    # print training losses and save logging information to disk
                    message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.4f}s, td:{:.4f}s, eta:{:.4f}h> '.format(
                        epoch, current_step, model.get_current_learning_rate(current_step), avg_time, 
                        avg_data_time, eta(avg_time))
                    
                    # tensorboard training logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs
                            # tb_logger.add_scalar('loss/nll', nll, current_step)
                            tb_logger.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
                            tb_logger.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
                            tb_logger.add_scalar('time/data', timerData.get_last_iteration(), current_step)

                    logs = model.get_current_log()
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard loss logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs
                                tb_logger.add_scalar(k, v, current_step)
                            # tb_logger.flush()
                    logger.info(message)

                    # start time for next iteration #TODO:skip the validation time from calculation
                    timer.tick()

                # update learning rate
                if model.optGstep and model.optDstep and take_step:
                    model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1))
                
                # save latest models and training states every <save_checkpoint_freq> iterations
                if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step:
                    if model.swa: 
                        model.save(current_step, opt['logger']['overwrite_chkp'], loader=dataloaders['train'])
                    else:
                        model.save(current_step, opt['logger']['overwrite_chkp'])
                    model.save_training_state(
                        epoch=epoch + (n >= len(dataloaders['train'])),
                        iter_step=current_step,
                        latest=opt['logger']['overwrite_chkp']
                    )
                    logger.info('Models and training states saved.')

                # validation
                if dataloaders.get('val', None) and current_step % opt['train']['val_freq'] == 0 and take_step:
                    val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None))
                    nlls = []
                    for val_data in dataloaders['val']:
                        
                        model.feed_data(val_data)  # unpack data from data loader
                        model.test()  # run inference
                        if hasattr(model, 'nll'):
                            nll = model.nll if model.nll else 0
                            nlls.append(nll)

                        """
                        Get Visuals
                        """
                        visuals = model.get_current_visuals()  # get image results
                        img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)
                        
                        # Save SR images for reference
                        sr_img = None
                        if hasattr(model, 'heats'):  # SRFlow
                            opt['train']['val_comparison'] = False
                            for heat in model.heats:
                                for i in range(model.n_sample):
                                    sr_img = tensor2np(visuals['SR', heat, i], denormalize=opt['datasets']['train']['znorm'])
                                    if opt['train']['overwrite_val_imgs']:
                                        save_img_path = os.path.join(img_dir,
                                                                '{:s}_h{:03d}_s{:d}.png'.format(img_name, int(heat * 100), i))
                                    else:
                                        save_img_path = os.path.join(img_dir,
                                                                '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
                                                                                                        current_step,
                                                                                                        int(heat * 100), i))
                                    util.save_img(sr_img, save_img_path)
                        else:  # regular SR
                            sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm'])
                            if opt['train']['overwrite_val_imgs']:
                                save_img_path = os.path.join(img_dir, '{:s}.png'.format(img_name))
                            else:
                                save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
                            if not opt['train']['val_comparison']:
                                util.save_img(sr_img, save_img_path)
                        assert sr_img is not None

                        # Save GT images for reference
                        gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm'])
                        if opt['train']['save_gt']:
                            save_img_path_gt = os.path.join(img_dir,
                                                            '{:s}_GT.png'.format(img_name))
                            if not os.path.isfile(save_img_path_gt):
                                util.save_img(gt_img, save_img_path_gt)

                        # Save LQ images for reference
                        if opt['train']['save_lr']:
                            save_img_path_lq = os.path.join(img_dir,
                                                            '{:s}_LQ.png'.format(img_name))
                            if not os.path.isfile(save_img_path_lq):
                                lq_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                                util.save_img(lq_img, save_img_path_lq, scale=opt['scale'])

                        # save single images or LQ / SR comparison
                        if opt['train']['val_comparison']:
                            lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                            util.save_img_comp([lr_img, sr_img], save_img_path)
                        # else:
                            # util.save_img(sr_img, save_img_path)

                        """
                        Get Metrics
                        # TODO: test using tensor based metrics (batch) instead of numpy.
                        """
                        val_metrics.calculate_metrics(sr_img, gt_img, crop_size=opt['scale'])  # , only_y=True)

                    avg_metrics = val_metrics.get_averages()
                    if nlls:
                        avg_nll = sum(nlls) / len(nlls)
                    del val_metrics

                    # log
                    logger_m = ''
                    for r in avg_metrics:
                        formatted_res = r['name'].upper() + ': {:.5g}, '.format(r['average'])
                        logger_m += formatted_res
                    if nlls:
                        logger_m += 'avg_nll: {:.4e}  '.format(avg_nll)

                    logger.info('# Validation # ' + logger_m[:-2])
                    logger_val = logging.getLogger('val')  # validation logger
                    logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step) + logger_m[:-2])
                    # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB
                    
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        for r in avg_metrics:
                            tb_logger.add_scalar(r['name'], r['average'], current_step)
                            if nlls:
                                tb_logger.add_scalar('average nll', avg_nll, current_step)
                            # tb_logger.flush()
                            # tb_logger_valid.add_scalar(r['name'], r['average'], current_step)
                            # tb_logger_valid.flush()
                    
                timerData.tick()
            
            timerEpoch.tock()
            logger.info('End of epoch {} / {} \t Time Taken: {:.4f} sec'.format(
                epoch, total_epochs, timerEpoch.get_last_iteration()))

        logger.info('Saving the final model.')
        if model.swa:
            model.save('latest', loader=dataloaders['train'])
        else:
            model.save('latest')
        logger.info('End of training.')

    except KeyboardInterrupt:
        # catch a KeyboardInterrupt and save the model and state to resume later
        if model.swa:
            model.save(current_step, True, loader=dataloaders['train'])
        else:
            model.save(current_step, True)
        model.save_training_state(epoch + (n >= len(dataloaders['train'])), current_step, True)
        logger.info('Training interrupted. Latest models and training states saved.')
Example #4
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(glob.glob(os.path.normpath(opt['path']['resume_state']) + '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name']) #for version tensorboardX >= 1.7
        except:
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) #for version tensorboardX < 1.6

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    # if the model does not change and input sizes remain the same during training then there may be benefit
    # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    val_loader = False
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            batch_size = dataset_opt.get('batch_size', 4)
            virtual_batch_size = dataset_opt.get('virtual_batch_size', batch_size)
            virtual_batch_size = virtual_batch_size if virtual_batch_size > batch_size else batch_size
            train_size = int(math.ceil(len(train_set) / batch_size))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        virtual_step = current_step * virtual_batch_size / batch_size \
            if virtual_batch_size and virtual_batch_size > batch_size else current_step
        model.resume_training(resume_state)  # handle optimizers and schedulers
        model.update_schedulers(opt['train']) # updated schedulers in case JSON configuration has changed
        del resume_state
        # start the iteration time when resuming
        t0 = time.time()
    else:
        current_step = 0
        virtual_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    try:
        for epoch in range(start_epoch, total_epochs*(virtual_batch_size//batch_size)):
            for n, train_data in enumerate(train_loader,start=1):

                if virtual_step == 0:
                    # first iteration start time
                    t0 = time.time()

                virtual_step += 1
                take_step = False
                if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0:
                    current_step += 1
                    take_step = True
                    if current_step > total_iters:
                        break

                # training
                model.feed_data(train_data)
                model.optimize_parameters(virtual_step)

                # log
                if current_step % opt['logger']['print_freq'] == 0 and take_step:
                    # iteration end time
                    t1 = time.time()

                    logs = model.get_current_log()
                    message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, i_time: {:.4f} sec.> '.format(
                        epoch, current_step, model.get_current_learning_rate(current_step), (t1 - t0))
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar(k, v, current_step)
                    logger.info(message)

                    # # start time for next iteration
                    # t0 = time.time()

                # update learning rate
                if model.optGstep and model.optDstep and take_step:
                    model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1))

                # save models and training states (changed to save models before validation)
                if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step:
                    if model.swa:
                        model.save(current_step, opt['logger']['overwrite_chkp'], loader=train_loader)
                    else:
                        model.save(current_step, opt['logger']['overwrite_chkp'])
                    model.save_training_state(epoch + (n >= len(train_loader)), current_step, opt['logger']['overwrite_chkp'])
                    logger.info('Models and training states saved.')

                # validation
                if val_loader and current_step % opt['train']['val_freq'] == 0 and take_step:
                    val_sr_imgs_list = []
                    val_gt_imgs_list = []
                    val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None))
                    for val_data in val_loader:
                        img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test(val_data)

                        """
                        Get Visuals
                        """
                        visuals = model.get_current_visuals()
                        sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm'])
                        gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm'])

                        # Save SR images for reference
                        if opt['train']['overwrite_val_imgs']:
                            save_img_path = os.path.join(img_dir, '{:s}.png'.format(\
                                img_name))
                        else:
                            save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                                img_name, current_step))

                        # save single images or lr / sr comparison
                        if opt['train']['val_comparison']:
                            lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                            util.save_img_comp(lr_img, sr_img, save_img_path)
                        else:
                            util.save_img(sr_img, save_img_path)

                        """
                        Get Metrics
                        # TODO: test using tensor based metrics (batch) instead of numpy.
                        """
                        crop_size = opt['scale']
                        val_metrics.calculate_metrics(sr_img, gt_img, crop_size = crop_size)  #, only_y=True)

                    avg_metrics = val_metrics.get_averages()
                    del val_metrics

                    # log
                    logger_m = ''
                    for r in avg_metrics:
                        #print(r)
                        formatted_res = r['name'].upper()+': {:.5g}, '.format(r['average'])
                        logger_m += formatted_res

                    logger.info('# Validation # '+logger_m[:-2])
                    logger_val = logging.getLogger('val')  # validation logger
                    logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step)+logger_m[:-2])
                    # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB

                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        for r in avg_metrics:
                            tb_logger.add_scalar(r['name'], r['average'], current_step)

                    # # reset time for next iteration to skip the validation time from calculation
                    # t0 = time.time()

                if current_step % opt['logger']['print_freq'] == 0 and take_step or \
                    (val_loader and current_step % opt['train']['val_freq'] == 0 and take_step):
                    # reset time for next iteration to skip the validation time from calculation
                    t0 = time.time()

        logger.info('Saving the final model.')
        if model.swa:
            model.save('latest', loader=train_loader)
        else:
            model.save('latest')
        logger.info('End of training.')

    except KeyboardInterrupt:
        # catch a KeyboardInterrupt and save the model and state to resume later
        if model.swa:
            model.save(current_step, True, loader=train_loader)
        else:
            model.save(current_step, True)
        model.save_training_state(epoch + (n >= len(train_loader)), current_step, True)
        logger.info('Training interrupted. Latest models and training states saved.')