def test(img_path, img_path2, crop_border, test_y_channel=False): img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) # --------------------- Numpy --------------------- psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') # --------------------- PyTorch (CPU) --------------------- img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) print( f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}' ) # --------------------- PyTorch (GPU) --------------------- img = img.cuda() img2 = img2.cuda() psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) print( f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}' ) psnr_pth = calculate_psnr_pt(torch.repeat_interleave(img, 2, dim=0), torch.repeat_interleave(img2, 2, dim=0), crop_border=crop_border, test_y_channel=test_y_channel) ssim_pth = calculate_ssim_pt(torch.repeat_interleave(img, 2, dim=0), torch.repeat_interleave(img2, 2, dim=0), crop_border=crop_border, test_y_channel=test_y_channel) print( f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}')
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) normalize(img_gt, self.mean, self.std, inplace=True) return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): """Read a sequence of images from a given folder path. Args: path (list[str] | str): List of image paths or image folder path. require_mod_crop (bool): Require mod crop for each image. Default: False. scale (int): Scale factor for mod_crop. Default: 1. return_imgname(bool): Whether return image names. Default False. Returns: Tensor: size (t, c, h, w), RGB, [0, 1]. list[str]: Returned image name list. """ if isinstance(path, list): img_paths = path else: img_paths = sorted(list(scandir(path, full_path=True))) imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] if require_mod_crop: imgs = [mod_crop(img, scale) for img in imgs] imgs = img2tensor(imgs, bgr2rgb=True, float32=True) imgs = torch.stack(imgs, dim=0) if return_imgname: imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths] return imgs, imgnames else: return imgs
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image gt_path = self.paths[index] # avoid errors caused by high latency in reading files retry = 3 while retry > 0: try: img_bytes = self.file_client.get(gt_path) except Exception as e: logger = get_root_logger() logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') # change another file to read index = random.randint(0, self.__len__()) gt_path = self.paths[index] time.sleep(1) # sleep 1s for occasional server congestion else: break finally: retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path}
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True): self.face_helper.clean_all() if has_aligned: # the inputs are already aligned img = cv2.resize(img, (512, 512)) self.face_helper.cropped_faces = [img] else: self.face_helper.read_image(img) # get face landmarks for each face self.face_helper.get_face_landmarks_5( only_center_face=only_center_face, eye_dist_threshold=5) # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. # align and warp each face self.face_helper.align_warp_face() # face restoration for cropped_face in self.face_helper.cropped_faces: # prepare data cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) try: output = self.gfpgan(cropped_face_t, return_rgb=False)[0] # convert to image restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) except RuntimeError as error: print(f'\tFailed inference for GFPGAN: {error}.') restored_face = cropped_face restored_face = restored_face.astype('uint8') self.face_helper.add_restored_face(restored_face) if not has_aligned and paste_back: # upsample the background if self.bg_upsampler is not None: # Now only support RealESRGAN for upsampling background bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] else: bg_img = None self.face_helper.get_inverse_affine(None) # paste each restored face to the input image restored_img = self.face_helper.paste_faces_to_input_image( upsample_img=bg_img) return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img else: return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # random reverse if self.random_reverse and random.random() < 0.5: self.neighbor_list.reverse() scale = self.opt['scale'] gt_size = self.opt['gt_size'] key = self.keys[index] clip, seq = key.split('/') # key example: 00001/0001 # get the neighboring LQ and GT frames img_lqs = [] img_gts = [] for neighbor in self.neighbor_list: if self.is_lmdb: img_lq_path = f'{clip}/{seq}/im{neighbor}' img_gt_path = f'{clip}/{seq}/im{neighbor}' else: img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png' # LQ img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # GT img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) img_gts.append(img_gt) # randomly crop img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) # augmentation - flip, rotate img_lqs.extend(img_gts) img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[:7], dim=0) img_gts = torch.stack(img_results[7:], dim=0) if self.flip_sequence: # flip the sequence: 7 frames to 14 frames img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0) img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0) # img_lqs: (t, c, h, w) # img_gt: (c, h, w) # key: str return {'lq': img_lqs, 'gt': img_gts, 'key': key}
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) # color space transform if 'color' in self.opt and self.opt['color'] == 'y': img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets # TODO: It is better to update the datasets, rather than force to crop if self.opt['phase'] != 'train': img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) normalize(img_gt, self.mean, self.std, inplace=True) return { 'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path }
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load lq image lq_path = self.paths[index] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) return {'lq': img_lq, 'lq_path': lq_path}
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image gt_path = self.paths[index] img_bytes = self.file_client.get(gt_path) img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path}
def main(): # Configurations # ------------------------------------------------------------------------- folder_gt = 'datasets/celeba/celeba_512_validation' folder_restored = 'datasets/celeba/celeba_512_validation_lq' # crop_border = 4 suffix = '' # ------------------------------------------------------------------------- loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1] lpips_all = [] img_list = sorted(glob.glob(osp.join(folder_gt, '*'))) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] for i, img_path in enumerate(img_list): basename, ext = osp.splitext(osp.basename(img_path)) img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype( np.float32) / 255. img_restored = cv2.imread( osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True) # norm to [-1, 1] normalize(img_gt, mean, std, inplace=True) normalize(img_restored, mean, std, inplace=True) # calculate lpips lpips_val = loss_fn_vgg( img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda()) print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.') lpips_all.append(lpips_val) print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) # to create pyramid for img_gt img_re1 = cv2.resize(cv2.resize(img_gt, (gt_size // 2, gt_size // 2), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) img_re2 = cv2.resize(cv2.resize(img_gt, (gt_size // 4, gt_size // 4), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) img_re3 = cv2.resize(cv2.resize(img_gt, (gt_size // 8, gt_size // 8), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) img_re4 = cv2.resize(cv2.resize(img_gt, (gt_size // 16, gt_size // 16), interpolation=cv2.INTER_LINEAR), (gt_size, gt_size), interpolation=cv2.INTER_LINEAR) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq, img_re1, img_re2, img_re3, img_re4 = img2tensor( [img_gt, img_lq, img_re1, img_re2, img_re3, img_re4], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True, align_corners=True) normalize(img_gt, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re1, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re2, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re3, self.mean, self.std, inplace=True, align_corners=True) normalize(img_re4, self.mean, self.std, inplace=True, align_corners=True) return { 'lq': img_lq, 'gt': torch.cat((img_gt, img_re1, img_re2, img_re3, img_re4), 0), 'lq_path': lq_path, 'gt_path': gt_path } elif self.opt['phase'] == 'val': h, w, c = img_lq.shape if h % 16 != 0 or w % 16 != 0: h = h // 16 * 16 w = w // 16 * 16 img_lq = cv2.resize(img_lq, (h, w), interpolation=cv2.INTER_LINEAR) img_gt = cv2.resize(img_gt, (2 * h, 2 * w), interpolation=cv2.INTER_LINEAR) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True, align_corners=True) normalize(img_gt, self.mean, self.std, inplace=True, align_corners=True) return { 'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path }
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] gt_size = self.opt['gt_size'] key = self.keys[index] clip_name, frame_name = key.split('/') # key example: 000/00000000 center_frame_idx = int(frame_name) # determine the neighboring frames interval = random.choice(self.interval_list) # ensure not exceeding the borders start_frame_idx = center_frame_idx - self.num_half_frames * interval end_frame_idx = center_frame_idx + self.num_half_frames * interval # each clip has 100 frames starting from 0 to 99 while (start_frame_idx < 0) or (end_frame_idx > 99): center_frame_idx = random.randint(0, 99) start_frame_idx = (center_frame_idx - self.num_half_frames * interval) end_frame_idx = center_frame_idx + self.num_half_frames * interval frame_name = f'{center_frame_idx:08d}' neighbor_list = list( range(center_frame_idx - self.num_half_frames * interval, center_frame_idx + self.num_half_frames * interval + 1, interval)) # random reverse if self.random_reverse and random.random() < 0.5: neighbor_list.reverse() assert len(neighbor_list) == self.num_frame, ( f'Wrong length of neighbor list: {len(neighbor_list)}') # get the GT frame (as the center frame) if self.is_lmdb: img_gt_path = f'{clip_name}/{frame_name}' else: img_gt_path = self.gt_root / clip_name / f'{frame_name}.png' img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) # get the neighboring LQ frames img_lqs = [] for neighbor in neighbor_list: if self.is_lmdb: img_lq_path = f'{clip_name}/{neighbor:08d}' else: img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # get flows if self.flow_root is not None: img_flows = [] # read previous flows for i in range(self.num_half_frames, 0, -1): if self.is_lmdb: flow_path = f'{clip_name}/{frame_name}_p{i}' else: flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) # read next flows for i in range(1, self.num_half_frames + 1): if self.is_lmdb: flow_path = f'{clip_name}/{frame_name}_n{i}' else: flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) # for random crop, here, img_flows and img_lqs have the same # spatial size img_lqs.extend(img_flows) # randomly crop img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) if self.flow_root is not None: img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self. num_frame:] # augmentation - flip, rotate img_lqs.append(img_gt) if self.flow_root is not None: img_results, img_flows = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot'], img_flows) else: img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] if self.flow_root is not None: img_flows = img2tensor(img_flows) # add the zero center flow img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) img_flows = torch.stack(img_flows, dim=0) # img_lqs: (t, c, h, w) # img_flows: (t, 2, h, w) # img_gt: (c, h, w) # key: str if self.flow_root is not None: return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key} else: return {'lq': img_lqs, 'gt': img_gt, 'key': key}
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient( self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] gt_size = self.opt.get('gt_size', None) key = self.keys[index] clip_name, frame_name = key.split('/') # key example: 000/00000000 center_frame_idx = int(frame_name) # determine the frameing frames interval = random.choice(self.interval_list) # ensure not exceeding the borders start_frame_idx = center_frame_idx - self.num_half_frames * interval end_frame_idx = start_frame_idx + (self.num_frame - 1) * interval # each clip has 100 frames starting from 0 to 99 while (start_frame_idx < 0) or (end_frame_idx > 99): center_frame_idx = random.randint( self.num_half_frames * interval, 99 - self.num_half_frames *interval) start_frame_idx = (center_frame_idx - self.num_half_frames * interval) end_frame_idx = start_frame_idx + (self.num_frame - 1) * interval frame_name = f'{center_frame_idx:08d}' frame_list = list( range(start_frame_idx, end_frame_idx + 1, interval)) # random reverse if self.random_reverse and random.random() < 0.5: frame_list.reverse() assert len(frame_list) == self.num_frame, ( f'Wrong length of frame list: {len(frame_list)}') # get the GT frame (as the center frame) img_gts = [] for frame in frame_list: if self.is_lmdb: img_gt_path = f'{clip_name}/{frame:08d}' else: img_gt_path = self.gt_root / clip_name / f'{frame:08d}.png' img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) img_gts.append(img_gt) # get the LQ frames img_lqs = [] for frame in frame_list: if self.is_lmdb: img_lq_path = f'{clip_name}/{frame:08d}' else: img_lq_path = self.lq_root / clip_name / f'{frame:08d}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # randomly crop if self.is_train: img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, clip_name) # augmentation - flip, rotate img_lqs.extend(img_gts) if self.is_train: img_lqs = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = img2tensor(img_lqs) img_lqs = torch.stack(img_results[:self.num_frame], dim=0) img_gts = torch.stack(img_results[self.num_frame:], dim=0) # img_lqs: (t, c, h, w) # img_gt: (t, c, h, w) # key: str return {'lq': img_lqs, 'gt': img_gts, 'key': key, 'frame_list': frame_list}
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] gt_size = self.opt['gt_size'] key = self.keys[index] clip_name, frame_name = key.split('/') # key example: 000/00000000 # determine the neighboring frames interval = random.choice(self.interval_list) # ensure not exceeding the borders start_frame_idx = int(frame_name) if start_frame_idx > 100 - self.num_frame * interval: start_frame_idx = random.randint(0, 100 - self.num_frame * interval) end_frame_idx = start_frame_idx + self.num_frame * interval neighbor_list = list(range(start_frame_idx, end_frame_idx, interval)) # random reverse if self.random_reverse and random.random() < 0.5: neighbor_list.reverse() # get the neighboring LQ and GT frames img_lqs = [] img_gts = [] for neighbor in neighbor_list: if self.is_lmdb: img_lq_path = f'{clip_name}/{neighbor:08d}' img_gt_path = f'{clip_name}/{neighbor:08d}' else: img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png' # get LQ img_bytes = self.file_client.get(img_lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # get GT img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) img_gts.append(img_gt) # randomly crop img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) # augmentation - flip, rotate img_lqs.extend(img_gts) img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) img_results = img2tensor(img_results) img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0) img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0) # img_lqs: (t, c, h, w) # img_gts: (t, c, h, w) # key: str return {'lq': img_lqs, 'gt': img_gts, 'key': key}
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # -------------------------------- Load gt images -------------------------------- # # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. gt_path = self.paths[index] # avoid errors caused by high latency in reading files retry = 3 while retry > 0: try: img_bytes = self.file_client.get(gt_path, 'gt') except (IOError, OSError) as e: logger = get_root_logger() logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') # change another file to read index = random.randint(0, self.__len__()) gt_path = self.paths[index] time.sleep(1) # sleep 1s for occasional server congestion else: break finally: retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) # -------------------- Do augmentation for training: flip, rotation -------------------- # img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) # crop or pad to 400 # TODO: 400 is hard-coded. You may change it accordingly h, w = img_gt.shape[0:2] crop_pad_size = 400 # pad if h < crop_pad_size or w < crop_pad_size: pad_h = max(0, crop_pad_size - h) pad_w = max(0, crop_pad_size - w) img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) # crop if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: h, w = img_gt.shape[0:2] # randomly choose top and left coordinates top = random.randint(0, h - crop_pad_size) left = random.randint(0, w - crop_pad_size) img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] # ------------------------ Generate kernels (used in the first degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob']: # this sinc filter setting is for kernels ranging from [7, 21] if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], self.betag_range, self.betap_range, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------ Generate kernels (used in the second degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob2']: if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel2 = random_mixed_kernels( self.kernel_list2, self.kernel_prob2, kernel_size, self.blur_sigma2, self.blur_sigma2, [-math.pi, math.pi], self.betag_range2, self.betap_range2, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------------------- the final sinc kernel ------------------------------------- # if np.random.uniform() < self.opt['final_sinc_prob']: kernel_size = random.choice(self.kernel_range) omega_c = np.random.uniform(np.pi / 3, np.pi) sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) sinc_kernel = torch.FloatTensor(sinc_kernel) else: sinc_kernel = self.pulse_tensor # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] kernel = torch.FloatTensor(kernel) kernel2 = torch.FloatTensor(kernel2) return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} return return_d
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. gt_path = self.paths[index] img_bytes = self.file_client.get(gt_path) img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) h, w, _ = img_gt.shape # get facial component coordinates if self.crop_components: locations = self.get_component_coordinates(index, status) loc_left_eye, loc_right_eye, loc_mouth = locations # ------------------------ generate lq image ------------------------ # # blur kernel = degradations.random_mixed_kernels(self.kernel_list, self.kernel_prob, self.blur_kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], noise_range=None) img_lq = cv2.filter2D(img_gt, -1, kernel) # downsample scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) # noise if self.noise_range is not None: img_lq = degradations.random_add_gaussian_noise( img_lq, self.noise_range) # jpeg compression if self.jpeg_range is not None: img_lq = degradations.random_add_jpg_compression( img_lq, self.jpeg_range) # resize to original size img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) # random color jitter (only for lq) if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): img_lq = self.color_jitter(img_lq, self.color_jitter_shift) # random to gray (only for lq) if self.gray_prob and np.random.uniform() < self.gray_prob: img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) if self.opt.get('gt_gray'): # whether convert GT to gray images img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # random color jitter (pytorch version) (only for lq) if self.color_jitter_pt_prob is not None and ( np.random.uniform() < self.color_jitter_pt_prob): brightness = self.opt.get('brightness', (0.5, 1.5)) contrast = self.opt.get('contrast', (0.5, 1.5)) saturation = self.opt.get('saturation', (0, 1.5)) hue = self.opt.get('hue', (-0.1, 0.1)) img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) # round and clip img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. # normalize normalize(img_gt, self.mean, self.std, inplace=True) normalize(img_lq, self.mean, self.std, inplace=True) if self.crop_components: return_dict = { 'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path, 'loc_left_eye': loc_left_eye, 'loc_right_eye': loc_right_eye, 'loc_mouth': loc_mouth } return return_dict else: return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) y, cb, cr = self.compress(x, factor=factor) recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) recovered = recovered[:, :, 0:h, 0:w] return recovered if __name__ == '__main__': import cv2 from basicsr.utils import img2tensor, tensor2img img_gt = cv2.imread('test.png') / 255. # -------------- cv2 -------------- # encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20] _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param) img_lq = np.float32(cv2.imdecode(encimg, 1)) cv2.imwrite('cv2_JPEG_20.png', img_lq) # -------------- DiffJPEG -------------- # jpeger = DiffJPEG(differentiable=False).cuda() img_gt = img2tensor(img_gt) img_gt = torch.stack([img_gt, img_gt]).cuda() quality = img_gt.new_tensor([20, 40]) out = jpeger(img_gt, quality=quality) cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0])) cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))