Beispiel #1
0
def tmp_vis(img_t, to_np=True, rgb2bgr=True, remove_batch=False, save_dir=''):
    '''
        Visualization function that can be inserted at any point 
        in the code, works with tensor or np images
        img_t: image
        save_dir: path to save image
    '''
    import cv2
    from dataops.common import tensor2np

    if to_np:
        img = tensor2np(img_t.detach(),
                        rgb2bgr=rgb2bgr,
                        remove_batch=remove_batch)
    else:
        img = img_t
    print("out: ", img.shape)

    cv2.imshow('image', img)
    cv2.waitKey(0)

    if save_dir != '':
        cv2.imwrite(save_dir, img)

    cv2.destroyAllWindows()

    return None
Beispiel #2
0
def save_image(image=None, num_rep=0, sufix=None, random=False):
    '''
        Output images to a directory instead of visualizing them, 
        may be easier to compare multiple batches of images
    '''
    import uuid, cv2
    from dataops.common import tensor2np
    img = tensor2np(image, remove_batch=False)  # uint8

    if random:
        #random name to save + had to multiply by 255, else getting all black image
        hex = uuid.uuid4().hex
        cv2.imwrite(
            "D:/tmp_test/fake_" + sufix + "_" + str(num_rep) + hex + ".png",
            img)
    else:
        cv2.imwrite("D:/tmp_test/fake_" + sufix + "_" + str(num_rep) + ".png",
                    img)

    return None
Beispiel #3
0
def get_sp_transform(train_opt: dict, znorm: bool = True):
    n_segments = train_opt.get('sp_n_segments', 200)  # 500
    max_size = train_opt.get('sp_max_size', None)  # crop_size
    # 'selective' 'cluster' 'rag' None
    reduction = train_opt.get('sp_reduction', 'selective')
    # 'seeds', 'slic', 'slico', 'mslic', 'sk_slic', 'sk_felzenszwalb'
    algo = train_opt.get('sp_algo', 'sk_felzenszwalb')
    gamma_range = train_opt.get('sp_gamma_range', (100, 120))

    superpixel_fn = transforms.Compose([
        transforms.Lambda(lambda img: tensor2np(
            img, rgb2bgr=True, denormalize=znorm, remove_batch=False)),
        transforms.Superpixels(p_replace=1,
                               n_segments=n_segments,
                               algo=algo,
                               reduction=reduction,
                               max_size=max_size,
                               p=1),
        transforms.RandomGamma(gamma_range=gamma_range, gain=1, p=1),
        transforms.Lambda(lambda img: np2tensor(
            img, bgr2rgb=True, normalize=znorm, add_batch=False))
    ])
    return superpixel_fn
Beispiel #4
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True,
                        help='Path to options file.')
    opt = options.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs(
        (path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = options.dict_to_nonedict(opt)

    util.setup_logger(None, opt['path']['log'],
                      'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(options.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    znorm = False  # TMP
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        if dataset_opt['znorm'] and znorm == False:
            znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['in_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
            top_img = tensor2np(visuals['top_fake'])  # uint8
            bot_img = tensor2np(visuals['bottom_fake'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(
                    dataset_dir, img_name + suffix)
            else:
                save_img_path = os.path.join(dataset_dir, img_name)
            util.save_img(top_img, save_img_path + '_top.png')
            util.save_img(bot_img, save_img_path + '_bot.png')


            #TODO: update to use metrics functions
            # calculate PSNR and SSIM
            if need_HR:
                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                gt_img = tensor2img(visuals['HR'], denormalize=znorm)  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-
                                        crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-
                                        crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(
                    cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(
                    cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-
                                                crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-
                                                crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                                .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        #TODO: update to use metrics functions
        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'
                        .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(
                    test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(
                    test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'
                            .format(ave_psnr_y, ave_ssim_y))
Beispiel #5
0
def test_loop(model, opt, dataloaders, data_params):
    logger = util.get_root_logger()

    # read data_params
    znorms = data_params['znorm']

    # prepare the metric calculation classes for RGB and Y_only images
    calc_metrics = opt.get('metrics', None)
    if calc_metrics:
        test_metrics = metrics.MetricsDict(metrics = calc_metrics)
        test_metrics_y = metrics.MetricsDict(metrics = calc_metrics)

    for phase, dataloader in dataloaders.items():
        name = dataloader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(name))
        dataset_dir = os.path.join(opt['path']['results_root'], name)
        util.mkdir(dataset_dir)

        for data in dataloader:
            znorm = znorms[name]
            need_HR = False if dataloader.dataset.opt['dataroot_HR'] is None else True

            # set up per image CEM wrapper if configured
            CEM_net = get_CEM(opt, data)

            model.feed_data(data, need_HR=need_HR)  # unpack data from data loader
            # img_path = data['LR_path'][0]
            img_path = get_img_path(data)
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            # test with eval mode. This only affects layers like batchnorm and dropout.
            test_mode = opt.get('test_mode', None)
            if test_mode == 'x8':
                # geometric self-ensemble
                model.test_x8(CEM_net=CEM_net)
            elif test_mode == 'chop':
                # chop images in patches/crops, to reduce VRAM usage
                model.test_chop(patch_size=opt.get('chop_patch_size', 100), 
                                step=opt.get('chop_step', 0.9),
                                CEM_net=CEM_net)
            else:
                # normal inference
                model.test(CEM_net=CEM_net)  # run inference
            
            # get image results
            visuals = model.get_current_visuals(need_HR=need_HR)

            # post-process options if using CEM
            if opt.get('use_cem', None) and opt['cem_config'].get('out_orig', False):
                # run regular inference
                if test_mode == 'x8':
                    model.test_x8()
                elif test_mode == 'chop':
                    model.test_chop(patch_size=opt.get('chop_patch_size', 100), 
                                    step=opt.get('chop_step', 0.9))
                else:
                    model.test()
                orig_visuals = model.get_current_visuals(need_HR=need_HR)

                if opt['cem_config'].get('out_filter', False):
                    GF = GuidedFilter(ks=opt['cem_config'].get('out_filter_ks', 7))
                    filt = GF(visuals['SR'].unsqueeze(0), (visuals['SR']-orig_visuals['SR']).unsqueeze(0)).squeeze(0)
                    visuals['SR'] = orig_visuals['SR']+filt

                if opt['cem_config'].get('out_keepY', False):
                    out_regY = rgb_to_ycbcr(orig_visuals['SR']).unsqueeze(0)
                    out_cemY = rgb_to_ycbcr(visuals['SR']).unsqueeze(0)
                    visuals['SR'] = ycbcr_to_rgb(torch.cat([out_regY[:, 0:1, :, :], out_cemY[:, 1:2, :, :], out_cemY[:, 2:3, :, :]], 1)).squeeze(0)

            res_options = visuals_check(visuals.keys(), opt.get('val_comparison', None))

            # save images
            save_img_path = os.path.join(dataset_dir, img_name + opt.get('suffix', ''))

            # save single images or lr / sr comparison
            if opt['val_comparison'] and len(res_options['save_imgs']) > 1:
                comp_images = [tensor2np(visuals[save_img_name], denormalize=znorm) for save_img_name in res_options['save_imgs']]
                util.save_img_comp(comp_images, save_img_path + '.png')
            else:
                for save_img_name in res_options['save_imgs']:
                    imn = '_' + save_img_name if len(res_options['save_imgs']) > 1 else ''
                    util.save_img(tensor2np(visuals[save_img_name], denormalize=znorm), save_img_path + imn + '.png')

            # calculate metrics if HR dataset is provided and metrics are configured in options
            if need_HR and calc_metrics and res_options['aligned_metrics']:
                metric_imgs = [tensor2np(visuals[x], denormalize=znorm) for x in res_options['compare_imgs']]
                test_results = test_metrics.calculate_metrics(metric_imgs[0], metric_imgs[1], 
                                                              crop_size=opt['scale'])
                
                # prepare single image metrics log message
                logger_m = '{:20s} -'.format(img_name)
                for k, v in test_results:
                    formatted_res = k.upper() + ': {:.6f}, '.format(v)
                    logger_m += formatted_res

                if gt_img.shape[2] == 3:  # RGB image, calculate y_only metrics
                    test_results_y = test_metrics_y.calculate_metrics(metric_imgs[0], metric_imgs[1], 
                                                                      crop_size=opt['scale'], only_y=True)
                    
                    # add the y only results to the single image log message
                    for k, v in test_results_y:
                        formatted_res = k.upper() + ': {:.6f}, '.format(v)
                        logger_m += formatted_res
                
                logger.info(logger_m)
            else:
                logger.info(img_name)

        # average metrics results for the dataset
        if need_HR and calc_metrics:
            
            # aggregate the metrics results (automatically resets the metric classes)
            avg_metrics = test_metrics.get_averages()
            avg_metrics_y = test_metrics_y.get_averages()

            # prepare log average metrics message
            agg_logger_m = ''
            for r in avg_metrics:
                formatted_res = r['name'].upper() + ': {:.6f}, '.format(r['average'])
                agg_logger_m += formatted_res
            logger.info('----Average metrics results for {}----\n\t'.format(name) + agg_logger_m[:-2])
            
            if len(avg_metrics_y > 0):
                # prepare log average Y channel metrics message
                agg_logger_m = ''
                for r in avg_metrics_y:
                    formatted_res = r['name'].upper() + '_Y' + ': {:.6f}, '.format(r['average'])
                    agg_logger_m += formatted_res
                logger.info('----Y channel, average metrics ----\n\t' + agg_logger_m[:-2])
Beispiel #6
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    scale = opt.get('scale', 4)

    # Create test dataset and dataloader
    test_loaders = []
    znorm = False  #TMP
    # znorm_list = []
    '''
    video_list = os.listdir(cfg.testset_dir)
    for idx_video in range(len(video_list)):
        video_name = video_list[idx_video]
        # dataloader
        test_set = TestsetLoader(cfg, video_name)
        test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False)
    '''

    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt.get['znorm'] and znorm == False:
        #     znorm = True
        znorm = dataset_opt.get('znorm', False)
        # znorm_list.apped(znorm)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            # tmp_vis(data['LR'][:,1,:,:,:], True)

            if opt.get('chop_forward', None):
                # data
                if len(data['LR'].size()) == 4:
                    b, n_frames, h_lr, w_lr = data['LR'].size()
                    LR_y_cube = data['LR'].view(b, -1, 1, h_lr,
                                                w_lr)  # b, t, c, h, w
                elif len(data['LR'].size()
                         ) == 5:  #for networks that work with 3 channel images
                    _, n_frames, _, _, _ = data['LR'].size()
                    LR_y_cube = data['LR']  # b, t, c, h, w

                # print(LR_y_cube.shape)
                # print(data['LR_bicubic'].shape)

                # crop borders to ensure each patch can be divisible by 2
                #TODO: this is modcrop, not sure if really needed, check (the dataloader already does modcrop)
                _, _, _, h, w = LR_y_cube.size()
                h = int(h // 16) * 16
                w = int(w // 16) * 16
                LR_y_cube = LR_y_cube[:, :, :, :h, :w]
                if isinstance(data['LR_bicubic'], torch.Tensor):
                    # SR_cb = data['LR_bicubic'][:, 1, :, :][:, :, :h * scale, :w * scale]
                    SR_cb = data['LR_bicubic'][:, 1, :h * scale, :w * scale]
                    # SR_cr = data['LR_bicubic'][:, 2, :, :][:, :, :h * scale, :w * scale]
                    SR_cr = data['LR_bicubic'][:, 2, :h * scale, :w * scale]

                SR_y = chop_forward(LR_y_cube, model, scale,
                                    need_HR=need_HR).squeeze(0)
                # SR_y = np.array(SR_y.data.cpu())
                if test_loader.dataset.opt.get('srcolors', None):
                    print(SR_y.shape, SR_cb.shape, SR_cr.shape)
                    sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr),
                                                      -3))
                else:
                    sr_img = SR_y
            else:
                # data
                model.feed_data(data, need_HR=need_HR)
                # SR_y = net(LR_y_cube).squeeze(0)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)
                # ds = torch.nn.AvgPool2d(2, stride=2, count_include_pad=False)
                # tmp_vis(ds(visuals['SR']), True)
                # tmp_vis(visuals['SR'], True)
                if test_loader.dataset.opt.get(
                        'y_only', None) and test_loader.dataset.opt.get(
                            'srcolors', None):
                    SR_cb = data['LR_bicubic'][:, 1, :, :]
                    SR_cr = data['LR_bicubic'][:, 2, :, :]
                    # tmp_vis(ds(SR_cb), True)
                    # tmp_vis(ds(SR_cr), True)
                    sr_img = ycbcr_to_rgb(
                        torch.stack((visuals['SR'], SR_cb, SR_cr), -3))
                else:
                    sr_img = visuals['SR']

            #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
            sr_img = tensor2np(sr_img, denormalize=znorm)  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            #TODO: update to use metrics functions
            # calculate PSNR and SSIM
            if need_HR:
                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                gt_img = tensor2img(visuals['HR'], denormalize=znorm)  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        #TODO: update to use metrics functions
        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
Beispiel #7
0
def fit(model, opt, dataloaders, steps_states, data_params, loggers):
    # read data_params
    batch_size = data_params['batch_size']
    virtual_batch_size = data_params['virtual_batch_size']
    total_iters = data_params['total_iters']
    total_epochs = data_params['total_epochs']

    # read steps_states
    start_epoch = steps_states["start_epoch"]
    current_step = steps_states["current_step"]
    virtual_step = steps_states["virtual_step"]

    # read loggers
    logger = util.get_root_logger()
    tb_logger = loggers["tb_logger"]
    
    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    try:
        timer = metrics.Timer()  # iteration timer
        timerData = metrics.TickTock()  # data timer
        timerEpoch = metrics.TickTock()  # epoch timer
        # outer loop for different epochs
        for epoch in range(start_epoch, (total_epochs * (virtual_batch_size // batch_size))+1):
            timerData.tick()
            timerEpoch.tick()

            # inner iteration loop within one epoch
            for n, train_data in enumerate(dataloaders['train'], start=1):
                timerData.tock()

                virtual_step += 1
                take_step = False
                if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0:
                    current_step += 1
                    take_step = True
                    if current_step > total_iters:
                        break

                # training
                model.feed_data(train_data)  # unpack data from dataset and apply preprocessing
                model.optimize_parameters(virtual_step)  # calculate loss functions, get gradients, update network weights

                # log
                def eta(t_iter):
                    # calculate training ETA in hours
                    return (t_iter * (opt['train']['niter'] - current_step)) / 3600 if t_iter > 0 else 0

                if current_step % opt['logger']['print_freq'] == 0 and take_step:
                    # iteration end time
                    avg_time = timer.get_average_and_reset()
                    avg_data_time = timerData.get_average_and_reset()

                    # print training losses and save logging information to disk
                    message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.4f}s, td:{:.4f}s, eta:{:.4f}h> '.format(
                        epoch, current_step, model.get_current_learning_rate(current_step), avg_time, 
                        avg_data_time, eta(avg_time))
                    
                    # tensorboard training logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs
                            # tb_logger.add_scalar('loss/nll', nll, current_step)
                            tb_logger.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
                            tb_logger.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
                            tb_logger.add_scalar('time/data', timerData.get_last_iteration(), current_step)

                    logs = model.get_current_log()
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard loss logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            if current_step % opt['logger'].get('tb_sample_rate', 1) == 0: # Reduce rate of tb logs
                                tb_logger.add_scalar(k, v, current_step)
                            # tb_logger.flush()
                    logger.info(message)

                    # start time for next iteration #TODO:skip the validation time from calculation
                    timer.tick()

                # update learning rate
                if model.optGstep and model.optDstep and take_step:
                    model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1))
                
                # save latest models and training states every <save_checkpoint_freq> iterations
                if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step:
                    if model.swa: 
                        model.save(current_step, opt['logger']['overwrite_chkp'], loader=dataloaders['train'])
                    else:
                        model.save(current_step, opt['logger']['overwrite_chkp'])
                    model.save_training_state(
                        epoch=epoch + (n >= len(dataloaders['train'])),
                        iter_step=current_step,
                        latest=opt['logger']['overwrite_chkp']
                    )
                    logger.info('Models and training states saved.')

                # validation
                if dataloaders.get('val', None) and current_step % opt['train']['val_freq'] == 0 and take_step:
                    val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None))
                    nlls = []
                    for val_data in dataloaders['val']:
                        
                        model.feed_data(val_data)  # unpack data from data loader
                        model.test()  # run inference
                        if hasattr(model, 'nll'):
                            nll = model.nll if model.nll else 0
                            nlls.append(nll)

                        """
                        Get Visuals
                        """
                        visuals = model.get_current_visuals()  # get image results
                        img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)
                        
                        # Save SR images for reference
                        sr_img = None
                        if hasattr(model, 'heats'):  # SRFlow
                            opt['train']['val_comparison'] = False
                            for heat in model.heats:
                                for i in range(model.n_sample):
                                    sr_img = tensor2np(visuals['SR', heat, i], denormalize=opt['datasets']['train']['znorm'])
                                    if opt['train']['overwrite_val_imgs']:
                                        save_img_path = os.path.join(img_dir,
                                                                '{:s}_h{:03d}_s{:d}.png'.format(img_name, int(heat * 100), i))
                                    else:
                                        save_img_path = os.path.join(img_dir,
                                                                '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
                                                                                                        current_step,
                                                                                                        int(heat * 100), i))
                                    util.save_img(sr_img, save_img_path)
                        else:  # regular SR
                            sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm'])
                            if opt['train']['overwrite_val_imgs']:
                                save_img_path = os.path.join(img_dir, '{:s}.png'.format(img_name))
                            else:
                                save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
                            if not opt['train']['val_comparison']:
                                util.save_img(sr_img, save_img_path)
                        assert sr_img is not None

                        # Save GT images for reference
                        gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm'])
                        if opt['train']['save_gt']:
                            save_img_path_gt = os.path.join(img_dir,
                                                            '{:s}_GT.png'.format(img_name))
                            if not os.path.isfile(save_img_path_gt):
                                util.save_img(gt_img, save_img_path_gt)

                        # Save LQ images for reference
                        if opt['train']['save_lr']:
                            save_img_path_lq = os.path.join(img_dir,
                                                            '{:s}_LQ.png'.format(img_name))
                            if not os.path.isfile(save_img_path_lq):
                                lq_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                                util.save_img(lq_img, save_img_path_lq, scale=opt['scale'])

                        # save single images or LQ / SR comparison
                        if opt['train']['val_comparison']:
                            lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                            util.save_img_comp([lr_img, sr_img], save_img_path)
                        # else:
                            # util.save_img(sr_img, save_img_path)

                        """
                        Get Metrics
                        # TODO: test using tensor based metrics (batch) instead of numpy.
                        """
                        val_metrics.calculate_metrics(sr_img, gt_img, crop_size=opt['scale'])  # , only_y=True)

                    avg_metrics = val_metrics.get_averages()
                    if nlls:
                        avg_nll = sum(nlls) / len(nlls)
                    del val_metrics

                    # log
                    logger_m = ''
                    for r in avg_metrics:
                        formatted_res = r['name'].upper() + ': {:.5g}, '.format(r['average'])
                        logger_m += formatted_res
                    if nlls:
                        logger_m += 'avg_nll: {:.4e}  '.format(avg_nll)

                    logger.info('# Validation # ' + logger_m[:-2])
                    logger_val = logging.getLogger('val')  # validation logger
                    logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step) + logger_m[:-2])
                    # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB
                    
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        for r in avg_metrics:
                            tb_logger.add_scalar(r['name'], r['average'], current_step)
                            if nlls:
                                tb_logger.add_scalar('average nll', avg_nll, current_step)
                            # tb_logger.flush()
                            # tb_logger_valid.add_scalar(r['name'], r['average'], current_step)
                            # tb_logger_valid.flush()
                    
                timerData.tick()
            
            timerEpoch.tock()
            logger.info('End of epoch {} / {} \t Time Taken: {:.4f} sec'.format(
                epoch, total_epochs, timerEpoch.get_last_iteration()))

        logger.info('Saving the final model.')
        if model.swa:
            model.save('latest', loader=dataloaders['train'])
        else:
            model.save('latest')
        logger.info('End of training.')

    except KeyboardInterrupt:
        # catch a KeyboardInterrupt and save the model and state to resume later
        if model.swa:
            model.save(current_step, True, loader=dataloaders['train'])
        else:
            model.save(current_step, True)
        model.save_training_state(epoch + (n >= len(dataloaders['train'])), current_step, True)
        logger.info('Training interrupted. Latest models and training states saved.')
Beispiel #8
0
def tmp_vis(img_t,
            to_np=True,
            rgb2bgr=True,
            remove_batch=False,
            denormalize=False,
            save_dir='',
            tensor_shape='TCHW'):
    '''
        Visualization function that can be inserted at any point 
        in the code, works with tensor or np images
        img_t: image (shape: [B, ..., W, H])
        save_dir: path to save image
        tensor_shape: in the case of n_dim == 5, needs to provide the order
            of the dimensions
    '''
    import cv2
    from dataops.common import tensor2np

    if isinstance(img_t, torch.Tensor) and to_np:
        n_dim = img_t.dim()
        if n_dim == 5:
            # "volumetric" tensor [B, _, _, H, W], where indexes [1] and [2]
            # can be (for example) either channels or time. Reduce to 4D tensor
            # for visualization
            if tensor_shape == 'CTHW':
                _, _, n_frames, _, _ = img_t.size()
                frames = []
                for frame in range(n_frames):
                    frames.append(img_t[:, :, frame:frame + 1, :, :])
                img_t = torch.cat(frames, -1)
            elif tensor_shape == 'TCHW':
                _, n_frames, _, _, _ = img_t.size()
                frames = []
                for frame in range(n_frames):
                    frames.append(img_t[:, frame:frame + 1, :, :, :])
                img_t = torch.cat(frames, -1)
            elif tensor_shape == 'CTHW_m':
                # select only the middle frame of CTHW tensor
                _, _, n_frames, _, _ = img_t.size()
                center = (n_frames - 1) // 2
                img_t = img_t[:, :, center, :, :]
            elif tensor_shape == 'TCHW_m':
                # select only the middle frame of TCHW tensor
                _, n_frames, _, _, _ = img_t.size()
                center = (n_frames - 1) // 2
                img_t = img_t[:, center, :, :, :]
            else:
                TypeError("Unrecognized tensor_shape: {}".format(tensor_shape))

        img = tensor2np(img_t.detach(),
                        rgb2bgr=rgb2bgr,
                        remove_batch=remove_batch,
                        denormalize=denormalize)
    elif isinstance(img_t, np.ndarray) and not to_np:
        img = img_t
    else:
        raise TypeError("img_t type not supported, expected tensor or ndarray")

    print("out: ", img.shape)
    cv2.imshow('image', img)
    cv2.waitKey(0)

    if save_dir != '':
        cv2.imwrite(save_dir, img)

    cv2.destroyAllWindows()

    return None
Beispiel #9
0
def test_loop(model, opt, dataloaders, data_params):
    logger = util.get_root_logger()

    # read data_params
    znorms = data_params['znorm']

    # prepare the metric calculation classes for RGB and Y_only images
    calc_metrics = opt.get('metrics', None)
    if calc_metrics:
        test_metrics = metrics.MetricsDict(metrics=calc_metrics)
        test_metrics_y = metrics.MetricsDict(metrics=calc_metrics)

    for phase, dataloader in dataloaders.items():
        name = dataloader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(name))
        dataset_dir = os.path.join(opt['path']['results_root'], name)
        util.mkdir(dataset_dir)

        nlls = []
        for data in dataloader:
            znorm = znorms[name]
            need_HR = False if dataloader.dataset.opt[
                'dataroot_HR'] is None else True

            # set up per image CEM wrapper if configured
            CEM_net = get_CEM(opt, data)

            model.feed_data(data,
                            need_HR=need_HR)  # unpack data from data loader
            # img_path = data['LR_path'][0]
            img_path = get_img_path(data)
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            # test with eval mode. This only affects layers like batchnorm and dropout.
            test_mode = opt.get('test_mode', None)
            if test_mode == 'x8':
                # geometric self-ensemble
                # model.test_x8(CEM_net=CEM_net)
                break
            elif test_mode == 'chop':
                # chop images in patches/crops, to reduce VRAM usage
                # model.test_chop(patch_size=opt.get('chop_patch_size', 100),
                #                 step=opt.get('chop_step', 0.9),
                #                 CEM_net=CEM_net)
                break
            else:
                # normal inference
                model.test(CEM_net=CEM_net)  # run inference

            if hasattr(model, 'nll'):
                nll = model.nll if model.nll else 0
                nlls.append(nll)

            # get image results
            visuals = model.get_current_visuals(need_HR=need_HR)

            res_options = visuals_check(visuals.keys(),
                                        opt.get('val_comparison', None))

            # save images
            save_img_path = os.path.join(dataset_dir,
                                         img_name + opt.get('suffix', ''))

            # Save SR images for reference
            sr_img = None
            if hasattr(model, 'heats'):  # SRFlow
                opt['val_comparison'] = False
                for heat in model.heats:
                    for i in range(model.n_sample):
                        for save_img_name in res_options['save_imgs']:
                            imn = '_' + save_img_name if len(
                                res_options['save_imgs']) > 1 else ''
                            imn += '_h{:03d}_s{:d}'.format(int(heat * 100), i)
                            util.save_img(
                                tensor2np(visuals[save_img_name, heat, i],
                                          denormalize=znorm),
                                save_img_path + imn + '.png')
            else:  # regular SR
                if not opt['val_comparison']:
                    for save_img_name in res_options['save_imgs']:
                        imn = '_' + save_img_name if len(
                            res_options['save_imgs']) > 1 else ''
                        util.save_img(
                            tensor2np(visuals[save_img_name],
                                      denormalize=znorm),
                            save_img_path + imn + '.png')

            # save single images or lr / sr comparison
            if opt['val_comparison'] and len(res_options['save_imgs']) > 1:
                comp_images = [
                    tensor2np(visuals[save_img_name], denormalize=znorm)
                    for save_img_name in res_options['save_imgs']
                ]
                util.save_img_comp(comp_images, save_img_path + '.png')
            # else:
            # util.save_img(sr_img, save_img_path)

            # calculate metrics if HR dataset is provided and metrics are configured in options
            if need_HR and calc_metrics and res_options['aligned_metrics']:
                metric_imgs = [
                    tensor2np(visuals[x], denormalize=znorm)
                    for x in res_options['compare_imgs']
                ]
                test_results = test_metrics.calculate_metrics(
                    metric_imgs[0], metric_imgs[1], crop_size=opt['scale'])

                # prepare single image metrics log message
                logger_m = '{:20s} -'.format(img_name)
                for k, v in test_results:
                    formatted_res = k.upper() + ': {:.6f}, '.format(v)
                    logger_m += formatted_res

                if gt_img.shape[2] == 3:  # RGB image, calculate y_only metrics
                    test_results_y = test_metrics_y.calculate_metrics(
                        metric_imgs[0],
                        metric_imgs[1],
                        crop_size=opt['scale'],
                        only_y=True)

                    # add the y only results to the single image log message
                    for k, v in test_results_y:
                        formatted_res = k.upper() + ': {:.6f}, '.format(v)
                        logger_m += formatted_res

                logger.info(logger_m)
            else:
                logger.info(img_name)

        # average metrics results for the dataset
        if need_HR and calc_metrics:

            # aggregate the metrics results (automatically resets the metric classes)
            avg_metrics = test_metrics.get_averages()
            avg_metrics_y = test_metrics_y.get_averages()

            # prepare log average metrics message
            agg_logger_m = ''
            for r in avg_metrics:
                formatted_res = r['name'].upper() + ': {:.6f}, '.format(
                    r['average'])
                agg_logger_m += formatted_res
            logger.info(
                '----Average metrics results for {}----\n\t'.format(name) +
                agg_logger_m[:-2])

            if len(avg_metrics_y > 0):
                # prepare log average Y channel metrics message
                agg_logger_m = ''
                for r in avg_metrics_y:
                    formatted_res = r['name'].upper(
                    ) + '_Y' + ': {:.6f}, '.format(r['average'])
                    agg_logger_m += formatted_res
                logger.info('----Y channel, average metrics ----\n\t' +
                            agg_logger_m[:-2])
Beispiel #10
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:
        if os.path.isdir(opt['path']['resume_state']):
            import glob
            resume_state_path = util.sorted_nicely(glob.glob(os.path.normpath(opt['path']['resume_state']) + '/*.state'))[-1]
        else:
            resume_state_path = opt['path']['resume_state']
        resume_state = torch.load(resume_state_path)
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Set [resume_state] to ' + resume_state_path)
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        try:
            tb_logger = SummaryWriter(logdir='../tb_logger/' + opt['name']) #for version tensorboardX >= 1.7
        except:
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) #for version tensorboardX < 1.6

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    # if the model does not change and input sizes remain the same during training then there may be benefit
    # from setting torch.backends.cudnn.benchmark = True, otherwise it may stall training
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    val_loader = False
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            batch_size = dataset_opt.get('batch_size', 4)
            virtual_batch_size = dataset_opt.get('virtual_batch_size', batch_size)
            virtual_batch_size = virtual_batch_size if virtual_batch_size > batch_size else batch_size
            train_size = int(math.ceil(len(train_set) / batch_size))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        virtual_step = current_step * virtual_batch_size / batch_size \
            if virtual_batch_size and virtual_batch_size > batch_size else current_step
        model.resume_training(resume_state)  # handle optimizers and schedulers
        model.update_schedulers(opt['train']) # updated schedulers in case JSON configuration has changed
        del resume_state
        # start the iteration time when resuming
        t0 = time.time()
    else:
        current_step = 0
        virtual_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    try:
        for epoch in range(start_epoch, total_epochs*(virtual_batch_size//batch_size)):
            for n, train_data in enumerate(train_loader,start=1):

                if virtual_step == 0:
                    # first iteration start time
                    t0 = time.time()

                virtual_step += 1
                take_step = False
                if virtual_step > 0 and virtual_step * batch_size % virtual_batch_size == 0:
                    current_step += 1
                    take_step = True
                    if current_step > total_iters:
                        break

                # training
                model.feed_data(train_data)
                model.optimize_parameters(virtual_step)

                # log
                if current_step % opt['logger']['print_freq'] == 0 and take_step:
                    # iteration end time
                    t1 = time.time()

                    logs = model.get_current_log()
                    message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, i_time: {:.4f} sec.> '.format(
                        epoch, current_step, model.get_current_learning_rate(current_step), (t1 - t0))
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard logger
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar(k, v, current_step)
                    logger.info(message)

                    # # start time for next iteration
                    # t0 = time.time()

                # update learning rate
                if model.optGstep and model.optDstep and take_step:
                    model.update_learning_rate(current_step, warmup_iter=opt['train'].get('warmup_iter', -1))

                # save models and training states (changed to save models before validation)
                if current_step % opt['logger']['save_checkpoint_freq'] == 0 and take_step:
                    if model.swa:
                        model.save(current_step, opt['logger']['overwrite_chkp'], loader=train_loader)
                    else:
                        model.save(current_step, opt['logger']['overwrite_chkp'])
                    model.save_training_state(epoch + (n >= len(train_loader)), current_step, opt['logger']['overwrite_chkp'])
                    logger.info('Models and training states saved.')

                # validation
                if val_loader and current_step % opt['train']['val_freq'] == 0 and take_step:
                    val_sr_imgs_list = []
                    val_gt_imgs_list = []
                    val_metrics = metrics.MetricsDict(metrics=opt['train'].get('metrics', None))
                    for val_data in val_loader:
                        img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test(val_data)

                        """
                        Get Visuals
                        """
                        visuals = model.get_current_visuals()
                        sr_img = tensor2np(visuals['SR'], denormalize=opt['datasets']['train']['znorm'])
                        gt_img = tensor2np(visuals['HR'], denormalize=opt['datasets']['train']['znorm'])

                        # Save SR images for reference
                        if opt['train']['overwrite_val_imgs']:
                            save_img_path = os.path.join(img_dir, '{:s}.png'.format(\
                                img_name))
                        else:
                            save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                                img_name, current_step))

                        # save single images or lr / sr comparison
                        if opt['train']['val_comparison']:
                            lr_img = tensor2np(visuals['LR'], denormalize=opt['datasets']['train']['znorm'])
                            util.save_img_comp(lr_img, sr_img, save_img_path)
                        else:
                            util.save_img(sr_img, save_img_path)

                        """
                        Get Metrics
                        # TODO: test using tensor based metrics (batch) instead of numpy.
                        """
                        crop_size = opt['scale']
                        val_metrics.calculate_metrics(sr_img, gt_img, crop_size = crop_size)  #, only_y=True)

                    avg_metrics = val_metrics.get_averages()
                    del val_metrics

                    # log
                    logger_m = ''
                    for r in avg_metrics:
                        #print(r)
                        formatted_res = r['name'].upper()+': {:.5g}, '.format(r['average'])
                        logger_m += formatted_res

                    logger.info('# Validation # '+logger_m[:-2])
                    logger_val = logging.getLogger('val')  # validation logger
                    logger_val.info('<epoch:{:3d}, iter:{:8,d}> '.format(epoch, current_step)+logger_m[:-2])
                    # memory_usage = torch.cuda.memory_allocated()/(1024.0 ** 3) # in GB

                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        for r in avg_metrics:
                            tb_logger.add_scalar(r['name'], r['average'], current_step)

                    # # reset time for next iteration to skip the validation time from calculation
                    # t0 = time.time()

                if current_step % opt['logger']['print_freq'] == 0 and take_step or \
                    (val_loader and current_step % opt['train']['val_freq'] == 0 and take_step):
                    # reset time for next iteration to skip the validation time from calculation
                    t0 = time.time()

        logger.info('Saving the final model.')
        if model.swa:
            model.save('latest', loader=train_loader)
        else:
            model.save('latest')
        logger.info('End of training.')

    except KeyboardInterrupt:
        # catch a KeyboardInterrupt and save the model and state to resume later
        if model.swa:
            model.save(current_step, True, loader=train_loader)
        else:
            model.save(current_step, True)
        model.save_training_state(epoch + (n >= len(train_loader)), current_step, True)
        logger.info('Training interrupted. Latest models and training states saved.')
Beispiel #11
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)
    chop2 = opt['chop']
    chop_patch_size = opt['chop_patch_size']
    multi_upscale = opt['multi_upscale']
    scale = opt['scale']

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    # znorm = False
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        #Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt['znorm'] and znorm == False:
        # znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True
            img_path = data['LR_path'][
                0]  # because there's only 1 image per "data" dataset loader?
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            znorm = test_loader.dataset.opt['znorm']

            if chop2 == True:
                lowres_img = data['LR']  #.to('cuda')
                if multi_upscale:  # Upscale 8 times in different rotations/flips and average the results in a single image
                    LR_90 = lowres_img.transpose(2,
                                                 3).flip(2)  #PyTorch > 0.4.1
                    LR_180 = LR_90.transpose(2, 3).flip(2)  #PyTorch > 0.4.1
                    LR_270 = LR_180.transpose(2, 3).flip(2)  #PyTorch > 0.4.1
                    LR_f = lowres_img.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1
                    LR_90f = LR_90.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1
                    LR_180f = LR_180.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1
                    LR_270f = LR_270.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1

                    pred = chop_forward2(lowres_img,
                                         model,
                                         scale=scale,
                                         patch_size=chop_patch_size)
                    pred_90 = chop_forward2(LR_90,
                                            model,
                                            scale=scale,
                                            patch_size=chop_patch_size)
                    pred_180 = chop_forward2(LR_180,
                                             model,
                                             scale=scale,
                                             patch_size=chop_patch_size)
                    pred_270 = chop_forward2(LR_270,
                                             model,
                                             scale=scale,
                                             patch_size=chop_patch_size)
                    pred_f = chop_forward2(LR_f,
                                           model,
                                           scale=scale,
                                           patch_size=chop_patch_size)
                    pred_90f = chop_forward2(LR_90f,
                                             model,
                                             scale=scale,
                                             patch_size=chop_patch_size)
                    pred_180f = chop_forward2(LR_180f,
                                              model,
                                              scale=scale,
                                              patch_size=chop_patch_size)
                    pred_270f = chop_forward2(LR_270f,
                                              model,
                                              scale=scale,
                                              patch_size=chop_patch_size)

                    #convert to numpy array
                    # if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    #     pred = util.tensor2img(pred,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_90 = util.tensor2img(pred_90,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_180 = util.tensor2img(pred_180,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_270 = util.tensor2img(pred_270,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_f = util.tensor2img(pred_f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_90f = util.tensor2img(pred_90f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_180f = util.tensor2img(pred_180f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_270f = util.tensor2img(pred_270f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    # else: # Default: Image range is [0,1]
                    #     pred = util.tensor2img(pred).clip(0, 255)  # uint8
                    #     pred_90 = util.tensor2img(pred_90).clip(0, 255)  # uint8
                    #     pred_180 = util.tensor2img(pred_180).clip(0, 255)  # uint8
                    #     pred_270 = util.tensor2img(pred_270).clip(0, 255)  # uint8
                    #     pred_f = util.tensor2img(pred_f).clip(0, 255)  # uint8
                    #     pred_90f = util.tensor2img(pred_90f).clip(0, 255)  # uint8
                    #     pred_180f = util.tensor2img(pred_180f).clip(0, 255)  # uint8
                    #     pred_270f = util.tensor2img(pred_270f).clip(0, 255)  # uint8

                    #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                    pred = tensor2np(pred,
                                     denormalize=znorm).clip(0, 255)  # uint8
                    pred_90 = tensor2np(pred_90,
                                        denormalize=znorm).clip(0,
                                                                255)  # uint8
                    pred_180 = tensor2np(pred_180,
                                         denormalize=znorm).clip(0,
                                                                 255)  # uint8
                    pred_270 = tensor2np(pred_270,
                                         denormalize=znorm).clip(0,
                                                                 255)  # uint8
                    pred_f = tensor2np(pred_f,
                                       denormalize=znorm).clip(0, 255)  # uint8
                    pred_90f = tensor2np(pred_90f,
                                         denormalize=znorm).clip(0,
                                                                 255)  # uint8
                    pred_180f = tensor2np(pred_180f,
                                          denormalize=znorm).clip(0,
                                                                  255)  # uint8
                    pred_270f = tensor2np(pred_270f,
                                          denormalize=znorm).clip(0,
                                                                  255)  # uint8

                    pred_90 = np.rot90(pred_90, 3)
                    pred_180 = np.rot90(pred_180, 2)
                    pred_270 = np.rot90(pred_270, 1)
                    pred_f = np.fliplr(pred_f)
                    pred_90f = np.rot90(np.fliplr(pred_90f), 3)
                    pred_180f = np.rot90(np.fliplr(pred_180f), 2)
                    pred_270f = np.rot90(np.fliplr(pred_270f), 1)

                    #The reason for overflow is that your NumPy arrays (im1arr im2arr) are of the uint8 type (i.e. 8-bit). This means each element of the array can only hold values up to 255, so when your sum exceeds 255, it loops back around 0:
                    #To avoid overflow, your arrays should be able to contain values beyond 255. You need to convert them to floats for instance, perform the blending operation and convert the result back to uint8:
                    # sr_img = (pred + pred_90 + pred_180 + pred_270 + pred_f + pred_90f + pred_180f + pred_270f) / 8.0
                    sr_img = (
                        pred.astype('float') + pred_90.astype('float') +
                        pred_180.astype('float') + pred_270.astype('float') +
                        pred_f.astype('float') + pred_90f.astype('float') +
                        pred_180f.astype('float') +
                        pred_270f.astype('float')) / 8.0
                    sr_img = sr_img.astype('uint8')

                else:
                    highres_output = chop_forward2(lowres_img,
                                                   model,
                                                   scale=scale,
                                                   patch_size=chop_patch_size)

                    #convert to numpy array
                    #highres_image = highres_output[0].permute(1, 2, 0).clamp(0.0, 1.0).cpu()
                    # if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    #     sr_img = util.tensor2img(highres_output,min_max=(-1, 1))  # uint8
                    # else: # Default: Image range is [0,1]
                    #     sr_img = util.tensor2img(highres_output)  # uint8

                    #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                    sr_img = tensor2np(highres_output,
                                       denormalize=znorm)  # uint8

            else:  # will upscale each image in the batch without chopping
                model.feed_data(data, need_HR=need_HR)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)

                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                sr_img = tensor2np(visuals['SR'], denormalize=znorm)  # uint8

                # if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                #     sr_img = util.tensor2img(visuals['SR'],min_max=(-1, 1))  # uint8
                # else: # Default: Image range is [0,1]
                #     sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    gt_img = util.tensor2img(visuals['HR'],
                                             min_max=(-1, 1))  # uint8
                else:  # Default: Image range is [0,1]
                    gt_img = util.tensor2img(visuals['HR'])  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))