Ejemplo n.º 1
0
def get_resume_state(opt):
    logger = util.get_root_logger()

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            resume_state_path = glob.glob(opt['path']['resume_state'] +
                                          '/*.state')
            resume_state_path = util.sorted_nicely(resume_state_path)[-1]
        else:
            resume_state_path = opt['path']['resume_state']

        if opt['gpu_ids']:
            resume_state = torch.load(resume_state_path)
        else:
            resume_state = torch.load(resume_state_path,
                                      map_location=torch.device('cpu'))

        logger.info('Set [resume_state] to {}'.format(resume_state_path))
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        options.check_resume(opt)  # check resume options
    else:  # training from scratch
        resume_state = None
    return resume_state
Ejemplo n.º 2
0
def get_dataloaders(opt):
    logger = util.get_root_logger()

    gpu_ids = opt.get('gpu_ids', None)
    gpu_ids = gpu_ids if gpu_ids else []

    # Create datasets and dataloaders
    dataloaders = {}
    data_params = {}
    znorm = {}
    for phase, dataset_opt in opt['datasets'].items():
        if opt['is_train'] and phase not in ['train', 'val']:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

        name = dataset_opt['name']
        dataset = create_dataset(dataset_opt)

        if not dataset:
            raise Exception('Dataset "{}" for phase "{}" is empty.'.format(
                name, phase))

        dataloaders[phase] = create_dataloader(dataset, dataset_opt, gpu_ids)

        if opt['is_train'] and phase == 'train':
            batch_size = dataset_opt.get('batch_size', 4)
            virtual_batch_size = dataset_opt.get('virtual_batch_size',
                                                 batch_size)
            virtual_batch_size = virtual_batch_size if virtual_batch_size > batch_size else batch_size
            # train_size = int(math.ceil(len(dataset) / batch_size))
            train_size = len(dataloaders[phase])
            logger.info(f'Number of train images: {len(dataset):,d}, '
                        f'epoch iters: {train_size:,d}')
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info(f'Total epochs needed: {total_epochs:d} for '
                        f'iters {total_iters:,d}')
            data_params = {
                "batch_size": batch_size,
                "virtual_batch_size": virtual_batch_size,
                "total_iters": total_iters,
                "total_epochs": total_epochs
            }
            assert dataset is not None
        else:
            logger.info('Number of {:s} images in [{:s}]: {:,d}'.format(
                phase, name, len(dataset)))
            if phase != 'val':
                znorm[name] = dataset_opt.get('znorm', False)

    if not opt['is_train']:
        data_params['znorm'] = znorm

    if not dataloaders:
        raise Exception("No Dataloader has been created.")

    return dataloaders, data_params
Ejemplo n.º 3
0
def get_random_seed(opt):
    logger = util.get_root_logger()
    # set random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
        opt['train']['manual_seed'] = seed
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)
    return opt
Ejemplo n.º 4
0
def configure_loggers(opt=None):
    tofile = opt.get('logger', {}).get('save_logfile', True)
    if opt['is_train']:
        # config loggers. Before it, the log will not work
        util.get_root_logger(None,
                             opt['path']['log'],
                             'train',
                             level=logging.INFO,
                             screen=True,
                             tofile=tofile)
        util.get_root_logger('val',
                             opt['path']['log'],
                             'val',
                             level=logging.INFO,
                             tofile=tofile)
    else:
        util.get_root_logger(None,
                             opt['path']['log'],
                             'test',
                             level=logging.INFO,
                             screen=True,
                             tofile=tofile)
    logger = util.get_root_logger()  # 'base'
    logger.info(options.dict2str(opt))

    # initialize tensorboard logger
    tb_logger = None
    if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
        version = float(torch.__version__[0:3])
        log_dir = os.path.join(opt['path']['root'], 'tb_logger', opt['name'])
        # log_dir = os.path.join(opt['path']['experiments_root'], opt['name'], 'tb')
        # logdir_valid = os.path.join(opt['path']['root'], 'tb_logger', opt['name'] + 'valid')
        # logdir_valid = os.path.join(opt['path']['experiments_root'], opt['name'], 'tb_valid')
        if version >= 1.1:  # PyTorch 1.1
            # official PyTorch tensorboard
            try:
                from torch.utils.tensorboard import SummaryWriter
            except:
                from tensorboardX import SummaryWriter
        else:
            logger.info(
                'You are using PyTorch {}. Using [tensorboardX].'.format(
                    version))
            from tensorboardX import SummaryWriter
        try:
            # for versions PyTorch > 1.1 and tensorboardX < 1.6
            tb_logger = SummaryWriter(log_dir=log_dir)
            # tb_logger_valid = SummaryWriter(log_dir=logdir_valid)
        except:
            # for version tensorboardX >= 1.7
            tb_logger = SummaryWriter(logdir=log_dir)
            # tb_logger_valid = SummaryWriter(logdir=logdir_valid)
    return {"tb_logger": tb_logger}
Ejemplo n.º 5
0
def test_loop(model, opt, dataloaders, data_params):
    logger = util.get_root_logger()

    # read data_params
    znorms = data_params['znorm']

    # prepare the metric calculation classes for RGB and Y_only images
    calc_metrics = opt.get('metrics', None)
    if calc_metrics:
        test_metrics = metrics.MetricsDict(metrics = calc_metrics)
        test_metrics_y = metrics.MetricsDict(metrics = calc_metrics)

    for phase, dataloader in dataloaders.items():
        name = dataloader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(name))
        dataset_dir = os.path.join(opt['path']['results_root'], name)
        util.mkdir(dataset_dir)

        for data in dataloader:
            znorm = znorms[name]
            need_HR = False if dataloader.dataset.opt['dataroot_HR'] is None else True

            # set up per image CEM wrapper if configured
            CEM_net = get_CEM(opt, data)

            model.feed_data(data, need_HR=need_HR)  # unpack data from data loader
            # img_path = data['LR_path'][0]
            img_path = get_img_path(data)
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            # test with eval mode. This only affects layers like batchnorm and dropout.
            test_mode = opt.get('test_mode', None)
            if test_mode == 'x8':
                # geometric self-ensemble
                model.test_x8(CEM_net=CEM_net)
            elif test_mode == 'chop':
                # chop images in patches/crops, to reduce VRAM usage
                model.test_chop(patch_size=opt.get('chop_patch_size', 100), 
                                step=opt.get('chop_step', 0.9),
                                CEM_net=CEM_net)
            else:
                # normal inference
                model.test(CEM_net=CEM_net)  # run inference
            
            # get image results
            visuals = model.get_current_visuals(need_HR=need_HR)

            # post-process options if using CEM
            if opt.get('use_cem', None) and opt['cem_config'].get('out_orig', False):
                # run regular inference
                if test_mode == 'x8':
                    model.test_x8()
                elif test_mode == 'chop':
                    model.test_chop(patch_size=opt.get('chop_patch_size', 100), 
                                    step=opt.get('chop_step', 0.9))
                else:
                    model.test()
                orig_visuals = model.get_current_visuals(need_HR=need_HR)

                if opt['cem_config'].get('out_filter', False):
                    GF = GuidedFilter(ks=opt['cem_config'].get('out_filter_ks', 7))
                    filt = GF(visuals['SR'].unsqueeze(0), (visuals['SR']-orig_visuals['SR']).unsqueeze(0)).squeeze(0)
                    visuals['SR'] = orig_visuals['SR']+filt

                if opt['cem_config'].get('out_keepY', False):
                    out_regY = rgb_to_ycbcr(orig_visuals['SR']).unsqueeze(0)
                    out_cemY = rgb_to_ycbcr(visuals['SR']).unsqueeze(0)
                    visuals['SR'] = ycbcr_to_rgb(torch.cat([out_regY[:, 0:1, :, :], out_cemY[:, 1:2, :, :], out_cemY[:, 2:3, :, :]], 1)).squeeze(0)

            res_options = visuals_check(visuals.keys(), opt.get('val_comparison', None))

            # save images
            save_img_path = os.path.join(dataset_dir, img_name + opt.get('suffix', ''))

            # save single images or lr / sr comparison
            if opt['val_comparison'] and len(res_options['save_imgs']) > 1:
                comp_images = [tensor2np(visuals[save_img_name], denormalize=znorm) for save_img_name in res_options['save_imgs']]
                util.save_img_comp(comp_images, save_img_path + '.png')
            else:
                for save_img_name in res_options['save_imgs']:
                    imn = '_' + save_img_name if len(res_options['save_imgs']) > 1 else ''
                    util.save_img(tensor2np(visuals[save_img_name], denormalize=znorm), save_img_path + imn + '.png')

            # calculate metrics if HR dataset is provided and metrics are configured in options
            if need_HR and calc_metrics and res_options['aligned_metrics']:
                metric_imgs = [tensor2np(visuals[x], denormalize=znorm) for x in res_options['compare_imgs']]
                test_results = test_metrics.calculate_metrics(metric_imgs[0], metric_imgs[1], 
                                                              crop_size=opt['scale'])
                
                # prepare single image metrics log message
                logger_m = '{:20s} -'.format(img_name)
                for k, v in test_results:
                    formatted_res = k.upper() + ': {:.6f}, '.format(v)
                    logger_m += formatted_res

                if gt_img.shape[2] == 3:  # RGB image, calculate y_only metrics
                    test_results_y = test_metrics_y.calculate_metrics(metric_imgs[0], metric_imgs[1], 
                                                                      crop_size=opt['scale'], only_y=True)
                    
                    # add the y only results to the single image log message
                    for k, v in test_results_y:
                        formatted_res = k.upper() + ': {:.6f}, '.format(v)
                        logger_m += formatted_res
                
                logger.info(logger_m)
            else:
                logger.info(img_name)

        # average metrics results for the dataset
        if need_HR and calc_metrics:
            
            # aggregate the metrics results (automatically resets the metric classes)
            avg_metrics = test_metrics.get_averages()
            avg_metrics_y = test_metrics_y.get_averages()

            # prepare log average metrics message
            agg_logger_m = ''
            for r in avg_metrics:
                formatted_res = r['name'].upper() + ': {:.6f}, '.format(r['average'])
                agg_logger_m += formatted_res
            logger.info('----Average metrics results for {}----\n\t'.format(name) + agg_logger_m[:-2])
            
            if len(avg_metrics_y > 0):
                # prepare log average Y channel metrics message
                agg_logger_m = ''
                for r in avg_metrics_y:
                    formatted_res = r['name'].upper() + '_Y' + ': {:.6f}, '.format(r['average'])
                    agg_logger_m += formatted_res
                logger.info('----Y channel, average metrics ----\n\t' + agg_logger_m[:-2])
Ejemplo n.º 6
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options file.')
    opt = options.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = options.dict_to_nonedict(opt)

    util.get_root_logger(None,
                         opt['path']['log'],
                         'test.log',
                         level=logging.INFO,
                         screen=True)
    logger = logging.getLogger('base')
    logger.info(options.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    znorm = False  # TMP
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        if dataset_opt['znorm'] and znorm == False:
            znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['in_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
            top_img = tensor2np(visuals['top_fake'])  # uint8
            bot_img = tensor2np(visuals['bottom_fake'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir, img_name + suffix)
            else:
                save_img_path = os.path.join(dataset_dir, img_name)
            util.save_img(top_img, save_img_path + '_top.png')
            util.save_img(bot_img, save_img_path + '_bot.png')

            #TODO: update to use metrics functions
            # calculate PSNR and SSIM
            if need_HR:
                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                gt_img = tensor2img(visuals['HR'], denormalize=znorm)  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        #TODO: update to use metrics functions
        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info(
                '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'
                .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info(
                    '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'
                    .format(ave_psnr_y, ave_ssim_y))
Ejemplo n.º 7
0
def fit(model, opt, dataloaders, steps_states, data_params, loggers):
    # read data_params
    batch_size = data_params['batch_size']
    virtual_batch_size = data_params['virtual_batch_size']
    total_iters = data_params['total_iters']
    total_epochs = data_params['total_epochs']

    # read steps_states
    start_epoch = steps_states["start_epoch"]
    current_step = steps_states["current_step"]
    virtual_step = steps_states["virtual_step"]

    # read loggers
    logger = util.get_root_logger()
    tb_logger = loggers["tb_logger"]
    
    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    try:
        timer = metrics.Timer()  # iteration timer
        timerData = metrics.TickTock()  # data timer
        timerEpoch = metrics.TickTock()  # epoch timer
        # outer loop for different epochs
        for epoch in range(start_epoch, (total_epochs * (virtual_batch_size // batch_size))+1):
            timerData.tick()
            timerEpoch.tick()

            # inner iteration loop within one epoch
            for n, train_data in enumerate(dataloaders['train'], start=1):
                timerData.tock()

                virtual_step += 1
                take_step = False
                if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0:
                    current_step += 1
                    take_step = True
                    if current_step > total_iters:
                        break

                # training
                model.feed_data(train_data)  # unpack data from dataset and apply preprocessing
                model.optimize_parameters(virtual_step)  # calculate loss functions, get gradients, update network weights

                # log
                def eta(t_iter):
                    # calculate training ETA in hours
                    return (t_iter * (opt['train']['niter'] - current_step)) / 3600 if t_iter > 0 else 0

                if current_step % opt['logger']['print_freq'] == 0 and take_step:
                    # iteration end time
                    avg_time = timer.get_average_and_reset()
                    avg_data_time = timerData.get_average_and_reset()

                    # print training losses and save logging information to disk
                    message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.4f}s, td:{:.4f}s, eta:{:.4f}h> '.format(
                        epoch, current_step, model.get_current_learning_rate(current_step), avg_time, 
                        avg_data_time, eta(avg_time))
                    
                    # tensorboard training logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs
                            # tb_logger.add_scalar('loss/nll', nll, current_step)
                            tb_logger.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
                            tb_logger.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
                            tb_logger.add_scalar('time/data', timerData.get_last_iteration(), current_step)

                    logs = model.get_current_log()
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard loss logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs
                                tb_logger.add_scalar(k, v, current_step)
                            # tb_logger.flush()
                    logger.info(message)

                    # start time for next iteration #TODO:skip the validation time from calculation
                    timer.tick()

                # update learning rate
                if model.optGstep and model.optDstep and take_step:
                    model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1))
                
                # save latest models and training states every <save_checkpoint_freq> iterations
                if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step:
                    if model.swa: 
                        model.save(current_step, opt['logger']['overwrite_chkp'], loader=dataloaders['train'])
                    else:
                        model.save(current_step, opt['logger']['overwrite_chkp'])
                    model.save_training_state(
                        epoch=epoch + (n >= len(dataloaders['train'])),
                        iter_step=current_step,
                        latest=opt['logger']['overwrite_chkp']
                    )
                    logger.info('Models and training states saved.')

                # validation
                if dataloaders.get('val', None) and current_step % opt['train']['val_freq'] == 0 and take_step:
                    val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None))
                    nlls = []
                    for val_data in dataloaders['val']:
                        
                        model.feed_data(val_data)  # unpack data from data loader
                        model.test()  # run inference
                        if hasattr(model, 'nll'):
                            nll = model.nll if model.nll else 0
                            nlls.append(nll)

                        """
                        Get Visuals
                        """
                        visuals = model.get_current_visuals()  # get image results
                        img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)
                        
                        # Save SR images for reference
                        sr_img = None
                        if hasattr(model, 'heats'):  # SRFlow
                            opt['train']['val_comparison'] = False
                            for heat in model.heats:
                                for i in range(model.n_sample):
                                    sr_img = tensor2np(visuals['SR', heat, i], denormalize=opt['datasets']['train']['znorm'])
                                    if opt['train']['overwrite_val_imgs']:
                                        save_img_path = os.path.join(img_dir,
                                                                '{:s}_h{:03d}_s{:d}.png'.format(img_name, int(heat * 100), i))
                                    else:
                                        save_img_path = os.path.join(img_dir,
                                                                '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
                                                                                                        current_step,
                                                                                                        int(heat * 100), i))
                                    util.save_img(sr_img, save_img_path)
                        else:  # regular SR
                            sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm'])
                            if opt['train']['overwrite_val_imgs']:
                                save_img_path = os.path.join(img_dir, '{:s}.png'.format(img_name))
                            else:
                                save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
                            if not opt['train']['val_comparison']:
                                util.save_img(sr_img, save_img_path)
                        assert sr_img is not None

                        # Save GT images for reference
                        gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm'])
                        if opt['train']['save_gt']:
                            save_img_path_gt = os.path.join(img_dir,
                                                            '{:s}_GT.png'.format(img_name))
                            if not os.path.isfile(save_img_path_gt):
                                util.save_img(gt_img, save_img_path_gt)

                        # Save LQ images for reference
                        if opt['train']['save_lr']:
                            save_img_path_lq = os.path.join(img_dir,
                                                            '{:s}_LQ.png'.format(img_name))
                            if not os.path.isfile(save_img_path_lq):
                                lq_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                                util.save_img(lq_img, save_img_path_lq, scale=opt['scale'])

                        # save single images or LQ / SR comparison
                        if opt['train']['val_comparison']:
                            lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                            util.save_img_comp([lr_img, sr_img], save_img_path)
                        # else:
                            # util.save_img(sr_img, save_img_path)

                        """
                        Get Metrics
                        # TODO: test using tensor based metrics (batch) instead of numpy.
                        """
                        val_metrics.calculate_metrics(sr_img, gt_img, crop_size=opt['scale'])  # , only_y=True)

                    avg_metrics = val_metrics.get_averages()
                    if nlls:
                        avg_nll = sum(nlls) / len(nlls)
                    del val_metrics

                    # log
                    logger_m = ''
                    for r in avg_metrics:
                        formatted_res = r['name'].upper() + ': {:.5g}, '.format(r['average'])
                        logger_m += formatted_res
                    if nlls:
                        logger_m += 'avg_nll: {:.4e}  '.format(avg_nll)

                    logger.info('# Validation # ' + logger_m[:-2])
                    logger_val = logging.getLogger('val')  # validation logger
                    logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step) + logger_m[:-2])
                    # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB
                    
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        for r in avg_metrics:
                            tb_logger.add_scalar(r['name'], r['average'], current_step)
                            if nlls:
                                tb_logger.add_scalar('average nll', avg_nll, current_step)
                            # tb_logger.flush()
                            # tb_logger_valid.add_scalar(r['name'], r['average'], current_step)
                            # tb_logger_valid.flush()
                    
                timerData.tick()
            
            timerEpoch.tock()
            logger.info('End of epoch {} / {} \t Time Taken: {:.4f} sec'.format(
                epoch, total_epochs, timerEpoch.get_last_iteration()))

        logger.info('Saving the final model.')
        if model.swa:
            model.save('latest', loader=dataloaders['train'])
        else:
            model.save('latest')
        logger.info('End of training.')

    except KeyboardInterrupt:
        # catch a KeyboardInterrupt and save the model and state to resume later
        if model.swa:
            model.save(current_step, True, loader=dataloaders['train'])
        else:
            model.save(current_step, True)
        model.save_training_state(epoch + (n >= len(dataloaders['train'])), current_step, True)
        logger.info('Training interrupted. Latest models and training states saved.')
Ejemplo n.º 8
0
def test_loop(model, opt, dataloaders, data_params):
    logger = util.get_root_logger()

    # read data_params
    znorms = data_params['znorm']

    # prepare the metric calculation classes for RGB and Y_only images
    calc_metrics = opt.get('metrics', None)
    if calc_metrics:
        test_metrics = metrics.MetricsDict(metrics=calc_metrics)
        test_metrics_y = metrics.MetricsDict(metrics=calc_metrics)

    for phase, dataloader in dataloaders.items():
        name = dataloader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(name))
        dataset_dir = os.path.join(opt['path']['results_root'], name)
        util.mkdir(dataset_dir)

        nlls = []
        for data in dataloader:
            znorm = znorms[name]
            need_HR = False if dataloader.dataset.opt[
                'dataroot_HR'] is None else True

            # set up per image CEM wrapper if configured
            CEM_net = get_CEM(opt, data)

            model.feed_data(data,
                            need_HR=need_HR)  # unpack data from data loader
            # img_path = data['LR_path'][0]
            img_path = get_img_path(data)
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            # test with eval mode. This only affects layers like batchnorm and dropout.
            test_mode = opt.get('test_mode', None)
            if test_mode == 'x8':
                # geometric self-ensemble
                # model.test_x8(CEM_net=CEM_net)
                break
            elif test_mode == 'chop':
                # chop images in patches/crops, to reduce VRAM usage
                # model.test_chop(patch_size=opt.get('chop_patch_size', 100),
                #                 step=opt.get('chop_step', 0.9),
                #                 CEM_net=CEM_net)
                break
            else:
                # normal inference
                model.test(CEM_net=CEM_net)  # run inference

            if hasattr(model, 'nll'):
                nll = model.nll if model.nll else 0
                nlls.append(nll)

            # get image results
            visuals = model.get_current_visuals(need_HR=need_HR)

            res_options = visuals_check(visuals.keys(),
                                        opt.get('val_comparison', None))

            # save images
            save_img_path = os.path.join(dataset_dir,
                                         img_name + opt.get('suffix', ''))

            # Save SR images for reference
            sr_img = None
            if hasattr(model, 'heats'):  # SRFlow
                opt['val_comparison'] = False
                for heat in model.heats:
                    for i in range(model.n_sample):
                        for save_img_name in res_options['save_imgs']:
                            imn = '_' + save_img_name if len(
                                res_options['save_imgs']) > 1 else ''
                            imn += '_h{:03d}_s{:d}'.format(int(heat * 100), i)
                            util.save_img(
                                tensor2np(visuals[save_img_name, heat, i],
                                          denormalize=znorm),
                                save_img_path + imn + '.png')
            else:  # regular SR
                if not opt['val_comparison']:
                    for save_img_name in res_options['save_imgs']:
                        imn = '_' + save_img_name if len(
                            res_options['save_imgs']) > 1 else ''
                        util.save_img(
                            tensor2np(visuals[save_img_name],
                                      denormalize=znorm),
                            save_img_path + imn + '.png')

            # save single images or lr / sr comparison
            if opt['val_comparison'] and len(res_options['save_imgs']) > 1:
                comp_images = [
                    tensor2np(visuals[save_img_name], denormalize=znorm)
                    for save_img_name in res_options['save_imgs']
                ]
                util.save_img_comp(comp_images, save_img_path + '.png')
            # else:
            # util.save_img(sr_img, save_img_path)

            # calculate metrics if HR dataset is provided and metrics are configured in options
            if need_HR and calc_metrics and res_options['aligned_metrics']:
                metric_imgs = [
                    tensor2np(visuals[x], denormalize=znorm)
                    for x in res_options['compare_imgs']
                ]
                test_results = test_metrics.calculate_metrics(
                    metric_imgs[0], metric_imgs[1], crop_size=opt['scale'])

                # prepare single image metrics log message
                logger_m = '{:20s} -'.format(img_name)
                for k, v in test_results:
                    formatted_res = k.upper() + ': {:.6f}, '.format(v)
                    logger_m += formatted_res

                if gt_img.shape[2] == 3:  # RGB image, calculate y_only metrics
                    test_results_y = test_metrics_y.calculate_metrics(
                        metric_imgs[0],
                        metric_imgs[1],
                        crop_size=opt['scale'],
                        only_y=True)

                    # add the y only results to the single image log message
                    for k, v in test_results_y:
                        formatted_res = k.upper() + ': {:.6f}, '.format(v)
                        logger_m += formatted_res

                logger.info(logger_m)
            else:
                logger.info(img_name)

        # average metrics results for the dataset
        if need_HR and calc_metrics:

            # aggregate the metrics results (automatically resets the metric classes)
            avg_metrics = test_metrics.get_averages()
            avg_metrics_y = test_metrics_y.get_averages()

            # prepare log average metrics message
            agg_logger_m = ''
            for r in avg_metrics:
                formatted_res = r['name'].upper() + ': {:.6f}, '.format(
                    r['average'])
                agg_logger_m += formatted_res
            logger.info(
                '----Average metrics results for {}----\n\t'.format(name) +
                agg_logger_m[:-2])

            if len(avg_metrics_y > 0):
                # prepare log average Y channel metrics message
                agg_logger_m = ''
                for r in avg_metrics_y:
                    formatted_res = r['name'].upper(
                    ) + '_Y' + ': {:.6f}, '.format(r['average'])
                    agg_logger_m += formatted_res
                logger.info('----Y channel, average metrics ----\n\t' +
                            agg_logger_m[:-2])
Ejemplo n.º 9
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to options file.')
    opt = options.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))

    logger = util.get_root_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(options.dict2str(opt))

    scale = opt.get('scale', 4)

    # Create test dataset and dataloader
    test_loaders = []
    znorm = False  # TMP
    # znorm_list = []

    '''
    video_list = os.listdir(cfg.testset_dir)
    for idx_video in range(len(video_list)):
        video_name = video_list[idx_video]
        # dataloader
        test_set = TestsetLoader(cfg, video_name)
        test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False)
    '''

    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt.get['znorm'] and znorm == False: 
        #     znorm = True
        znorm = dataset_opt.get('znorm', False)
        # znorm_list.apped(znorm)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            # tmp_vis(data['LR'][:,1,:,:,:], True)

            if opt.get('chop_forward', None):
                # data
                if len(data['LR'].size()) == 4:
                    b, n_frames, h_lr, w_lr = data['LR'].size()
                    LR_y_cube = data['LR'].view(b, -1, 1, h_lr, w_lr)  # b, t, c, h, w
                elif len(data['LR'].size()) == 5:  # for networks that work with 3 channel images
                    _, n_frames, _, _, _ = data['LR'].size()
                    LR_y_cube = data['LR']  # b, t, c, h, w

                # print(LR_y_cube.shape)
                # print(data['LR_bicubic'].shape)

                # crop borders to ensure each patch can be divisible by 2
                # TODO: this is modcrop, not sure if really needed, check (the dataloader already does modcrop)
                _, _, _, h, w = LR_y_cube.size()
                h = int(h // 16) * 16
                w = int(w // 16) * 16
                LR_y_cube = LR_y_cube[:, :, :, :h, :w]
                if isinstance(data['LR_bicubic'], torch.Tensor):
                    # SR_cb = data['LR_bicubic'][:, 1, :, :][:, :, :h * scale, :w * scale]
                    SR_cb = data['LR_bicubic'][:, 1, :h * scale, :w * scale]
                    # SR_cr = data['LR_bicubic'][:, 2, :, :][:, :, :h * scale, :w * scale]
                    SR_cr = data['LR_bicubic'][:, 2, :h * scale, :w * scale]

                SR_y = chop_forward(LR_y_cube, model, scale, need_HR=need_HR).squeeze(0)
                # SR_y = np.array(SR_y.data.cpu())
                if test_loader.dataset.opt.get('srcolors', None):
                    print(SR_y.shape, SR_cb.shape, SR_cr.shape)
                    sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3))
                else:
                    sr_img = SR_y
            else:
                # data
                model.feed_data(data, need_HR=need_HR)
                # SR_y = net(LR_y_cube).squeeze(0)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)
                # ds = torch.nn.AvgPool2d(2, stride=2, count_include_pad=False)
                # tmp_vis(ds(visuals['SR']), True)
                # tmp_vis(visuals['SR'], True)
                if test_loader.dataset.opt.get('y_only', None) and test_loader.dataset.opt.get('srcolors', None):
                    SR_cb = data['LR_bicubic'][:, 1, :, :]
                    SR_cr = data['LR_bicubic'][:, 2, :, :]
                    # tmp_vis(ds(SR_cb), True)
                    # tmp_vis(ds(SR_cr), True)
                    sr_img = ycbcr_to_rgb(torch.stack((visuals['SR'], SR_cb, SR_cr), -3))
                else:
                    sr_img = visuals['SR']

            # if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
            sr_img = tensor2np(sr_img, denormalize=znorm)  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir, img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # TODO: update to use metrics functions
            # calculate PSNR and SSIM
            if need_HR:
                # if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                gt_img = tensor2img(visuals['HR'], denormalize=znorm)  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.' \
                                .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        # TODO: update to use metrics functions
        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n' \
                        .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n' \
                            .format(ave_psnr_y, ave_ssim_y))