def __getitem__(self, index): folder = self.data_info['folder'][index] idx, max_idx = self.data_info['idx'][index].split('/') idx, max_idx = int(idx), int(max_idx) border = self.data_info['border'][index] lq_path = self.data_info['lq_path'][index] select_idx = util.generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) if self.cache_data: imgs_lq = self.imgs_lq[folder].index_select( 0, torch.LongTensor(select_idx)) img_gt = self.imgs_gt[folder][idx] else: img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] imgs_lq = util.read_img_seq(img_paths_lq) img_gt = util.read_img_seq([self.imgs_gt[folder][idx]]) img_gt.squeeze_(0) return { 'lq': imgs_lq, # (t, c, h, w) 'gt': img_gt, # (c, h, w) 'folder': folder, # folder name 'idx': self.data_info['idx'][index], # e.g., 0/99 'border': border, # 1 for border, 0 for non-border 'lq_path': lq_path # center frame }
def __getitem__(self, index): folder = self.data_info['folder'][index] idx, max_idx = self.data_info['idx'][index].split('/') idx, max_idx = int(idx), int(max_idx) border = self.data_info['border'][index] lq_path = self.data_info['lq_path'][index] select_idx = util.generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) if self.cache_data: imgs_lq = self.imgs_lq[folder].index_select( 0, torch.LongTensor(select_idx)) img_gt = self.imgs_gt[folder][idx] else: img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] imgs_lq = util.read_img_seq(img_paths_lq) img_gt = util.read_img_seq([self.imgs_gt[folder][idx]]) return { 'lq': imgs_lq, 'gt': img_gt, 'folder': folder, 'idx': self.data_info['idx'][index], 'border': border, 'lq_path': lq_path }
def __getitem__(self, index): folder = self.data_info['folder'][index] idx, max_idx = self.data_info['idx'][index].split('/') idx, max_idx = int(idx), int(max_idx) border = self.data_info['border'][index] lq_path = self.data_info['lq_path'][index] select_idx = util.generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) if self.cache_data: if self.opt['use_duf_downsampling']: # read imgs_gt to generate low-resolution frames imgs_lq = self.imgs_gt[folder].index_select( 0, torch.LongTensor(select_idx)) imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) else: imgs_lq = self.imgs_lq[folder].index_select( 0, torch.LongTensor(select_idx)) img_gt = self.imgs_gt[folder][idx] else: if self.opt['use_duf_downsampling']: img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx] # read imgs_gt to generate low-resolution frames imgs_lq = util.read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale']) imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) else: img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] imgs_lq = util.read_img_seq(img_paths_lq) img_gt = util.read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale']) img_gt.squeeze_(0) return { 'lq': imgs_lq, # (t, c, h, w) 'gt': img_gt, # (c, h, w) 'folder': folder, # folder name 'idx': self.data_info['idx'][index], # e.g., 0/99 'border': border, # 1 for border, 0 for non-border 'lq_path': lq_path # center frame }
def __getitem__(self, index): lq_path = self.data_info['lq_path'][index] gt_path = self.data_info['gt_path'][index] imgs_lq = util.read_img_seq(lq_path) img_gt = util.read_img_seq([gt_path]) return { 'lq': imgs_lq, 'gt': img_gt, 'folder': self.data_info['folder'][index], 'idx': self.data_info['idx'][index], 'border': self.data_info['border'][index], 'lq_path': lq_path[self.opt['num_frame'] // 2] }
def __getitem__(self, index): lq_path = self.data_info['lq_path'][index] gt_path = self.data_info['gt_path'][index] imgs_lq = util.read_img_seq(lq_path) img_gt = util.read_img_seq([gt_path]) img_gt.squeeze_(0) return { 'lq': imgs_lq, # (t, c, h, w) 'gt': img_gt, # (c, h, w) 'folder': self.data_info['folder'][index], # folder name 'idx': self.data_info['idx'][index], # e.g., 0/843 'border': self.data_info['border'][index], # 0 for non-border 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame }
def __getitem__(self, index): lq_path = self.data_info['lq_path'][index] gt_path = self.data_info['gt_path'][index] imgs_lq = util.read_img_seq(lq_path) img_gt = util.read_img_seq([gt_path]) hr_3d = util.read_img_seq( [gt_path.replace('YTB-2/GT', '3d')[:-4] + '_hr.png']) img_gt.squeeze_(0) # hr_3d.squeeze_(0) return { 'lq': imgs_lq, # (t, c, h, w) 'gt': img_gt, # (c, h, w) 'hr_3d': hr_3d, 'folder': self.data_info['folder'][index], # folder name 'idx': self.data_info['idx'][index], # e.g., 0/843 'border': self.data_info['border'][index], # 0 for non-border 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame }
def __init__(self, opt): super(VideoTestDataset, self).__init__() self.opt = opt self.cache_data = opt['cache_data'] self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] self.data_info = { 'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': [] } # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] assert self.io_backend_opt[ 'type'] != 'lmdb', 'No need to use lmdb during validation/test.' logger = get_root_logger() logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') self.imgs_lq, self.imgs_gt = {}, {} if opt['meta_info_file']: with open(opt['meta_info_file'], 'r') as fin: subfolders = [line.split(' ')[0] for line in fin] subfolders_lq = [ osp.join(self.lq_root, key) for key in subfolders ] subfolders_gt = [ osp.join(self.gt_root, key) for key in subfolders ] else: subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*'))) subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*'))) if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']: for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt): # get frame list for lq and gt subfolder_name = osp.basename(subfolder_lq) img_paths_lq = sorted([ osp.join(subfolder_lq, v) for v in mmcv.scandir(subfolder_lq) ]) img_paths_gt = sorted([ osp.join(subfolder_gt, v) for v in mmcv.scandir(subfolder_gt) ]) max_idx = len(img_paths_lq) assert max_idx == len(img_paths_gt), ( f'Different number of images in lq ({max_idx})' f' and gt folders ({len(img_paths_gt)})') self.data_info['lq_path'].extend(img_paths_lq) self.data_info['gt_path'].extend(img_paths_gt) self.data_info['folder'].extend([subfolder_name] * max_idx) for i in range(max_idx): self.data_info['idx'].append(f'{i}/{max_idx}') border_l = [0] * max_idx for i in range(self.opt['num_frame'] // 2): border_l[i] = 1 border_l[max_idx - i - 1] = 1 self.data_info['border'].extend(border_l) # cache data or save the frame list if self.cache_data: logger.info( f'Cache {subfolder_name} for VideoTestDataset...') self.imgs_lq[subfolder_name] = util.read_img_seq( img_paths_lq) self.imgs_gt[subfolder_name] = util.read_img_seq( img_paths_gt) else: self.imgs_lq[subfolder_name] = img_paths_lq self.imgs_gt[subfolder_name] = img_paths_gt else: raise ValueError( f'Non-supported video test dataset: {type(opt["name"])}')
def main(): ##### # configurations ##### device = torch.device('cuda') data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp # Vid4: SR # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); # blur (deblur-clean), blur_comp (deblur-compression). stage = 1 # 1 or 2, use two stage strategy for REDS dataset. flip_test = False ################### # model if data_mode == 'Vid4': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth' # noqa: E501 else: raise ValueError('Vid4 does not support stage 2.') elif data_mode == 'sharp_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth' else: model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth' # noqa: E501 elif data_mode == 'blur_bicubic': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth' # noqa: E501 else: model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth' # noqa: E501 elif data_mode == 'blur': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth' # noqa: E501 else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth' # noqa: E501 elif data_mode == 'blur_comp': if stage == 1: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth' # noqa: E501 else: model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth' # noqa: E501 else: raise NotImplementedError if data_mode == 'Vid4': N_in = 7 # use N_in images to restore one HR image else: N_in = 5 predeblur, hr_in = False, False num_reconstruct_block = 40 if data_mode == 'blur_bicubic': predeblur = True if data_mode == 'blur' or data_mode == 'blur_comp': predeblur, hr_in = True, True if stage == 2: hr_in = True num_reconstruct_block = 20 model = EDVR_arch.EDVR(128, N_in, 8, 5, num_reconstruct_block, predeblur=predeblur, hr_in=hr_in) # dataset if data_mode == 'Vid4': test_dataset_folder = '../datasets/Vid4/BIx4' gt_dataset_folder = '../datasets/Vid4/GT' else: if stage == 1: test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode) else: test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4' print('You should modify the test_dataset_folder path for stage 2') gt_dataset_folder = '../datasets/REDS4/GT' # evaluation crop_border = 0 border_frame = N_in // 2 # border frames when evaluate # temporal padding mode if data_mode == 'Vid4' or data_mode == 'sharp_bicubic': padding = 'new_info' else: padding = 'replicate' save_imgs = True save_folder = '../results/{}'.format(data_mode) util.mkdirs(save_folder) util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) logger = logging.getLogger('base') # log info logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) # set up the models model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device) avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] subfolder_name_l = [] subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) subfolder_gt_l = sorted(glob.glob(osp.join(gt_dataset_folder, '*'))) # for each subfolder for subfolder, subfolder_gt in zip(subfolder_l, subfolder_gt_l): subfolder_name = osp.basename(subfolder) subfolder_name_l.append(subfolder_name) save_subfolder = osp.join(save_folder, subfolder_name) img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) max_idx = len(img_path_l) if save_imgs: util.mkdirs(save_subfolder) # read lq and gt images imgs_lq = data_util.read_img_seq(subfolder) img_gt_l = [] for img_gt_path in sorted(glob.glob(osp.join(subfolder_gt, '*'))): img_gt_l.append(data_util.read_img(None, img_gt_path)) avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 N_border, N_center = 0, 0 # process each image for img_idx, img_path in enumerate(img_path_l): img_name = osp.splitext(osp.basename(img_path))[0] select_idx = data_util.generate_frame_indices(img_idx, max_idx, N_in, padding=padding) imgs_in = imgs_lq.index_select( 0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) if flip_test: output = util.flipx4_forward(model, imgs_in) else: output = util.single_forward(model, imgs_in) output = util.tensor2img(output.squeeze(0)) if save_imgs: cv2.imwrite( osp.join(save_subfolder, '{}.png'.format(img_name)), output) # calculate PSNR output = output / 255. gt = np.copy(img_gt_l[img_idx]) # For REDS, evaluate on RGB channels; for Vid4, evaluate on the # Y channel if data_mode == 'Vid4': # bgr2y, [0, 1] gt = data_util.bgr2ycbcr(gt, only_y=True) output = data_util.bgr2ycbcr(output, only_y=True) output, gt = util.crop_border([output, gt], crop_border) crt_psnr = util.calculate_psnr(output * 255, gt * 255) logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format( img_idx + 1, img_name, crt_psnr)) if img_idx >= border_frame and img_idx < max_idx - border_frame: # center frames avg_psnr_center += crt_psnr N_center += 1 else: # border frames avg_psnr_border += crt_psnr N_border += 1 avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) avg_psnr_center = avg_psnr_center / N_center avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border avg_psnr_l.append(avg_psnr) avg_psnr_center_l.append(avg_psnr_center) avg_psnr_border_l.append(avg_psnr_border) logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; ' 'Center PSNR: {:.6f} dB for {} frames; ' 'Border PSNR: {:.6f} dB for {} frames.'.format( subfolder_name, avg_psnr, (N_center + N_border), avg_psnr_center, N_center, avg_psnr_border, N_border)) logger.info('#### Tidy Outputs ####') for subfolder_name, psnr, psnr_center, psnr_border in zip( subfolder_name_l, avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l): logger.info('Folder {} - Average PSNR: {:.6f} dB. ' 'Center PSNR: {:.6f} dB. ' 'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center, psnr_border)) logger.info('#### Final Results ####') logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) logger.info('Padding mode: {}'.format(padding)) logger.info('Model path: {}'.format(model_path)) logger.info('Save images: {}'.format(save_imgs)) logger.info('Flip test: {}'.format(flip_test)) logger.info('Total Average PSNR: {:.6f} dB for {} clips. ' 'Center PSNR: {:.6f} dB. Border PSNR: {:.6f} dB.'.format( sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), sum(avg_psnr_center_l) / len(avg_psnr_center_l), sum(avg_psnr_border_l) / len(avg_psnr_border_l)))