class SingleImageDataset(data.Dataset): """Read only lq images in the test phase. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). There are two modes: 1. 'meta_info_file': Use meta information file to generate paths. 2. 'folder': Scan folders to generate paths. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_lq (str): Data root path for lq. meta_info_file (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. """ def __init__(self, opt): super(SingleImageDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.lq_folder = opt['dataroot_lq'] if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.lq_folder] self.io_backend_opt['client_keys'] = ['lq'] self.paths = single_paths_from_lmdb([self.lq_folder], ['lq']) elif 'meta_info_file' in self.opt and self.opt[ 'meta_info_file'] is not None: with open(self.opt['meta_info_file'], 'r') as fin: self.paths = [ osp.join(self.lq_folder, line.split(' ')[0]) for line in fin ] else: self.paths = [ osp.join(self.lq_folder, v) for v in mmcv.scandir(self.lq_folder) ] def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_lq = totensor(img_lq, bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) return {'lq': img_lq, 'lq_path': lq_path} def __len__(self): return len(self.paths)
class FFHQDataset(data.Dataset): """FFHQ dataset for StyleGAN. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. io_backend (dict): IO backend type and other kwarg. mean (list | tuple): Image mean. std (list | tuple): Image std. use_hflip (bool): Whether to horizontally flip. """ def __init__(self, opt): super(FFHQDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] self.mean = opt['mean'] self.std = opt['std'] if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = self.gt_folder if not self.gt_folder.endswith('.lmdb'): raise ValueError("'dataroot_gt' should end with '.lmdb', " f'but received {self.gt_folder}') with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: # FFHQ has 70000 images in total self.paths = [ osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) ] def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image gt_path = self.paths[index] img_bytes = self.file_client.get(gt_path) img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path} def __len__(self): return len(self.paths)
class REDSDataset(data.Dataset): """REDS dataset for training. The keys are generated from a meta info txt file. basicsr/data/meta_info/meta_info_REDS_GT.txt Each line contains: 1. subfolder (clip) name; 2. frame number; 3. image shape, seperated by a white space. Examples: 000 100 (720,1280,3) 001 100 (720,1280,3) ... Key examples: "000/00000000" GT (gt): Ground-Truth; LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. Args: opt (dict): Config for train dataset. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. dataroot_flow (str, optional): Data root path for flow. meta_info_file (str): Path for meta information file. val_partition (str): Validation partition types. 'REDS4' or 'official'. io_backend (dict): IO backend type and other kwarg. num_frame (int): Window size for input frames. gt_size (int): Cropped patched size for gt patches. interval_list (list): Interval list for temporal augmentation. random_reverse (bool): Random reverse input frames. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. """ def __init__(self, opt): super(REDSDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( opt['dataroot_lq']) self.flow_root = Path( opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None assert opt['num_frame'] % 2 == 1, ( f'num_frame should be odd number, but got {opt["num_frame"]}') self.num_frame = opt['num_frame'] self.num_half_frames = opt['num_frame'] // 2 self.keys = [] with open(opt['meta_info_file'], 'r') as fin: for line in fin: folder, frame_num, _ = line.split(' ') self.keys.extend( [f'{folder}/{i:08d}' for i in range(int(frame_num))]) # remove the video clips used in validation if opt['val_partition'] == 'REDS4': val_partition = ['000', '011', '015', '020'] elif opt['val_partition'] == 'official': val_partition = [f'{v:03d}' for v in range(240, 270)] else: raise ValueError( f'Wrong validation partition {opt["val_partition"]}.' f"Supported ones are ['official', 'REDS4'].") self.keys = [ v for v in self.keys if v.split('/')[0] not in val_partition ] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True if self.flow_root is not None: self.io_backend_opt['db_paths'] = [ self.lq_root, self.gt_root, self.flow_root ] self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow'] else: self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] self.io_backend_opt['client_keys'] = ['lq', 'gt'] # temporal augmentation configs self.interval_list = opt['interval_list'] self.random_reverse = opt['random_reverse'] interval_str = ','.join(str(x) for x in opt['interval_list']) logger = get_root_logger() logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' f'random reverse is {self.random_reverse}.') def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] gt_size = self.opt['gt_size'] key = self.keys[index] clip_name, frame_name = key.split('/') # key example: 000/00000000 center_frame_idx = int(frame_name) # determine the neighboring frames interval = random.choice(self.interval_list) # ensure not exceeding the borders start_frame_idx = center_frame_idx - self.num_half_frames * interval end_frame_idx = center_frame_idx + self.num_half_frames * interval # each clip has 100 frames starting from 0 to 99 while (start_frame_idx < 0) or (end_frame_idx > 99): center_frame_idx = random.randint(0, 99) start_frame_idx = (center_frame_idx - self.num_half_frames * interval) end_frame_idx = center_frame_idx + self.num_half_frames * interval frame_name = f'{center_frame_idx:08d}' neighbor_list = list( range(center_frame_idx - self.num_half_frames * interval, center_frame_idx + self.num_half_frames * interval + 1, interval)) # random reverse if self.random_reverse and random.random() < 0.5: neighbor_list.reverse() assert len(neighbor_list) == self.num_frame, ( f'Wrong length of neighbor list: {len(neighbor_list)}') # get the GT frame (as the center frame) if self.is_lmdb: img_gt_path = f'{clip_name}/{frame_name}' else: img_gt_path = self.gt_root / clip_name / f'{frame_name}.png' img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) # get the neighboring LQ frames img_lqs = [] for neighbor in neighbor_list: if self.is_lmdb: img_lq_path = f'{clip_name}/{neighbor:08d}' else: img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # get flows if self.flow_root is not None: img_flows = [] # read previous flows for i in range(self.num_half_frames, 0, -1): if self.is_lmdb: flow_path = f'{clip_name}/{frame_name}_p{i}' else: flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) # read next flows for i in range(1, self.num_half_frames + 1): if self.is_lmdb: flow_path = f'{clip_name}/{frame_name}_n{i}' else: flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) # for random crop, here, img_flows and img_lqs have the same # spatial size img_lqs.extend(img_flows) # randomly crop img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) if self.flow_root is not None: img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self. num_frame:] # augmentation - flip, rotate img_lqs.append(img_gt) if self.flow_root is not None: img_results, img_flows = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot'], img_flows) else: img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] if self.flow_root is not None: img_flows = img2tensor(img_flows) # add the zero center flow img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) img_flows = torch.stack(img_flows, dim=0) # img_lqs: (t, c, h, w) # img_flows: (t, 2, h, w) # img_gt: (c, h, w) # key: str if self.flow_root is not None: return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key} else: return {'lq': img_lqs, 'gt': img_gt, 'key': key} def __len__(self): return len(self.keys)
class REDSRecurrentDataset(data.Dataset): """REDS recurrent dataset for training. The keys are generated from a meta info txt file. basicsr/data/meta_info/meta_info_REDS_GT.txt Each line contains: 1. subfolder (clip) name; 2. frame number; 3. image shape, seperated by a white space. Examples: 000 100 (720,1280,3) 001 100 (720,1280,3) ... Key examples: "000/00000000" GT (gt): Ground-Truth, length is equal to LQ; LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. Args: opt (dict): Config for train dataset. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. meta_info_file (str): Path for meta information file. val_partition (str): Validation partition types. 'REDS4' or 'official'. io_backend (dict): IO backend type and other kwarg. num_frame (int): Window size for input frames. gt_size (int): Cropped patched size for gt patches. interval_list (list): Interval list for temporal augmentation. random_reverse (bool): Random reverse input frames. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. """ def __init__(self, opt): super(REDSRecurrentDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( opt['dataroot_lq']) self.is_train = opt.get('is_train', False) self.num_frame = opt['num_frame'] self.num_half_frames = opt['num_frame'] // 2 self.keys = [] if self.is_train: with open(opt['meta_info_file'], 'r') as fin: for line in fin: folder, frame_num, _ = line.split(' ') self.keys.extend( [f'{folder}/{i:08d}' for i in range(int(frame_num))]) else: with open(opt['meta_info_file'], 'r') as fin: for line in fin: folder, frame_num, _ = line.split(' ') frame_num = int(frame_num) assert int(frame_num) % opt['num_frame'] == 0, ( f'frame_num of "{folder}" is not divided by ' f'{opt["num_frame"]}') self.keys.extend([ f'{folder}/{i:08d}' for i in range( self.num_half_frames, 100, self.num_frame) ]) # remove the video clips used in validation if opt['val_partition'] == 'REDS4': val_partition = ['000', '011', '015', '020'] elif opt['val_partition'] == 'official': val_partition = [f'{v:03d}' for v in range(240, 270)] else: raise ValueError( f'Wrong validation partition {opt["val_partition"]}.' f"Supported ones are ['official', 'REDS4'].") if self.is_train: self.keys = [ v for v in self.keys if v.split('/')[0] not in val_partition ] else: self.keys = [ v for v in self.keys if v.split('/')[0] in val_partition ] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] self.io_backend_opt['client_keys'] = ['lq', 'gt'] # temporal augmentation configs self.interval_list = opt['interval_list'] self.random_reverse = opt['random_reverse'] interval_str = ','.join(str(x) for x in opt['interval_list']) logger = get_root_logger() logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' f'random reverse is {self.random_reverse}.') def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient( self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] gt_size = self.opt.get('gt_size', None) key = self.keys[index] clip_name, frame_name = key.split('/') # key example: 000/00000000 center_frame_idx = int(frame_name) # determine the frameing frames interval = random.choice(self.interval_list) # ensure not exceeding the borders start_frame_idx = center_frame_idx - self.num_half_frames * interval end_frame_idx = start_frame_idx + (self.num_frame - 1) * interval # each clip has 100 frames starting from 0 to 99 while (start_frame_idx < 0) or (end_frame_idx > 99): center_frame_idx = random.randint( self.num_half_frames * interval, 99 - self.num_half_frames *interval) start_frame_idx = (center_frame_idx - self.num_half_frames * interval) end_frame_idx = start_frame_idx + (self.num_frame - 1) * interval frame_name = f'{center_frame_idx:08d}' frame_list = list( range(start_frame_idx, end_frame_idx + 1, interval)) # random reverse if self.random_reverse and random.random() < 0.5: frame_list.reverse() assert len(frame_list) == self.num_frame, ( f'Wrong length of frame list: {len(frame_list)}') # get the GT frame (as the center frame) img_gts = [] for frame in frame_list: if self.is_lmdb: img_gt_path = f'{clip_name}/{frame:08d}' else: img_gt_path = self.gt_root / clip_name / f'{frame:08d}.png' img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) img_gts.append(img_gt) # get the LQ frames img_lqs = [] for frame in frame_list: if self.is_lmdb: img_lq_path = f'{clip_name}/{frame:08d}' else: img_lq_path = self.lq_root / clip_name / f'{frame:08d}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # randomly crop if self.is_train: img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, clip_name) # augmentation - flip, rotate img_lqs.extend(img_gts) if self.is_train: img_lqs = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = img2tensor(img_lqs) img_lqs = torch.stack(img_results[:self.num_frame], dim=0) img_gts = torch.stack(img_results[self.num_frame:], dim=0) # img_lqs: (t, c, h, w) # img_gt: (t, c, h, w) # key: str return {'lq': img_lqs, 'gt': img_gts, 'key': key, 'frame_list': frame_list} def __len__(self): return len(self.keys)
class PairedImageWTDataset(data.Dataset): """Paired image dataset for image restoration. Read WT (Wavelet Transformed (3 times in total from 256x256), and GT image pairs. There are three modes: 1. 'lmdb': Use lmdb files. If opt['io_backend'] == lmdb. 2. 'meta_info_file': Use meta information file to generate paths. If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 3. 'folder': Scan folders to generate paths. The rest. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. meta_info_file (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. Default: '{}'. gt_size (int): Cropped patched size for gt patches. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. phase (str): 'train' or 'val'. """ def __init__(self, opt): super(PairedImageWTDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] if 'filename_tmpl' in opt: self.filename_tmpl = opt['filename_tmpl'] else: self.filename_tmpl = '{}' if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] self.io_backend_opt['client_keys'] = ['lq', 'gt'] self.paths = paired_paths_from_lmdb( [self.lq_folder, self.gt_folder], ['lq', 'gt']) elif 'meta_info_file' in self.opt and self.opt[ 'meta_info_file'] is not None: self.paths = paired_paths_from_meta_info_file( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.opt['meta_info_file'], self.filename_tmpl) else: self.paths = paired_paths_from_folder( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. img_gt_h = img_gt.shape[0] img_gt_w = img_gt.shape[1] lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = np.copy(np.frombuffer(img_bytes, dtype='float32')).reshape( img_gt_h // scale, img_gt_w // scale, -1) # No augmentation for training # if self.opt['phase'] == 'train': # gt_size = self.opt['gt_size'] # # random crop # img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, # gt_path) # # flip, rotation # img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], # self.opt['use_rot']) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = totensor([img_gt, img_lq], bgr2rgb=True, float32=True) return { 'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path } def __len__(self): return len(self.paths)
class FFHQDataset(data.Dataset): """FFHQ dataset for StyleGAN. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. io_backend (dict): IO backend type and other kwarg. mean (list | tuple): Image mean. std (list | tuple): Image std. use_hflip (bool): Whether to horizontally flip. """ def __init__(self, opt): super(FFHQDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] self.mean = opt['mean'] self.std = opt['std'] if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = self.gt_folder if not self.gt_folder.endswith('.lmdb'): raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: # FFHQ has 70000 images in total self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image gt_path = self.paths[index] # avoid errors caused by high latency in reading files retry = 3 while retry > 0: try: img_bytes = self.file_client.get(gt_path) except Exception as e: logger = get_root_logger() logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') # change another file to read index = random.randint(0, self.__len__()) gt_path = self.paths[index] time.sleep(1) # sleep 1s for occasional server congestion else: break finally: retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path} def __len__(self): return len(self.paths)
class Vimeo90KDataset(data.Dataset): """Vimeo90K dataset for training. The keys are generated from a meta info txt file. basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt Each line contains: 1. clip name; 2. frame number; 3. image shape, seperated by a white space. Examples: 00001/0001 7 (256,448,3) 00001/0002 7 (256,448,3) Key examples: "00001/0001" GT (gt): Ground-Truth; LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. The neighboring frame list for different num_frame: num_frame | frame list 1 | 4 3 | 3,4,5 5 | 2,3,4,5,6 7 | 1,2,3,4,5,6,7 Args: opt (dict): Config for train dataset. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. meta_info_file (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. num_frame (int): Window size for input frames. gt_size (int): Cropped patched size for gt patches. random_reverse (bool): Random reverse input frames. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. """ def __init__(self, opt): super(Vimeo90KDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( opt['dataroot_lq']) with open(opt['meta_info_file'], 'r') as fin: self.keys = [line.split(' ')[0] for line in fin] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] self.io_backend_opt['client_keys'] = ['lq', 'gt'] # indices of input images self.neighbor_list = [ i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) ] # temporal augmentation configs self.random_reverse = False logger = get_root_logger() logger.info(f'Random reverse is {self.random_reverse}.') def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # random reverse if self.random_reverse and random.random() < 0.5: self.neighbor_list.reverse() scale = self.opt['scale'] gt_size = self.opt['gt_size'] key = self.keys[index] clip, seq = key.split('/') # key example: 00001/0001 # get the GT frame (im4.png) if self.is_lmdb: img_gt_path = f'{key}/im4' else: img_gt_path = self.gt_root / clip / seq / 'im4.png' img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. # get the neighboring LQ frames img_lqs = [] for neighbor in self.neighbor_list: if self.is_lmdb: img_lq_path = f'{clip}/{seq}/im{neighbor}' else: img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. img_lqs.append(img_lq) # randomly crop img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) # augmentation - flip, rotate img_lqs.append(img_gt) img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = totensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] # img_lqs: (t, c, h, w) # img_gt: (c, h, w) # key: str try: flow_path = f'{img_gt_path}'[:-7] + 'flow_160.npy' flow = np.load(flow_path, allow_pickle=True) except: flow_path = f'{img_gt_path}'[:-7] + 'flow.npy' flow = np.load(flow_path, allow_pickle=True) flow = flow / 2.0 flow = np.transpose(flow, [0, 3, 1, 2]) #7 2 448 448 # print('flow: ', flow.shape) # reverse 38.10 # tmp = flow[:, 0, :, :] # flow[:, 0, :, :] = flow[:, 1, :, :] # flow[:, 1, :, :] = tmp ### get 18 # ztm = np.load(path_flow,allow_pickle=True) # result_7 = [] # for test in ztm: # test = np.transpose(test, [2,1,0]) # width = test.shape[1] # height = test.shape[2] # ndarray=np.pad(test,((0,0),(1,1),(1,1)),'constant', constant_values=0) # result=[] # for i in range(0,3): # for j in range(0,3): # result.append(ndarray[:,i:i+448,j:j+448]) # result = np.array(result).reshape(18,448,448) # #result = np.repeat(result,8,axis=0) # result_7.append(np.array(result)) # save_path = path_flow.replace('flow.npy','flow_7.npy') # np.save(save_path,np.array(result_7)) ### get18 #return np.array(result_7) return {'lq': img_lqs, 'gt': img_gt, 'key': key, 'flow': flow} def __len__(self): return len(self.keys)
class RealESRGANDataset(data.Dataset): """ Dataset used for Real-ESRGAN model. """ def __init__(self, opt): super(RealESRGANDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.gt_folder] self.io_backend_opt['client_keys'] = ['gt'] if not self.gt_folder.endswith('.lmdb'): raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: with open(self.opt['meta_info']) as fin: paths = [line.strip() for line in fin] self.paths = [os.path.join(self.gt_folder, v) for v in paths] # blur settings for the first degradation self.blur_kernel_size = opt['blur_kernel_size'] self.kernel_list = opt['kernel_list'] self.kernel_prob = opt['kernel_prob'] self.blur_sigma = opt['blur_sigma'] self.betag_range = opt['betag_range'] self.betap_range = opt['betap_range'] self.sinc_prob = opt['sinc_prob'] # blur settings for the second degradation self.blur_kernel_size2 = opt['blur_kernel_size2'] self.kernel_list2 = opt['kernel_list2'] self.kernel_prob2 = opt['kernel_prob2'] self.blur_sigma2 = opt['blur_sigma2'] self.betag_range2 = opt['betag_range2'] self.betap_range2 = opt['betap_range2'] self.sinc_prob2 = opt['sinc_prob2'] # a final sinc filter self.final_sinc_prob = opt['final_sinc_prob'] self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect self.pulse_tensor[10, 10] = 1 def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # -------------------------------- Load gt images -------------------------------- # # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. gt_path = self.paths[index] # avoid errors caused by high latency in reading files retry = 3 while retry > 0: try: img_bytes = self.file_client.get(gt_path, 'gt') except Exception as e: logger = get_root_logger() logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') # change another file to read index = random.randint(0, self.__len__()) gt_path = self.paths[index] time.sleep(1) # sleep 1s for occasional server congestion else: break finally: retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) # -------------------- augmentation for training: flip, rotation -------------------- # img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) # crop or pad to 400: 400 is hard-coded. You may change it accordingly h, w = img_gt.shape[0:2] crop_pad_size = 400 # pad if h < crop_pad_size or w < crop_pad_size: pad_h = max(0, crop_pad_size - h) pad_w = max(0, crop_pad_size - w) img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) # crop if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: h, w = img_gt.shape[0:2] # randomly choose top and left coordinates top = random.randint(0, h - crop_pad_size) left = random.randint(0, w - crop_pad_size) img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] # ------------------------ Generate kernels (used in the first degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob']: # this sinc filter setting is for kernels ranging from [7, 21] if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], self.betag_range, self.betap_range, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------ Generate kernels (used in the second degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob2']: if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel2 = random_mixed_kernels( self.kernel_list2, self.kernel_prob2, kernel_size, self.blur_sigma2, self.blur_sigma2, [-math.pi, math.pi], self.betag_range2, self.betap_range2, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------------------- sinc kernel ------------------------------------- # if np.random.uniform() < self.opt['final_sinc_prob']: kernel_size = random.choice(self.kernel_range) omega_c = np.random.uniform(np.pi / 3, np.pi) sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) sinc_kernel = torch.FloatTensor(sinc_kernel) else: sinc_kernel = self.pulse_tensor # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] kernel = torch.FloatTensor(kernel) kernel2 = torch.FloatTensor(kernel2) return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} return return_d def __len__(self): return len(self.paths)
class FFHQDegradationDataset(data.Dataset): """FFHQ dataset for GFPGAN. It reads high resolution images, and then generate low-quality (LQ) images on-the-fly. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. io_backend (dict): IO backend type and other kwarg. mean (list | tuple): Image mean. std (list | tuple): Image std. use_hflip (bool): Whether to horizontally flip. Please see more options in the codes. """ def __init__(self, opt): super(FFHQDegradationDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] self.mean = opt['mean'] self.std = opt['std'] self.out_size = opt['out_size'] self.crop_components = opt.get('crop_components', False) # facial components self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions if self.crop_components: # load component list from a pre-process pth files self.components_list = torch.load(opt.get('component_path')) # file client (lmdb io backend) if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = self.gt_folder if not self.gt_folder.endswith('.lmdb'): raise ValueError( f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}" ) with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: # disk backend: scan file list from a folder self.paths = paths_from_folder(self.gt_folder) # degradation configurations self.blur_kernel_size = opt['blur_kernel_size'] self.kernel_list = opt['kernel_list'] self.kernel_prob = opt['kernel_prob'] self.blur_sigma = opt['blur_sigma'] self.downsample_range = opt['downsample_range'] self.noise_range = opt['noise_range'] self.jpeg_range = opt['jpeg_range'] # color jitter self.color_jitter_prob = opt.get('color_jitter_prob') self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') self.color_jitter_shift = opt.get('color_jitter_shift', 20) # to gray self.gray_prob = opt.get('gray_prob') logger = get_root_logger() logger.info( f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]' ) logger.info( f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]' ) logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') logger.info( f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') if self.color_jitter_prob is not None: logger.info( f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}' ) if self.gray_prob is not None: logger.info(f'Use random gray. Prob: {self.gray_prob}') self.color_jitter_shift /= 255. @staticmethod def color_jitter(img, shift): """jitter color: randomly jitter the RGB values, in numpy formats""" jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) img = img + jitter_val img = np.clip(img, 0, 1) return img @staticmethod def color_jitter_pt(img, brightness, contrast, saturation, hue): """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" fn_idx = torch.randperm(4) for fn_id in fn_idx: if fn_id == 0 and brightness is not None: brightness_factor = torch.tensor(1.0).uniform_( brightness[0], brightness[1]).item() img = adjust_brightness(img, brightness_factor) if fn_id == 1 and contrast is not None: contrast_factor = torch.tensor(1.0).uniform_( contrast[0], contrast[1]).item() img = adjust_contrast(img, contrast_factor) if fn_id == 2 and saturation is not None: saturation_factor = torch.tensor(1.0).uniform_( saturation[0], saturation[1]).item() img = adjust_saturation(img, saturation_factor) if fn_id == 3 and hue is not None: hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() img = adjust_hue(img, hue_factor) return img def get_component_coordinates(self, index, status): """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file""" components_bbox = self.components_list[f'{index:08d}'] if status[0]: # hflip # exchange right and left eye tmp = components_bbox['left_eye'] components_bbox['left_eye'] = components_bbox['right_eye'] components_bbox['right_eye'] = tmp # modify the width coordinate components_bbox['left_eye'][ 0] = self.out_size - components_bbox['left_eye'][0] components_bbox['right_eye'][ 0] = self.out_size - components_bbox['right_eye'][0] components_bbox['mouth'][ 0] = self.out_size - components_bbox['mouth'][0] # get coordinates locations = [] for part in ['left_eye', 'right_eye', 'mouth']: mean = components_bbox[part][0:2] half_len = components_bbox[part][2] if 'eye' in part: half_len *= self.eye_enlarge_ratio loc = np.hstack((mean - half_len + 1, mean + half_len)) loc = torch.from_numpy(loc).float() locations.append(loc) return locations def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. gt_path = self.paths[index] img_bytes = self.file_client.get(gt_path) img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) h, w, _ = img_gt.shape # get facial component coordinates if self.crop_components: locations = self.get_component_coordinates(index, status) loc_left_eye, loc_right_eye, loc_mouth = locations # ------------------------ generate lq image ------------------------ # # blur kernel = degradations.random_mixed_kernels(self.kernel_list, self.kernel_prob, self.blur_kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], noise_range=None) img_lq = cv2.filter2D(img_gt, -1, kernel) # downsample scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) # noise if self.noise_range is not None: img_lq = degradations.random_add_gaussian_noise( img_lq, self.noise_range) # jpeg compression if self.jpeg_range is not None: img_lq = degradations.random_add_jpg_compression( img_lq, self.jpeg_range) # resize to original size img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) # random color jitter (only for lq) if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): img_lq = self.color_jitter(img_lq, self.color_jitter_shift) # random to gray (only for lq) if self.gray_prob and np.random.uniform() < self.gray_prob: img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) if self.opt.get('gt_gray'): # whether convert GT to gray images img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # random color jitter (pytorch version) (only for lq) if self.color_jitter_pt_prob is not None and ( np.random.uniform() < self.color_jitter_pt_prob): brightness = self.opt.get('brightness', (0.5, 1.5)) contrast = self.opt.get('contrast', (0.5, 1.5)) saturation = self.opt.get('saturation', (0, 1.5)) hue = self.opt.get('hue', (-0.1, 0.1)) img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) # round and clip img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. # normalize normalize(img_gt, self.mean, self.std, inplace=True) normalize(img_lq, self.mean, self.std, inplace=True) if self.crop_components: return_dict = { 'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path, 'loc_left_eye': loc_left_eye, 'loc_right_eye': loc_right_eye, 'loc_mouth': loc_mouth } return return_dict else: return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} def __len__(self): return len(self.paths)
class RealESRGANDataset(data.Dataset): """Dataset used for Real-ESRGAN model: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. It loads gt (Ground-Truth) images, and augments them. It also generates blur kernels and sinc kernels for generating low-quality images. Note that the low-quality images are processed in tensors on GPUS for faster processing. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. meta_info (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. use_hflip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). Please see more options in the codes. """ def __init__(self, opt): super(RealESRGANDataset, self).__init__() self.opt = opt self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] # file client (lmdb io backend) if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.gt_folder] self.io_backend_opt['client_keys'] = ['gt'] if not self.gt_folder.endswith('.lmdb'): raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: self.paths = [line.split('.')[0] for line in fin] else: # disk backend with meta_info # Each line in the meta_info describes the relative path to an image with open(self.opt['meta_info']) as fin: paths = [line.strip().split(' ')[0] for line in fin] self.paths = [os.path.join(self.gt_folder, v) for v in paths] # blur settings for the first degradation self.blur_kernel_size = opt['blur_kernel_size'] self.kernel_list = opt['kernel_list'] self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability self.blur_sigma = opt['blur_sigma'] self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels self.betap_range = opt['betap_range'] # betap used in plateau blur kernels self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters # blur settings for the second degradation self.blur_kernel_size2 = opt['blur_kernel_size2'] self.kernel_list2 = opt['kernel_list2'] self.kernel_prob2 = opt['kernel_prob2'] self.blur_sigma2 = opt['blur_sigma2'] self.betag_range2 = opt['betag_range2'] self.betap_range2 = opt['betap_range2'] self.sinc_prob2 = opt['sinc_prob2'] # a final sinc filter self.final_sinc_prob = opt['final_sinc_prob'] self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 # TODO: kernel range is now hard-coded, should be in the configure file self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect self.pulse_tensor[10, 10] = 1 def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # -------------------------------- Load gt images -------------------------------- # # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. gt_path = self.paths[index] # avoid errors caused by high latency in reading files retry = 3 while retry > 0: try: img_bytes = self.file_client.get(gt_path, 'gt') except (IOError, OSError) as e: logger = get_root_logger() logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') # change another file to read index = random.randint(0, self.__len__()) gt_path = self.paths[index] time.sleep(1) # sleep 1s for occasional server congestion else: break finally: retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) # -------------------- Do augmentation for training: flip, rotation -------------------- # img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) # crop or pad to 400 # TODO: 400 is hard-coded. You may change it accordingly h, w = img_gt.shape[0:2] crop_pad_size = 400 # pad if h < crop_pad_size or w < crop_pad_size: pad_h = max(0, crop_pad_size - h) pad_w = max(0, crop_pad_size - w) img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) # crop if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: h, w = img_gt.shape[0:2] # randomly choose top and left coordinates top = random.randint(0, h - crop_pad_size) left = random.randint(0, w - crop_pad_size) img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] # ------------------------ Generate kernels (used in the first degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob']: # this sinc filter setting is for kernels ranging from [7, 21] if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], self.betag_range, self.betap_range, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------ Generate kernels (used in the second degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob2']: if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel2 = random_mixed_kernels( self.kernel_list2, self.kernel_prob2, kernel_size, self.blur_sigma2, self.blur_sigma2, [-math.pi, math.pi], self.betag_range2, self.betap_range2, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------------------- the final sinc kernel ------------------------------------- # if np.random.uniform() < self.opt['final_sinc_prob']: kernel_size = random.choice(self.kernel_range) omega_c = np.random.uniform(np.pi / 3, np.pi) sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) sinc_kernel = torch.FloatTensor(sinc_kernel) else: sinc_kernel = self.pulse_tensor # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] kernel = torch.FloatTensor(kernel) kernel2 = torch.FloatTensor(kernel2) return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} return return_d def __len__(self): return len(self.paths)
class Vimeo90KRecurrentDataset(Vimeo90KDataset): def __init__(self, opt): super(Vimeo90KRecurrentDataset, self).__init__(opt) self.flip_sequence = opt['flip_sequence'] self.neighbor_list = [1, 2, 3, 4, 5, 6, 7] def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # random reverse if self.random_reverse and random.random() < 0.5: self.neighbor_list.reverse() scale = self.opt['scale'] gt_size = self.opt['gt_size'] key = self.keys[index] clip, seq = key.split('/') # key example: 00001/0001 # get the neighboring LQ and GT frames img_lqs = [] img_gts = [] for neighbor in self.neighbor_list: if self.is_lmdb: img_lq_path = f'{clip}/{seq}/im{neighbor}' img_gt_path = f'{clip}/{seq}/im{neighbor}' else: img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png' # LQ img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # GT img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) img_gts.append(img_gt) # randomly crop img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) # augmentation - flip, rotate img_lqs.extend(img_gts) img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[:7], dim=0) img_gts = torch.stack(img_results[7:], dim=0) if self.flip_sequence: # flip the sequence: 7 frames to 14 frames img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0) img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0) # img_lqs: (t, c, h, w) # img_gt: (c, h, w) # key: str return {'lq': img_lqs, 'gt': img_gts, 'key': key} def __len__(self): return len(self.keys)
class PairedImgPSFNpyDataset(data.Dataset): """Paired image dataset with its corresponding PSF. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. There are three modes: 1. 'lmdb': Use lmdb files. If opt['io_backend'] == lmdb. 2. 'meta_info_file': Use meta information file to generate paths. If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 3. 'folder': Scan folders to generate paths. The rest. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. meta_info_file (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. Default: '{}'. gt_size (int): Cropped patched size for gt patches. use_flip (bool): Use horizontal and vertical flips. use_rot (bool): Use rotation (use transposing h and w for implementation). scale (bool): Scale, which will be added automatically. phase (str): 'train' or 'val'. """ def __init__(self, opt): super().__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] if 'filename_tmpl' in opt: self.filename_tmpl = opt['filename_tmpl'] else: self.filename_tmpl = '{}' self.paths = [] for folder_name, folder_opt in opt['folders'].items(): assert folder_opt['meta_info_file'] is not None, ( 'Only support loading image\ and PSF by meta info file.') gt_folder, lq_folder = folder_opt['dataroot_gt'], folder_opt[ 'dataroot_lq'] self.paths += paired_paths_PSF_from_meta_info_file( [lq_folder, gt_folder], ['lq', 'gt'], folder_opt['meta_info_file'], self.filename_tmpl) def _tonemap(self, x, type='simple'): if type == 'mu_law': norm_x = x / x.max() mapped_x = np.log(1 + 10000 * norm_x) / np.log(1 + 10000) elif type == 'simple': mapped_x = x / (x + 0.25) elif type == 'same': mapped_x = x else: raise NotImplementedError( 'tone mapping type [{:s}] is not recognized.'.format(type)) return mapped_x def _expand_dim(self, x): # expand dimemsion if images are gray. if x.ndim == 2: return x[:, :, None] else: return x def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] lq_map_type = self.opt['lq_map_type'] gt_map_type = self.opt['gt_map_type'] crop_scale = self.opt.get('crop_scale', None) # Load gt and lq images. Dimension order: HWC; channel order: RGGB; # HDR image range: [0, +inf], float32. gt_path = self.paths[index]['gt_path'] lq_path = self.paths[index]['lq_path'] psf_path = self.paths[index]['psf_path'] img_gt = self.file_client.get(gt_path) img_lq = self.file_client.get(lq_path) psf_code = self.file_client.get(psf_path) # tone mapping img_lq = self._tonemap(img_lq, type=lq_map_type) img_gt = self._tonemap(img_gt, type=gt_map_type) # expand dimension img_gt = self._expand_dim(img_gt) img_lq = self._expand_dim(img_lq) # Rescale for random crop if crop_scale != None: h, w, _ = img_lq.shape img_lq = cv2.resize(img_lq, (int(w * crop_scale), int(h * crop_scale)), interpolation=cv2.INTER_LINEAR) img_gt = cv2.resize(img_gt, (int(w * crop_scale), int(h * crop_scale)), interpolation=cv2.INTER_LINEAR) # augmentation if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = totensor([img_gt, img_lq], bgr2rgb=False, float32=True) psf_code = torch.from_numpy(psf_code)[..., None, None] return { 'lq': img_lq, 'gt': img_gt, 'psf_code': psf_code, 'lq_path': lq_path, 'gt_path': gt_path, 'psf_path': psf_path, } def __len__(self): return len(self.paths)
class RealESRGANPairedDataset(data.Dataset): """Paired image dataset for image restoration. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. There are three modes: 1. 'lmdb': Use lmdb files. If opt['io_backend'] == lmdb. 2. 'meta_info': Use meta information file to generate paths. If opt['io_backend'] != lmdb and opt['meta_info'] is not None. 3. 'folder': Scan folders to generate paths. The rest. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. meta_info (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. Default: '{}'. gt_size (int): Cropped patched size for gt patches. use_hflip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. phase (str): 'train' or 'val'. """ def __init__(self, opt): super(RealESRGANPairedDataset, self).__init__() self.opt = opt self.file_client = None self.io_backend_opt = opt['io_backend'] # mean and std for normalizing the input images self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' # file client (lmdb io backend) if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] self.io_backend_opt['client_keys'] = ['lq', 'gt'] self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: # disk backend with meta_info # Each line in the meta_info describes the relative path to an image with open(self.opt['meta_info']) as fin: paths = [line.strip() for line in fin] self.paths = [] for path in paths: gt_path, lq_path = path.split(', ') gt_path = os.path.join(self.gt_folder, gt_path) lq_path = os.path.join(self.lq_folder, lq_path) self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) else: # disk backend # it will scan the whole folder to get meta info # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) normalize(img_gt, self.mean, self.std, inplace=True) return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} def __len__(self): return len(self.paths)
save_dict = {} for item_idx, item in enumerate(json_data.values()): print(f'\r{item_idx} / {len(json_data)}, {item["image"]["file_path"]} ', end='', flush=True) # parse landmarks lm = np.array(item['image']['face_landmarks']) lm = lm * scale item_dict = {} # get image if save_img: img_bytes = file_client.get(paths[item_idx]) img = imfrombytes(img_bytes, float32=True) # get landmarks for each component map_left_eye = list(range(36, 42)) map_right_eye = list(range(42, 48)) map_mouth = list(range(48, 68)) # eye_left mean_left_eye = np.mean(lm[map_left_eye], 0) # (x, y) half_len_left_eye = np.max( (np.max(np.max(lm[map_left_eye], 0) - np.min(lm[map_left_eye], 0)) / 2, 16)) item_dict['left_eye'] = [ mean_left_eye[0], mean_left_eye[1], half_len_left_eye ]
class Vimeo90KDataset(data.Dataset): """Vimeo90K dataset for training. The keys are generated from a meta info txt file. basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt Each line contains: 1. clip name; 2. frame number; 3. image shape, seperated by a white space. Examples: 00001/0001 7 (256,448,3) 00001/0002 7 (256,448,3) Key examples: "00001/0001" GT (gt): Ground-Truth; LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. The neighboring frame list for different num_frame: num_frame | frame list 1 | 4 3 | 3,4,5 5 | 2,3,4,5,6 7 | 1,2,3,4,5,6,7 Args: opt (dict): Config for train dataset. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. meta_info_file (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. num_frame (int): Window size for input frames. gt_size (int): Cropped patched size for gt patches. random_reverse (bool): Random reverse input frames. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. """ def __init__(self, opt): super(Vimeo90KDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( opt['dataroot_lq']) with open(opt['meta_info_file'], 'r') as fin: self.keys = [line.split(' ')[0] for line in fin] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] self.io_backend_opt['client_keys'] = ['lq', 'gt'] # indices of input images self.neighbor_list = [ i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) ] # temporal augmentation configs self.random_reverse = opt['random_reverse'] logger = get_root_logger() logger.info(f'Random reverse is {self.random_reverse}.') def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # random reverse if self.random_reverse and random.random() < 0.5: self.neighbor_list.reverse() scale = self.opt['scale'] gt_size = self.opt['gt_size'] key = self.keys[index] clip, seq = key.split('/') # key example: 00001/0001 # get the GT frame (im4.png) if self.is_lmdb: img_gt_path = f'{key}/im4' else: img_gt_path = self.gt_root / clip / seq / 'im4.png' img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. # get the neighboring LQ frames img_lqs = [] for neighbor in self.neighbor_list: if self.is_lmdb: img_lq_path = f'{clip}/{seq}/im{neighbor}' else: img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. img_lqs.append(img_lq) # randomly crop img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) # augmentation - flip, rotate img_lqs.append(img_gt) img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = totensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] # img_lqs: (t, c, h, w) # img_gt: (c, h, w) # key: str return {'lq': img_lqs, 'gt': img_gt, 'key': key} def __len__(self): return len(self.keys)
class PairedImageDataset(data.Dataset): """Paired image dataset for image restoration. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. There are three modes: 1. 'lmdb': Use lmdb files. If opt['io_backend'] == lmdb. 2. 'meta_info_file': Use meta information file to generate paths. If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 3. 'folder': Scan folders to generate paths. The rest. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_lq (str): Data root path for lq. meta_info_file (str): Path for meta information file. io_backend (dict): IO backend type and other kwarg. filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. Default: '{}'. gt_size (int): Cropped patched size for gt patches. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. phase (str): 'train' or 'val'. """ def __init__(self, opt): super(PairedImageDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] if 'filename_tmpl' in opt: self.filename_tmpl = opt['filename_tmpl'] else: self.filename_tmpl = '{}' if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] self.io_backend_opt['client_keys'] = ['lq', 'gt'] self.paths = paired_paths_from_lmdb( [self.lq_folder, self.gt_folder], ['lq', 'gt']) elif 'meta_info_file' in self.opt and self.opt[ 'meta_info_file'] is not None: self.paths = paired_paths_from_meta_info_file( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.opt['meta_info_file'], self.filename_tmpl) else: self.paths = paired_paths_from_folder( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) normalize(img_gt, self.mean, self.std, inplace=True) return { 'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path } def __len__(self): return len(self.paths)
class PairedImagePyramidDataset(data.Dataset): def __init__(self, opt): super(PairedImagePyramidDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] if 'filename_tmpl' in opt: self.filename_tmpl = opt['filename_tmpl'] else: self.filename_tmpl = '{}' if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] self.io_backend_opt['client_keys'] = ['lq', 'gt'] self.paths = paired_paths_from_lmdb( [self.lq_folder, self.gt_folder], ['lq', 'gt']) elif 'meta_info_file' in self.opt and self.opt[ 'meta_info_file'] is not None: self.paths = paired_paths_from_meta_info_file( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.opt['meta_info_file'], self.filename_tmpl) else: self.paths = paired_paths_from_folder( [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) # to create pyramid for img_gt img_re1 = cv2.resize(cv2.resize(img_gt, (gt_size // 2, gt_size // 2), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) img_re2 = cv2.resize(cv2.resize(img_gt, (gt_size // 4, gt_size // 4), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) img_re3 = cv2.resize(cv2.resize(img_gt, (gt_size // 8, gt_size // 8), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) img_re4 = cv2.resize(cv2.resize(img_gt, (gt_size // 16, gt_size // 16), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq, img_re1, img_re2, img_re3, img_re4 = img2tensor( [img_gt, img_lq, img_re1, img_re2, img_re3, img_re4], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True, align_corners=True) normalize(img_gt, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re1, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re2, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re3, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re4, self.mean, self.std, inplace=True, align_corners=True) return { 'lq': img_lq, 'gt': torch.cat((img_gt, img_re1, img_re2, img_re3, img_re4), 0), 'lq_path': lq_path, 'gt_path': gt_path } elif self.opt['phase'] == 'val': h, w, c = img_lq.shape if h % 16 != 0 or w % 16 != 0: h = h // 16 * 16 w = w // 16 * 16 img_lq = cv2.resize(img_lq, (h, w), interpolation=cv2.INTER_LINEAR) img_gt = cv2.resize(img_gt, (2 * h, 2 * w), interpolation=cv2.INTER_LINEAR) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True, align_corners=True) normalize(img_gt, self.mean, self.std, inplace=True, align_corners=True) return { 'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path } def __len__(self): return len(self.paths)