예제 #1
0
    def calculate_metrics(self, img1, img2, crop_size=4, only_y=False):
        # if images are np arrays, the convention is that they are in the [0, 255] range
        # tensor are in the [0, 1] range

        # TODO: should images be converted from tensor here?

        if isinstance(img1, torch.Tensor) and isinstance(img2, torch.Tensor):
            pass  # TODO
        else:
            if only_y:
                img1 = bgr2ycbcr(img1, only_y=True)
                img2 = bgr2ycbcr(img2, only_y=True)

            # will handle all cases, RGB or grayscale. Numpy in HWC
            img1 = img1[crop_size:-crop_size, crop_size:-crop_size, ...]
            img2 = img2[crop_size:-crop_size, crop_size:-crop_size, ...]

            calculations = {}
            for _, m in enumerate(self.metrics_list):
                if m['name'] == 'psnr':
                    psnr = calculate_psnr(img1, img2, False)
                    self.psnr_total(psnr)
                    calculations['psnr'] = psnr
                elif m['name'] == 'ssim':
                    ssim = calculate_ssim(img1, img2, False)
                    self.ssim_total(ssim)
                    calculations['ssim'] = ssim
                elif m['name'] == 'lpips' and not only_y:  # single channel images not supported by LPIPS
                    lpips = calculate_lpips([img1], [img2],
                                            model=self.lpips_model).item()
                    self.lpips_total(lpips)
                    calculations['lpips'] = lpips
        self.count += 1
        return calculations
예제 #2
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True,
                        help='Path to options file.')
    opt = options.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs(
        (path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = options.dict_to_nonedict(opt)

    util.setup_logger(None, opt['path']['log'],
                      'test.log', level=logging.INFO, screen=True)
    logger = logging.getLogger('base')
    logger.info(options.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    znorm = False  # TMP
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        if dataset_opt['znorm'] and znorm == False:
            znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['in_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
            top_img = tensor2np(visuals['top_fake'])  # uint8
            bot_img = tensor2np(visuals['bottom_fake'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(
                    dataset_dir, img_name + suffix)
            else:
                save_img_path = os.path.join(dataset_dir, img_name)
            util.save_img(top_img, save_img_path + '_top.png')
            util.save_img(bot_img, save_img_path + '_bot.png')


            #TODO: update to use metrics functions
            # calculate PSNR and SSIM
            if need_HR:
                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                gt_img = tensor2img(visuals['HR'], denormalize=znorm)  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-
                                        crop_border, crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-
                                        crop_border, crop_border:-crop_border, :]

                psnr = util.calculate_psnr(
                    cropped_sr_img * 255, cropped_gt_img * 255)
                ssim = util.calculate_ssim(
                    cropped_sr_img * 255, cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-
                                                crop_border, crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-
                                                crop_border, crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'
                                .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
            else:
                logger.info(img_name)

        #TODO: update to use metrics functions
        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'
                        .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(
                    test_results['psnr_y']) / len(test_results['psnr_y'])
                ave_ssim_y = sum(
                    test_results['ssim_y']) / len(test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'
                            .format(ave_psnr_y, ave_ssim_y))
예제 #3
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))

    scale = opt.get('scale', 4)

    # Create test dataset and dataloader
    test_loaders = []
    znorm = False  #TMP
    # znorm_list = []
    '''
    video_list = os.listdir(cfg.testset_dir)
    for idx_video in range(len(video_list)):
        video_name = video_list[idx_video]
        # dataloader
        test_set = TestsetLoader(cfg, video_name)
        test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False)
    '''

    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt.get['znorm'] and znorm == False:
        #     znorm = True
        znorm = dataset_opt.get('znorm', False)
        # znorm_list.apped(znorm)

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            # tmp_vis(data['LR'][:,1,:,:,:], True)

            if opt.get('chop_forward', None):
                # data
                if len(data['LR'].size()) == 4:
                    b, n_frames, h_lr, w_lr = data['LR'].size()
                    LR_y_cube = data['LR'].view(b, -1, 1, h_lr,
                                                w_lr)  # b, t, c, h, w
                elif len(data['LR'].size()
                         ) == 5:  #for networks that work with 3 channel images
                    _, n_frames, _, _, _ = data['LR'].size()
                    LR_y_cube = data['LR']  # b, t, c, h, w

                # print(LR_y_cube.shape)
                # print(data['LR_bicubic'].shape)

                # crop borders to ensure each patch can be divisible by 2
                #TODO: this is modcrop, not sure if really needed, check (the dataloader already does modcrop)
                _, _, _, h, w = LR_y_cube.size()
                h = int(h // 16) * 16
                w = int(w // 16) * 16
                LR_y_cube = LR_y_cube[:, :, :, :h, :w]
                if isinstance(data['LR_bicubic'], torch.Tensor):
                    # SR_cb = data['LR_bicubic'][:, 1, :, :][:, :, :h * scale, :w * scale]
                    SR_cb = data['LR_bicubic'][:, 1, :h * scale, :w * scale]
                    # SR_cr = data['LR_bicubic'][:, 2, :, :][:, :, :h * scale, :w * scale]
                    SR_cr = data['LR_bicubic'][:, 2, :h * scale, :w * scale]

                SR_y = chop_forward(LR_y_cube, model, scale,
                                    need_HR=need_HR).squeeze(0)
                # SR_y = np.array(SR_y.data.cpu())
                if test_loader.dataset.opt.get('srcolors', None):
                    print(SR_y.shape, SR_cb.shape, SR_cr.shape)
                    sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr),
                                                      -3))
                else:
                    sr_img = SR_y
            else:
                # data
                model.feed_data(data, need_HR=need_HR)
                # SR_y = net(LR_y_cube).squeeze(0)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)
                # ds = torch.nn.AvgPool2d(2, stride=2, count_include_pad=False)
                # tmp_vis(ds(visuals['SR']), True)
                # tmp_vis(visuals['SR'], True)
                if test_loader.dataset.opt.get(
                        'y_only', None) and test_loader.dataset.opt.get(
                            'srcolors', None):
                    SR_cb = data['LR_bicubic'][:, 1, :, :]
                    SR_cr = data['LR_bicubic'][:, 2, :, :]
                    # tmp_vis(ds(SR_cb), True)
                    # tmp_vis(ds(SR_cr), True)
                    sr_img = ycbcr_to_rgb(
                        torch.stack((visuals['SR'], SR_cb, SR_cr), -3))
                else:
                    sr_img = visuals['SR']

            #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
            sr_img = tensor2np(sr_img, denormalize=znorm)  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            #TODO: update to use metrics functions
            # calculate PSNR and SSIM
            if need_HR:
                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                gt_img = tensor2img(visuals['HR'], denormalize=znorm)  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        #TODO: update to use metrics functions
        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
예제 #4
0
    def __getitem__(self, idx):
        scale = self.opt.get('scale', 4)
        HR_size = self.opt.get('HR_size', 128)
        LR_size = HR_size // scale
        idx_center = (self.num_frames - 1) // 2
        ds_kernel = None
        
        # Default case: tensor will result in the [0,1] range
        # Alternative: tensor will be z-normalized to the [-1,1] range
        znorm  = self.opt.get('znorm', False)

        if self.opt['phase'] == 'train':
            if self.opt.get('lr_downscale', None) and self.opt.get('dataroot_kernels', None) and 999 in self.opt["lr_downscale_types"]:
                ds_kernel = self.ds_kernels #KernelDownscale(scale, self.kernel_paths, self.num_kernel)

            # get a random video directory
            idx_video = random.randint(0, len(self.video_list)-1)
            video_dir = self.video_list[idx_video]
            # print(video_dir)
        else:
            # only one video and paths_LR/paths_HR is already the video dir
            video_dir = ""
        
        # list the frames in the directory 
        # hr_dir = self.trainset_dir + '/' + video_dir + '/hr'
        paths_HR = util.get_image_paths(self.opt['data_type'], os.path.join(self.paths_HR, video_dir))
        # print(paths_HR)

        if self.opt['phase'] == 'train':
            # random reverse augmentation
            random_reverse = self.opt.get('random_reverse', False)
            
            # skipping intermediate frames to learn from low FPS videos augmentation
            # testing random frameskip up to 'max_frameskip' frames
            max_frameskip = self.opt.get('max_frameskip', 0)
            if max_frameskip > 0:
                max_frameskip = min(max_frameskip, len(paths_HR)//(self.num_frames-1))
                frameskip = random.randint(1, max_frameskip)
            else:
                frameskip = 1
            # print("max_frameskip: ", max_frameskip)

            assert ((self.num_frames-1)*frameskip) <= (len(paths_HR)-1), (
                f'num_frame*frameskip must be smaller than the number of frames per video, check {video_dir}')
            
            # if number of frames of training video is for example 31, "max index -num_frames" = 31-3=28
            idx_frame = random.randint(0, (len(paths_HR)-1)-((self.num_frames-1)*frameskip))
            # print('frameskip:', frameskip)
        else:
            frameskip = 1
            idx_frame = idx
        
        '''
        List based frames loading
        '''
        if self.paths_LR:
            paths_LR = util.get_image_paths(self.opt['data_type'], os.path.join(self.paths_LR, video_dir))
        else:
            paths_LR = paths_HR
            ds_algo = 777 # default to matlab-like bicubic downscale
            if self.opt.get('lr_downscale', None): # if manually set and scale algorithms are provided, then:
                ds_algo  = self.opt.get('lr_downscale_types', 777)

        # get the video directory
        HR_dir, _ = os.path.split(paths_HR[idx_frame])
        LR_dir, _ = os.path.split(paths_HR[idx_frame])

        # read HR & LR frames
        HR_list = []
        LR_list = []
        resize_type = None
        LR_bicubic = None
        HR_center = None

        # print('len(paths_HR)', len(paths_HR))
        for i_frame in range(self.num_frames):
            # print('frame path:', paths_HR[int(idx_frame)+(frameskip*i_frame)])
            HR_img = util.read_img(None, paths_HR[int(idx_frame)+(frameskip*i_frame)], out_nc=self.image_channels)
            HR_img = util.modcrop(HR_img, scale)

            if self.opt['phase'] == 'train':
                '''
                If using individual image augmentations, get cropping parameters for reuse
                '''
                if self.otf_noise and i_frame == 0: #only need to calculate once, from the first frame
                    # reuse the cropping parameters for all LR and HR frames
                    hr_crop_params, lr_crop_params = get_crop_params(HR_img, LR_size, scale)
                    if self.opt.get('lr_noise', None):
                        # reuse the same noise type for all the frames
                        noise_option = get_noise(self.opt.get('lr_noise_types', None), self.noise_patches)
                    if self.opt.get('lr_blur', None):
                        # reuse the same blur type for all the frames
                        blur_option = get_blur(self.opt.get('lr_blur_types', None))

            if self.paths_LR:
                # LR images are provided at the correct scale
                LR_img = util.read_img(None, paths_LR[int(idx_frame)+(frameskip*i_frame)], out_nc=self.image_channels)
                if LR_img.shape == HR_img.shape:
                    LR_img, resize_type = Scale(img=HR_img, scale=scale, algo=ds_algo, ds_kernel=ds_kernel, resize_type=resize_type)
            else:
                # generate LR images on the fly
                LR_img, resize_type = Scale(img=HR_img, scale=scale, algo=ds_algo, ds_kernel=ds_kernel, resize_type=resize_type)

            # get the bicubic upscale of the center frame to concatenate for SR
            if self.y_only and self.srcolors and i_frame == idx_center:
                LR_bicubic, _ = Scale(img=LR_img, scale=1/scale, algo=777) # bicubic upscale
                HR_center = HR_img
                # tmp_vis(LR_bicubic, False)
                # tmp_vis(HR_center, False)
            
            if self.y_only:
                # extract Y channel from frames
                # normal path, only Y for both
                HR_img = util.bgr2ycbcr(HR_img, only_y=True)
                LR_img = util.bgr2ycbcr(LR_img, only_y=True)

            # crop patches randomly if using otf noise
            #TODO: make a BasicSR composable random_crop
            #TODO: note the original crop should go here and crop after loading each image, but could also be much simpler
            # to crop after concatenating. Check the speed difference.
            if self.otf_noise and self.opt['phase'] == 'train':
                HR_img, LR_img = apply_crop_params(HR_img, LR_img, hr_crop_params, lr_crop_params)
                if self.y_only and self.srcolors and i_frame == idx_center:
                    LR_bicubic, _ = apply_crop_params(LR_bicubic, None, hr_crop_params, None)
                    HR_center, _ = apply_crop_params(HR_center, None, hr_crop_params, None)

            # expand Y images to add the channel dimension
            # normal path, only Y for both
            if self.y_only:
                HR_img = util.fix_img_channels(HR_img, 1)
                LR_img = util.fix_img_channels(LR_img, 1)

            if self.opt['phase'] == 'train':
                # single frame augmentation (noise, blur, etc). Would only be efficient if patches are cropped in this loop
                if self.opt.get('lr_blur', None):
                    if blur_option:
                        LR_img = blur_option(LR_img)
                if self.opt.get('lr_noise', None):
                    if noise_option:
                        LR_img = noise_option(LR_img)
            
                # expand LR images to add the channel dimension again if needed (blur removes the grayscale channel)
                #TODO: add a if condition, can compare to the ndim before the augs, maybe move inside the aug condition
                # if not fullimgchannels: #TODO: TMP, this should be when using srcolors for HR or when training with 3 channels tests, separatedly
                if self.y_only:
                    LR_img = util.fix_img_channels(LR_img, 1)
            
            # print("HR_img.shape: ", HR_img.shape)
            # print("LR_img.shape", LR_img.shape)

            HR_list.append(HR_img) # h, w, c
            LR_list.append(LR_img) # h, w, c

        # print(len(HR_list))
        # print(len(LR_list))

        if self.opt['phase'] == 'train':
            # random reverse sequence augmentation
            if random_reverse and random.random() < 0.5:
                HR_list.reverse()
                LR_list.reverse()

        if not self.y_only:
            t = self.num_frames
            HR = [np.asarray(GT) for GT in HR_list]  # list -> numpy # input: list (contatin numpy: [H,W,C])
            HR = np.asarray(HR) # numpy, [T,H,W,C]
            h_HR, w_HR, c = HR_img.shape #HR_center.shape #TODO: check, may be risky
            HR = HR.transpose(1,2,3,0).reshape(h_HR, w_HR, -1) # numpy, [H',W',CT]
            LR = [np.asarray(LT) for LT in LR_list]  # list -> numpy # input: list (contatin numpy: [H,W,C])
            LR = np.asarray(LR) # numpy, [T,H,W,C]
            LR = LR.transpose(1,2,3,0).reshape(h_HR//scale, w_HR//scale, -1) # numpy, [Hl',Wl',CT]
        else:
            HR = np.concatenate((HR_list), axis=2) # h, w, t
            LR = np.concatenate((LR_list), axis=2) # h, w, t

        if self.opt['phase'] == 'train':
            '''
            # If not using individual image augmentations, this cropping should be faster, only once 
            '''
            # crop patches randomly. If not using otf noise, crop all concatenated images 
            if not self.otf_noise:
                HR, LR, hr_crop_params, _ = random_crop_mod(HR, LR, LR_size, scale)
                if self.y_only and self.srcolors:
                    LR_bicubic, _, _, _ = random_crop_mod(LR_bicubic, _, LR_size, scale, hr_crop_params)
                    HR_center, _, _, _ = random_crop_mod(HR_center, _, LR_size, scale, hr_crop_params)
                    # tmp_vis(LR_bicubic, False)
                    # tmp_vis(HR_center, False)

            # data augmentation
            #TODO: use BasicSR augmentations
            #TODO: use variables from config
            LR, HR, LR_bicubic, HR_center = augmentation()([LR, HR, LR_bicubic, HR_center])

        # tmp_vis(HR, False)
        # tmp_vis(LR, False)
        # tmp_vis(LR_bicubic, False)
        # tmp_vis(HR_center, False)

        if self.y_only:
            HR = util.np2tensor(HR, normalize=znorm, bgr2rgb=False, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
            LR = util.np2tensor(LR, normalize=znorm, bgr2rgb=False, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
        else:
            HR = util.np2tensor(HR, normalize=znorm, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
            LR = util.np2tensor(LR, normalize=znorm, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
        
        #TODO: TMP to test generating 3 channel images for SR loss
        # HR = util.np2tensor(HR, normalize=znorm, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W]
        # LR = util.np2tensor(LR, normalize=znorm, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W]
        
        # if self.srcolors:
        #     HR = HR.view(c,t,HR_size,HR_size) # Tensor, [C,T,H,W]
        if not self.y_only:
            HR = HR.view(c,t,HR_size,HR_size) # Tensor, [C,T,H,W]
            LR = LR.view(c,t,LR_size,LR_size) # Tensor, [C,T,H,W]
            if self.shape == 'TCHW':
                HR = HR.transpose(0,1) # Tensor, [T,C,H,W]
                LR = LR.transpose(0,1) # Tensor, [T,C,H,W]

        # generate Cr, Cb channels using bicubic interpolation
        #TODO: check, it might be easier to return the whole image and separate later when needed
        if self.y_only and self.srcolors:
            LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False)
            # HR_center = util.bgr2ycbcr(HR_center, only_y=False) #not needed, can directly use rgb image
            ## LR_bicubic = util.ycbcr2rgb(LR_bicubic, only_y=False) #test, looks ok
            ## HR_center = util.ycbcr2rgb(HR_center, only_y=False) #test, looks ok
            ## _, SR_cb, SR_cr = util.bgr2ycbcr(LR_bicubic, only_y=False, separate=True)
            LR_bicubic = util.np2tensor(LR_bicubic, normalize=znorm, bgr2rgb=False, add_batch=False)
            # HR_center = util.np2tensor(HR_center, normalize=znorm, bgr2rgb=False, add_batch=False) # will test using rgb image instead
            HR_center = util.np2tensor(HR_center, normalize=znorm, bgr2rgb=True, add_batch=False)
            #TODO: TMP to test generating 3 channel images for SR loss
            # LR_bicubic = util.np2tensor(LR_bicubic, normalize=znorm, bgr2rgb=False, add_batch=True)
            # HR_center = util.np2tensor(HR_center, normalize=znorm, bgr2rgb=False, add_batch=True)
        elif self.y_only and not self.srcolors:
            LR_bicubic = []
            HR_center = []
        else:
            HR_center = HR[:,idx_center,:,:] if self.shape == 'CTHW' else HR[idx_center,:,:,:]
            LR_bicubic = []

        # return toTensor(LR), toTensor(HR)
        return {'LR': LR, 'HR': HR, 'LR_path': LR_dir, 'HR_path': HR_dir, 'LR_bicubic': LR_bicubic, 'HR_center': HR_center}
예제 #5
0
    def __getitem__(self, idx):
        scale = self.opt.get('scale', 4)
        idx_center = (self.num_frames - 1) // 2
        h_LR = None
        w_LR = None

        # Default case: tensor will result in the [0,1] range
        # Alternative: tensor will be z-normalized to the [-1,1] range
        znorm  = self.opt.get('znorm', False)

        # only one video and paths_LR/paths_HR is already the video dir
        video_dir = ""
        
        # list the frames in the directory 
        # hr_dir = self.trainset_dir + '/' + video_dir + '/hr'

        '''
        List based frames loading
        '''
        paths_LR = util.get_image_paths(self.opt['data_type'], os.path.join(self.paths_LR, video_dir))

        assert self.num_frames <= len(paths_LR), (
            f'num_frame must be smaller than the number of frames per video, check {video_dir}')

        idx_frame = idx
        LR_name = paths_LR[idx_frame + 1] # center frame
        # print(LR_name)
        # print(len(self.video_list))

        # read LR frames
        # HR_list = []
        LR_list = []
        resize_type = None
        LR_bicubic = None
        for i_frame in range(self.num_frames):
            if idx_frame == len(self.video_list)-2 and self.num_frames == 3:
                # print("second to last frame:", i_frame)
                if i_frame == 0:
                    LR_img = util.read_img(None, paths_LR[int(idx_frame)], out_nc=self.image_channels)
                else:
                    LR_img = util.read_img(None, paths_LR[int(idx_frame)+1], out_nc=self.image_channels)
            elif idx_frame == len(self.video_list)-1 and self.num_frames == 3:
                # print("last frame:", i_frame)
                LR_img = util.read_img(None, paths_LR[int(idx_frame)], out_nc=self.image_channels)
            # every other internal frame
            else:
                # print("normal frame:", idx_frame)
                LR_img = util.read_img(None, paths_LR[int(idx_frame)+(i_frame)], out_nc=self.image_channels)
            #TODO: check if this is necessary
            LR_img = util.modcrop(LR_img, scale)

            # get the bicubic upscale of the center frame to concatenate for SR
            if not self.y_only and self.srcolors and i_frame == idx_center:
                if self.opt.get('denoise_LRbic', False):
                    LR_bicubic = transforms.RandomAverageBlur(p=1, kernel_size=3)(LR_img)
                    # LR_bicubic = transforms.RandomBoxBlur(p=1, kernel_size=3)(LR_img)
                else:
                    LR_bicubic = LR_img
                LR_bicubic, _ = Scale(img=LR_bicubic, scale=1/scale, algo=777) # bicubic upscale
                # HR_center = HR_img
                # tmp_vis(LR_bicubic, False)
                # tmp_vis(HR_center, False)
            
            if self.y_only:
                # extract Y channel from frames
                # normal path, only Y for both
                LR_img = util.bgr2ycbcr(LR_img, only_y=True)

                # expand Y images to add the channel dimension
                # normal path, only Y for both
                LR_img = util.fix_img_channels(LR_img, 1)
                
                # print("HR_img.shape: ", HR_img.shape)
                # print("LR_img.shape", LR_img.shape)

            LR_list.append(LR_img) # h, w, c
            
            if not self.y_only and (not h_LR or not w_LR):
                h_LR, w_LR, c = LR_img.shape
        
        if not self.y_only:
            t = self.num_frames
            LR = [np.asarray(LT) for LT in LR_list]  # list -> numpy # input: list (contatin numpy: [H,W,C])
            LR = np.asarray(LR) # numpy, [T,H,W,C]
            LR = LR.transpose(1,2,3,0).reshape(h_LR, w_LR, -1) # numpy, [Hl',Wl',CT]
        else:
            LR = np.concatenate((LR_list), axis=2) # h, w, t

        if self.y_only:
            LR = util.np2tensor(LR, normalize=znorm, bgr2rgb=False, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
        else:
            LR = util.np2tensor(LR, normalize=znorm, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
            LR = LR.view(c,t,h_LR,w_LR) # Tensor, [C,T,H,W]
            if self.shape == 'TCHW':
                LR = LR.transpose(0,1) # Tensor, [T,C,H,W]

        if self.y_only and self.srcolors:
            # generate Cr, Cb channels using bicubic interpolation
            LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False)
            LR_bicubic = util.np2tensor(LR_bicubic, normalize=znorm, bgr2rgb=False, add_batch=False)
            HR_center = []
        else:
            LR_bicubic = []
            HR_center = []

        # return toTensor(LR), toTensor(HR)
        return {'LR': LR, 'LR_path': LR_name, 'LR_bicubic': LR_bicubic, 'HR_center': HR_center}
예제 #6
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)
    chop2 = opt['chop']
    chop_patch_size = opt['chop_patch_size']
    multi_upscale = opt['multi_upscale']
    scale = opt['scale']

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    # znorm = False
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        #Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        # if dataset_opt['znorm'] and znorm == False:
        # znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True
            img_path = data['LR_path'][
                0]  # because there's only 1 image per "data" dataset loader?
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            znorm = test_loader.dataset.opt['znorm']

            if chop2 == True:
                lowres_img = data['LR']  #.to('cuda')
                if multi_upscale:  # Upscale 8 times in different rotations/flips and average the results in a single image
                    LR_90 = lowres_img.transpose(2,
                                                 3).flip(2)  #PyTorch > 0.4.1
                    LR_180 = LR_90.transpose(2, 3).flip(2)  #PyTorch > 0.4.1
                    LR_270 = LR_180.transpose(2, 3).flip(2)  #PyTorch > 0.4.1
                    LR_f = lowres_img.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1
                    LR_90f = LR_90.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1
                    LR_180f = LR_180.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1
                    LR_270f = LR_270.flip(
                        3
                    )  # horizontal mirror (flip), dim=3 (B,C,H,W=0,1,2,3) #PyTorch > 0.4.1

                    pred = chop_forward2(lowres_img,
                                         model,
                                         scale=scale,
                                         patch_size=chop_patch_size)
                    pred_90 = chop_forward2(LR_90,
                                            model,
                                            scale=scale,
                                            patch_size=chop_patch_size)
                    pred_180 = chop_forward2(LR_180,
                                             model,
                                             scale=scale,
                                             patch_size=chop_patch_size)
                    pred_270 = chop_forward2(LR_270,
                                             model,
                                             scale=scale,
                                             patch_size=chop_patch_size)
                    pred_f = chop_forward2(LR_f,
                                           model,
                                           scale=scale,
                                           patch_size=chop_patch_size)
                    pred_90f = chop_forward2(LR_90f,
                                             model,
                                             scale=scale,
                                             patch_size=chop_patch_size)
                    pred_180f = chop_forward2(LR_180f,
                                              model,
                                              scale=scale,
                                              patch_size=chop_patch_size)
                    pred_270f = chop_forward2(LR_270f,
                                              model,
                                              scale=scale,
                                              patch_size=chop_patch_size)

                    #convert to numpy array
                    # if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    #     pred = util.tensor2img(pred,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_90 = util.tensor2img(pred_90,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_180 = util.tensor2img(pred_180,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_270 = util.tensor2img(pred_270,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_f = util.tensor2img(pred_f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_90f = util.tensor2img(pred_90f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_180f = util.tensor2img(pred_180f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    #     pred_270f = util.tensor2img(pred_270f,min_max=(-1, 1)).clip(0, 255)  # uint8
                    # else: # Default: Image range is [0,1]
                    #     pred = util.tensor2img(pred).clip(0, 255)  # uint8
                    #     pred_90 = util.tensor2img(pred_90).clip(0, 255)  # uint8
                    #     pred_180 = util.tensor2img(pred_180).clip(0, 255)  # uint8
                    #     pred_270 = util.tensor2img(pred_270).clip(0, 255)  # uint8
                    #     pred_f = util.tensor2img(pred_f).clip(0, 255)  # uint8
                    #     pred_90f = util.tensor2img(pred_90f).clip(0, 255)  # uint8
                    #     pred_180f = util.tensor2img(pred_180f).clip(0, 255)  # uint8
                    #     pred_270f = util.tensor2img(pred_270f).clip(0, 255)  # uint8

                    #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                    pred = tensor2np(pred,
                                     denormalize=znorm).clip(0, 255)  # uint8
                    pred_90 = tensor2np(pred_90,
                                        denormalize=znorm).clip(0,
                                                                255)  # uint8
                    pred_180 = tensor2np(pred_180,
                                         denormalize=znorm).clip(0,
                                                                 255)  # uint8
                    pred_270 = tensor2np(pred_270,
                                         denormalize=znorm).clip(0,
                                                                 255)  # uint8
                    pred_f = tensor2np(pred_f,
                                       denormalize=znorm).clip(0, 255)  # uint8
                    pred_90f = tensor2np(pred_90f,
                                         denormalize=znorm).clip(0,
                                                                 255)  # uint8
                    pred_180f = tensor2np(pred_180f,
                                          denormalize=znorm).clip(0,
                                                                  255)  # uint8
                    pred_270f = tensor2np(pred_270f,
                                          denormalize=znorm).clip(0,
                                                                  255)  # uint8

                    pred_90 = np.rot90(pred_90, 3)
                    pred_180 = np.rot90(pred_180, 2)
                    pred_270 = np.rot90(pred_270, 1)
                    pred_f = np.fliplr(pred_f)
                    pred_90f = np.rot90(np.fliplr(pred_90f), 3)
                    pred_180f = np.rot90(np.fliplr(pred_180f), 2)
                    pred_270f = np.rot90(np.fliplr(pred_270f), 1)

                    #The reason for overflow is that your NumPy arrays (im1arr im2arr) are of the uint8 type (i.e. 8-bit). This means each element of the array can only hold values up to 255, so when your sum exceeds 255, it loops back around 0:
                    #To avoid overflow, your arrays should be able to contain values beyond 255. You need to convert them to floats for instance, perform the blending operation and convert the result back to uint8:
                    # sr_img = (pred + pred_90 + pred_180 + pred_270 + pred_f + pred_90f + pred_180f + pred_270f) / 8.0
                    sr_img = (
                        pred.astype('float') + pred_90.astype('float') +
                        pred_180.astype('float') + pred_270.astype('float') +
                        pred_f.astype('float') + pred_90f.astype('float') +
                        pred_180f.astype('float') +
                        pred_270f.astype('float')) / 8.0
                    sr_img = sr_img.astype('uint8')

                else:
                    highres_output = chop_forward2(lowres_img,
                                                   model,
                                                   scale=scale,
                                                   patch_size=chop_patch_size)

                    #convert to numpy array
                    #highres_image = highres_output[0].permute(1, 2, 0).clamp(0.0, 1.0).cpu()
                    # if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    #     sr_img = util.tensor2img(highres_output,min_max=(-1, 1))  # uint8
                    # else: # Default: Image range is [0,1]
                    #     sr_img = util.tensor2img(highres_output)  # uint8

                    #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                    sr_img = tensor2np(highres_output,
                                       denormalize=znorm)  # uint8

            else:  # will upscale each image in the batch without chopping
                model.feed_data(data, need_HR=need_HR)
                model.test()  # test
                visuals = model.get_current_visuals(need_HR=need_HR)

                #if znorm the image range is [-1,1], Default: Image range is [0,1] # testing, each "dataset" can have a different name (not train, val or other)
                sr_img = tensor2np(visuals['SR'], denormalize=znorm)  # uint8

                # if znorm: #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                #     sr_img = util.tensor2img(visuals['SR'],min_max=(-1, 1))  # uint8
                # else: # Default: Image range is [0,1]
                #     sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    gt_img = util.tensor2img(visuals['HR'],
                                             min_max=(-1, 1))  # uint8
                else:  # Default: Image range is [0,1]
                    gt_img = util.tensor2img(visuals['HR'])  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))