Example #1
0
    def __getitem__(self, index):
        folder = self.data_info['folder'][index]
        idx, max_idx = self.data_info['idx'][index].split('/')
        idx, max_idx = int(idx), int(max_idx)
        border = self.data_info['border'][index]
        lq_path = self.data_info['lq_path'][index]

        select_idx = util.generate_frame_indices(idx,
                                                 max_idx,
                                                 self.opt['num_frame'],
                                                 padding=self.opt['padding'])

        if self.cache_data:
            imgs_lq = self.imgs_lq[folder].index_select(
                0, torch.LongTensor(select_idx))
            img_gt = self.imgs_gt[folder][idx]
        else:
            img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
            imgs_lq = util.read_img_seq(img_paths_lq)
            img_gt = util.read_img_seq([self.imgs_gt[folder][idx]])
            img_gt.squeeze_(0)

        return {
            'lq': imgs_lq,  # (t, c, h, w)
            'gt': img_gt,  # (c, h, w)
            'folder': folder,  # folder name
            'idx': self.data_info['idx'][index],  # e.g., 0/99
            'border': border,  # 1 for border, 0 for non-border
            'lq_path': lq_path  # center frame
        }
Example #2
0
    def __getitem__(self, index):
        folder = self.data_info['folder'][index]
        idx, max_idx = self.data_info['idx'][index].split('/')
        idx, max_idx = int(idx), int(max_idx)
        border = self.data_info['border'][index]
        lq_path = self.data_info['lq_path'][index]

        select_idx = util.generate_frame_indices(idx,
                                                 max_idx,
                                                 self.opt['num_frame'],
                                                 padding=self.opt['padding'])

        if self.cache_data:
            imgs_lq = self.imgs_lq[folder].index_select(
                0, torch.LongTensor(select_idx))
            img_gt = self.imgs_gt[folder][idx]
        else:
            img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
            imgs_lq = util.read_img_seq(img_paths_lq)
            img_gt = util.read_img_seq([self.imgs_gt[folder][idx]])

        return {
            'lq': imgs_lq,
            'gt': img_gt,
            'folder': folder,
            'idx': self.data_info['idx'][index],
            'border': border,
            'lq_path': lq_path
        }
Example #3
0
    def __getitem__(self, index):
        folder = self.data_info['folder'][index]
        idx, max_idx = self.data_info['idx'][index].split('/')
        idx, max_idx = int(idx), int(max_idx)
        border = self.data_info['border'][index]
        lq_path = self.data_info['lq_path'][index]

        select_idx = util.generate_frame_indices(idx,
                                                 max_idx,
                                                 self.opt['num_frame'],
                                                 padding=self.opt['padding'])

        if self.cache_data:
            if self.opt['use_duf_downsampling']:
                # read imgs_gt to generate low-resolution frames
                imgs_lq = self.imgs_gt[folder].index_select(
                    0, torch.LongTensor(select_idx))
                imgs_lq = duf_downsample(imgs_lq,
                                         kernel_size=13,
                                         scale=self.opt['scale'])
            else:
                imgs_lq = self.imgs_lq[folder].index_select(
                    0, torch.LongTensor(select_idx))
            img_gt = self.imgs_gt[folder][idx]
        else:
            if self.opt['use_duf_downsampling']:
                img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
                # read imgs_gt to generate low-resolution frames
                imgs_lq = util.read_img_seq(img_paths_lq,
                                            require_mod_crop=True,
                                            scale=self.opt['scale'])
                imgs_lq = duf_downsample(imgs_lq,
                                         kernel_size=13,
                                         scale=self.opt['scale'])
            else:
                img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
                imgs_lq = util.read_img_seq(img_paths_lq)
            img_gt = util.read_img_seq([self.imgs_gt[folder][idx]],
                                       require_mod_crop=True,
                                       scale=self.opt['scale'])
            img_gt.squeeze_(0)

        return {
            'lq': imgs_lq,  # (t, c, h, w)
            'gt': img_gt,  # (c, h, w)
            'folder': folder,  # folder name
            'idx': self.data_info['idx'][index],  # e.g., 0/99
            'border': border,  # 1 for border, 0 for non-border
            'lq_path': lq_path  # center frame
        }
Example #4
0
    def __getitem__(self, index):
        lq_path = self.data_info['lq_path'][index]
        gt_path = self.data_info['gt_path'][index]
        imgs_lq = util.read_img_seq(lq_path)
        img_gt = util.read_img_seq([gt_path])

        return {
            'lq': imgs_lq,
            'gt': img_gt,
            'folder': self.data_info['folder'][index],
            'idx': self.data_info['idx'][index],
            'border': self.data_info['border'][index],
            'lq_path': lq_path[self.opt['num_frame'] // 2]
        }
Example #5
0
    def __getitem__(self, index):
        lq_path = self.data_info['lq_path'][index]
        gt_path = self.data_info['gt_path'][index]
        imgs_lq = util.read_img_seq(lq_path)
        img_gt = util.read_img_seq([gt_path])
        img_gt.squeeze_(0)

        return {
            'lq': imgs_lq,  # (t, c, h, w)
            'gt': img_gt,  # (c, h, w)
            'folder': self.data_info['folder'][index],  # folder name
            'idx': self.data_info['idx'][index],  # e.g., 0/843
            'border': self.data_info['border'][index],  # 0 for non-border
            'lq_path': lq_path[self.opt['num_frame'] // 2]  # center frame
        }
Example #6
0
    def __getitem__(self, index):
        lq_path = self.data_info['lq_path'][index]
        gt_path = self.data_info['gt_path'][index]

        imgs_lq = util.read_img_seq(lq_path)

        img_gt = util.read_img_seq([gt_path])
        hr_3d = util.read_img_seq(
            [gt_path.replace('YTB-2/GT', '3d')[:-4] + '_hr.png'])
        img_gt.squeeze_(0)
        # hr_3d.squeeze_(0)
        return {
            'lq': imgs_lq,  # (t, c, h, w)
            'gt': img_gt,  # (c, h, w)
            'hr_3d': hr_3d,
            'folder': self.data_info['folder'][index],  # folder name
            'idx': self.data_info['idx'][index],  # e.g., 0/843
            'border': self.data_info['border'][index],  # 0 for non-border
            'lq_path': lq_path[self.opt['num_frame'] // 2]  # center frame
        }
Example #7
0
    def __init__(self, opt):
        super(VideoTestDataset, self).__init__()
        self.opt = opt
        self.cache_data = opt['cache_data']
        self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
        self.data_info = {
            'lq_path': [],
            'gt_path': [],
            'folder': [],
            'idx': [],
            'border': []
        }
        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']
        assert self.io_backend_opt[
            'type'] != 'lmdb', 'No need to use lmdb during validation/test.'

        logger = get_root_logger()
        logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
        self.imgs_lq, self.imgs_gt = {}, {}
        if opt['meta_info_file']:
            with open(opt['meta_info_file'], 'r') as fin:
                subfolders = [line.split(' ')[0] for line in fin]
                subfolders_lq = [
                    osp.join(self.lq_root, key) for key in subfolders
                ]
                subfolders_gt = [
                    osp.join(self.gt_root, key) for key in subfolders
                ]
        else:
            subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
            subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))

        if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
            for subfolder_lq, subfolder_gt in zip(subfolders_lq,
                                                  subfolders_gt):
                # get frame list for lq and gt
                subfolder_name = osp.basename(subfolder_lq)
                img_paths_lq = sorted([
                    osp.join(subfolder_lq, v)
                    for v in mmcv.scandir(subfolder_lq)
                ])
                img_paths_gt = sorted([
                    osp.join(subfolder_gt, v)
                    for v in mmcv.scandir(subfolder_gt)
                ])

                max_idx = len(img_paths_lq)
                assert max_idx == len(img_paths_gt), (
                    f'Different number of images in lq ({max_idx})'
                    f' and gt folders ({len(img_paths_gt)})')

                self.data_info['lq_path'].extend(img_paths_lq)
                self.data_info['gt_path'].extend(img_paths_gt)
                self.data_info['folder'].extend([subfolder_name] * max_idx)
                for i in range(max_idx):
                    self.data_info['idx'].append(f'{i}/{max_idx}')
                border_l = [0] * max_idx
                for i in range(self.opt['num_frame'] // 2):
                    border_l[i] = 1
                    border_l[max_idx - i - 1] = 1
                self.data_info['border'].extend(border_l)

                # cache data or save the frame list
                if self.cache_data:
                    logger.info(
                        f'Cache {subfolder_name} for VideoTestDataset...')
                    self.imgs_lq[subfolder_name] = util.read_img_seq(
                        img_paths_lq)
                    self.imgs_gt[subfolder_name] = util.read_img_seq(
                        img_paths_gt)
                else:
                    self.imgs_lq[subfolder_name] = img_paths_lq
                    self.imgs_gt[subfolder_name] = img_paths_gt
        else:
            raise ValueError(
                f'Non-supported video test dataset: {type(opt["name"])}')
def main():
    #####
    # configurations
    #####
    device = torch.device('cuda')
    data_mode = 'Vid4'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False
    ###################
    # model
    if data_mode == 'Vid4':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'  # noqa: E501
        else:
            raise ValueError('Vid4 does not support stage 2.')
    elif data_mode == 'sharp_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'  # noqa: E501
    elif data_mode == 'blur_bicubic':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'  # noqa: E501
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'  # noqa: E501
    elif data_mode == 'blur':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'  # noqa: E501
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'  # noqa: E501
    elif data_mode == 'blur_comp':
        if stage == 1:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'  # noqa: E501
        else:
            model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'  # noqa: E501
    else:
        raise NotImplementedError

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, hr_in = False, False
    num_reconstruct_block = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, hr_in = True, True
    if stage == 2:
        hr_in = True
        num_reconstruct_block = 20
    model = EDVR_arch.EDVR(128,
                           N_in,
                           8,
                           5,
                           num_reconstruct_block,
                           predeblur=predeblur,
                           hr_in=hr_in)

    # dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        gt_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        gt_dataset_folder = '../datasets/REDS4/GT'

    # evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    save_folder = '../results/{}'.format(data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    # log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    # set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    subfolder_gt_l = sorted(glob.glob(osp.join(gt_dataset_folder, '*')))
    # for each subfolder
    for subfolder, subfolder_gt in zip(subfolder_l, subfolder_gt_l):
        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        # read lq and gt images
        imgs_lq = data_util.read_img_seq(subfolder)
        img_gt_l = []
        for img_gt_path in sorted(glob.glob(osp.join(subfolder_gt, '*'))):
            img_gt_l.append(data_util.read_img(None, img_gt_path))

        avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0
        N_border, N_center = 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):
            img_name = osp.splitext(osp.basename(img_path))[0]
            select_idx = data_util.generate_frame_indices(img_idx,
                                                          max_idx,
                                                          N_in,
                                                          padding=padding)
            imgs_in = imgs_lq.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)

            if flip_test:
                output = util.flipx4_forward(model, imgs_in)
            else:
                output = util.single_forward(model, imgs_in)
            output = util.tensor2img(output.squeeze(0))

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)),
                    output)

            # calculate PSNR
            output = output / 255.
            gt = np.copy(img_gt_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the
            # Y channel
            if data_mode == 'Vid4':  # bgr2y, [0, 1]
                gt = data_util.bgr2ycbcr(gt, only_y=True)
                output = data_util.bgr2ycbcr(output, only_y=True)

            output, gt = util.crop_border([output, gt], crop_border)
            crt_psnr = util.calculate_psnr(output * 255, gt * 255)
            logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(
                img_idx + 1, img_name, crt_psnr))

            if img_idx >= border_frame and img_idx < max_idx - border_frame:
                # center frames
                avg_psnr_center += crt_psnr
                N_center += 1
            else:  # border frames
                avg_psnr_border += crt_psnr
                N_border += 1

        avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        avg_psnr_center = avg_psnr_center / N_center
        avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        avg_psnr_l.append(avg_psnr)
        avg_psnr_center_l.append(avg_psnr_center)
        avg_psnr_border_l.append(avg_psnr_border)

        logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
                    'Center PSNR: {:.6f} dB for {} frames; '
                    'Border PSNR: {:.6f} dB for {} frames.'.format(
                        subfolder_name, avg_psnr, (N_center + N_border),
                        avg_psnr_center, N_center, avg_psnr_border, N_border))

    logger.info('#### Tidy Outputs ####')
    for subfolder_name, psnr, psnr_center, psnr_border in zip(
            subfolder_name_l, avg_psnr_l, avg_psnr_center_l,
            avg_psnr_border_l):
        logger.info('Folder {} - Average PSNR: {:.6f} dB. '
                    'Center PSNR: {:.6f} dB. '
                    'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr,
                                                     psnr_center, psnr_border))
    logger.info('#### Final Results ####')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
    logger.info('Total Average PSNR: {:.6f} dB for {} clips. '
                'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format(
                    sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l),
                    sum(avg_psnr_center_l) / len(avg_psnr_center_l),
                    sum(avg_psnr_border_l) / len(avg_psnr_border_l)))