def main(args): """Calculate PSNR and SSIM for images. """ psnr_all = [] ssim_all = [] img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True))) img_list_restored = sorted(list(scandir(args.restored, recursive=True, full_path=True))) if args.test_y_channel: print('Testing Y channel.') else: print('Testing RGB channels.') for i, img_path in enumerate(img_list_gt): basename, ext = osp.splitext(osp.basename(img_path)) img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. if args.suffix == '': img_path_restored = img_list_restored[i] else: img_path_restored = osp.join(args.restored, basename + args.suffix + ext) img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. if args.correct_mean_var: mean_l = [] std_l = [] for j in range(3): mean_l.append(np.mean(img_gt[:, :, j])) std_l.append(np.std(img_gt[:, :, j])) for j in range(3): # correct twice mean = np.mean(img_restored[:, :, j]) img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j] std = np.std(img_restored[:, :, j]) img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] mean = np.mean(img_restored[:, :, j]) img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j] std = np.std(img_restored[:, :, j]) img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] if args.test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3: img_gt = bgr2ycbcr(img_gt, y_only=True) img_restored = bgr2ycbcr(img_restored, y_only=True) # calculate PSNR and SSIM psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC') ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC') print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') psnr_all.append(psnr) ssim_all.append(ssim) print(args.gt) print(args.restored) print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): """Read a sequence of images from a given folder path. Args: path (list[str] | str): List of image paths or image folder path. require_mod_crop (bool): Require mod crop for each image. Default: False. scale (int): Scale factor for mod_crop. Default: 1. return_imgname(bool): Whether return image names. Default False. Returns: Tensor: size (t, c, h, w), RGB, [0, 1]. list[str]: Returned image name list. """ if isinstance(path, list): img_paths = path else: img_paths = sorted(list(scandir(path, full_path=True))) imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] if require_mod_crop: imgs = [mod_crop(img, scale) for img in imgs] imgs = img2tensor(imgs, bgr2rgb=True, float32=True) imgs = torch.stack(imgs, dim=0) if return_imgname: imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths] return imgs, imgnames else: return imgs
def load_resume_state(opt): resume_state_path = None if opt['auto_resume']: state_path = osp.join('experiments', opt['name'], 'training_states') if osp.isdir(state_path): states = list( scandir(state_path, suffix='state', recursive=False, full_path=False)) if len(states) != 0: states = [float(v.split('.state')[0]) for v in states] resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') opt['path']['resume_state'] = resume_state_path else: if opt['path'].get('resume_state'): resume_state_path = opt['path']['resume_state'] if resume_state_path is None: resume_state = None else: device_id = torch.cuda.current_device() resume_state = torch.load( resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) check_resume(opt, resume_state['iter']) return resume_state
def extract_subimages(opt): """Crop images to subimages. Args: opt (dict): Configuration dict. It contains: input_folder (str): Path to the input folder. save_folder (str): Path to save folder. n_thread (int): Thread number. """ input_folder = opt['input_folder'] save_folder = opt['save_folder'] if not osp.exists(save_folder): os.makedirs(save_folder) print(f'mkdir {save_folder} ...') else: print(f'Folder {save_folder} already exists. Exit.') sys.exit(1) img_list = list(scandir(input_folder, full_path=True)) pbar = tqdm(total=len(img_list), unit='image', desc='Extract') pool = Pool(opt['n_thread']) for path in img_list: pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) pool.close() pool.join() pbar.close() print('All processes done.')
def main(): """Calculate PSNR and SSIM for images. Configurations: folder_gt (str): Path to gt (Ground-Truth). folder_restored (str): Path to restored images. crop_border (int): Crop border for each side. suffix (str): Suffix for restored images. test_y_channel (bool): If True, test Y channel (In MatLab YCbCr format) If False, test RGB channels. """ # Configurations # ------------------------------------------------------------------------- folder_gt = 'datasets/val_set14/Set14' folder_restored = 'results/exp/visualization/val_set14' crop_border = 4 suffix = '_expname' test_y_channel = False # ------------------------------------------------------------------------- psnr_all = [] ssim_all = [] img_list = sorted(scandir(folder_gt, recursive=True, full_path=True)) if test_y_channel: print('Testing Y channel.') else: print('Testing RGB channels.') for i, img_path in enumerate(img_list): basename, ext = osp.splitext(osp.basename(img_path)) img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype( np.float32) / 255. img_restored = cv2.imread( osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. if test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3: img_gt = bgr2ycbcr(img_gt, y_only=True) img_restored = bgr2ycbcr(img_restored, y_only=True) # calculate PSNR and SSIM psnr = calculate_psnr( img_gt * 255, img_restored * 255, crop_border=crop_border, input_order='HWC') ssim = calculate_ssim( img_gt * 255, img_restored * 255, crop_border=crop_border, input_order='HWC') print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, ' f'\tSSIM: {ssim:.6f}') psnr_all.append(psnr) ssim_all.append(ssim) print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, ' f'SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
def __init__(self, opt): super(VidTestDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] assert self.io_backend_opt[ 'type'] != 'lmdb', 'No need to use lmdb during validation/test.' logger = get_root_logger() logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') self.data_info = { 'lq_path': [], 'gt_path': [], 'clip_name': [], 'max_idx': [], } self.lq_frames, self.gt_frames = {}, {} self.clip_list = os.listdir(osp.abspath(self.gt_root)) self.clip_list.sort() for clip_name in self.clip_list: lq_frames_path = osp.join(self.lq_root, clip_name) lq_frames_path = sorted( list(scandir(lq_frames_path, full_path=True))) gt_frames_path = osp.join(self.gt_root, clip_name) gt_frames_path = sorted( list(scandir(gt_frames_path, full_path=True))) max_idx = len(lq_frames_path) assert max_idx == len(lq_frames_path), ( f'Different number of images in lq ({max_idx})' f' and gt folders ({len(gt_frames_path)})') self.data_info['lq_path'].extend(lq_frames_path) self.data_info['gt_path'].extend(gt_frames_path) self.data_info['clip_name'].append(clip_name) self.data_info['max_idx'].append(max_idx) self.lq_frames[clip_name] = lq_frames_path self.gt_frames[clip_name] = gt_frames_path
def paired_paths_from_folder(folders, keys, filename_tmpl): """Generate paired paths from folders. Args: folders (list[str]): A list of folder path. The order of list should be [input_folder, gt_folder]. keys (list[str]): A list of keys identifying folders. The order should be in consistent with folders, e.g., ['lq', 'gt']. filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. Usually the filename_tmpl is for files in the input folder. Returns: list[str]: Returned path list. """ assert len(folders) == 2, ( 'The len of folders should be 2 with [input_folder, gt_folder]. ' f'But got {len(folders)}') assert len(keys) == 2, ( 'The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') input_folder, gt_folder = folders input_key, gt_key = keys input_paths = list(scandir(input_folder)) gt_paths = list(scandir(gt_folder)) assert len(input_paths) == len(gt_paths), ( f'{input_key} and {gt_key} datasets have different number of images: ' f'{len(input_paths)}, {len(gt_paths)}.') paths = [] for idx in range(len(gt_paths)): gt_path = gt_paths[idx] basename, ext = osp.splitext(osp.basename(gt_path)) input_path = input_paths[idx] basename_input, ext_input = osp.splitext(osp.basename(input_path)) input_name = f'{filename_tmpl.format(basename)}{ext_input}' input_path = osp.join(input_folder, input_name) assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.') gt_path = osp.join(gt_folder, gt_path) paths.append( dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) return paths
def paths_from_folder(folder): """Generate paths from folder. Args: folder (str): Folder path. Returns: list[str]: Returned path list. """ paths = list(scandir(folder)) paths = [osp.join(folder, path) for path in paths] return paths
def prepare_keys_reds(folder_path): """Prepare image path list and keys for REDS dataset. Args: folder_path (str): Folder path. Returns: list[str]: Image path list. list[str]: Key list. """ print('Reading image path list ...') img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=True))) keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000 return img_path_list, keys
def prepare_keys_div2k(folder_path): """Prepare image path list and keys for DIV2K dataset. Args: folder_path (str): Folder path. Returns: list[str]: Image path list. list[str]: Key list. """ print('Reading image path list ...') img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False))) keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] return img_path_list, keys
def __init__(self, opt): super(SingleImageDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.lq_folder = opt['dataroot_lq'] if 'meta_info_file' in self.opt: with open(self.opt['meta_info_file'], 'r') as fin: self.paths = [ osp.join(self.lq_folder, line.split(' ')[0]) for line in fin ] else: self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
def main(args): niqe_all = [] img_list = sorted(scandir(args.input, recursive=True, full_path=True)) for i, img_path in enumerate(img_list): basename, _ = os.path.splitext(os.path.basename(img_path)) img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) with warnings.catch_warnings(): warnings.simplefilter('ignore', category=RuntimeWarning) niqe_score = calculate_niqe(img, args.crop_border, input_order='HWC', convert_to='y') print(f'{i+1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}') niqe_all.append(niqe_score) print(args.input) print(f'Average: NIQE: {sum(niqe_all) / len(niqe_all):.6f}')
def generate_meta_info_div2k(): """Generate meta info for DIV2K dataset. """ gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/' meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt' img_list = sorted(list(scandir(gt_folder))) with open(meta_info_txt, 'w') as f: for idx, img_path in enumerate(img_list): img = Image.open(osp.join(gt_folder, img_path)) # lazy load width, height = img.size mode = img.mode if mode == 'RGB': n_channel = 3 elif mode == 'L': n_channel = 1 else: raise ValueError(f'Unsupported mode {mode}.') info = f'{img_path} ({height},{width},{n_channel})' print(idx + 1, info) f.write(f'{info}\n')
import importlib from copy import deepcopy from os import path as osp from basicsr.utils import get_root_logger, scandir from basicsr.utils.registry import MODEL_REGISTRY __all__ = ['build_model'] # automatically scan and import model modules for registry # scan all the files under the 'models' folder and collect files ending with # '_model.py' model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] def build_model(opt): """Build model from options. Args: opt (dict): Configuration. It must contain: model_type (str): Model type. """ opt = deepcopy(opt) model = MODEL_REGISTRY.get(opt['model_type'])(opt) logger = get_root_logger() logger.info(f'Model [{model.__class__.__name__}] is created.') return model
def __init__(self, opt): super(VideoTestDataset, self).__init__() self.opt = opt self.cache_data = opt['cache_data'] self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] self.data_info = { 'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': [] } # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] assert self.io_backend_opt[ 'type'] != 'lmdb', 'No need to use lmdb during validation/test.' logger = get_root_logger() logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') self.imgs_lq, self.imgs_gt = {}, {} if 'meta_info_file' in opt: with open(opt['meta_info_file'], 'r') as fin: subfolders = [line.split(' ')[0] for line in fin] subfolders_lq = [ osp.join(self.lq_root, key) for key in subfolders ] subfolders_gt = [ osp.join(self.gt_root, key) for key in subfolders ] else: subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*'))) subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*'))) if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']: for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt): # get frame list for lq and gt subfolder_name = osp.basename(subfolder_lq) img_paths_lq = sorted( list(scandir(subfolder_lq, full_path=True))) img_paths_gt = sorted( list(scandir(subfolder_gt, full_path=True))) max_idx = len(img_paths_lq) assert max_idx == len(img_paths_gt), ( f'Different number of images in lq ({max_idx})' f' and gt folders ({len(img_paths_gt)})') self.data_info['lq_path'].extend(img_paths_lq) self.data_info['gt_path'].extend(img_paths_gt) self.data_info['folder'].extend([subfolder_name] * max_idx) for i in range(max_idx): self.data_info['idx'].append(f'{i}/{max_idx}') border_l = [0] * max_idx for i in range(self.opt['num_frame'] // 2): border_l[i] = 1 border_l[max_idx - i - 1] = 1 self.data_info['border'].extend(border_l) # cache data or save the frame list if self.cache_data: logger.info( f'Cache {subfolder_name} for VideoTestDataset...') self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq) self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt) else: self.imgs_lq[subfolder_name] = img_paths_lq self.imgs_gt[subfolder_name] = img_paths_gt else: raise ValueError( f'Non-supported video test dataset: {type(opt["name"])}')
import torch import torch.utils.data from functools import partial from os import path as osp from basicsr.data.prefetch_dataloader import PrefetchDataLoader from basicsr.utils import get_root_logger, scandir from basicsr.utils.dist_util import get_dist_info __all__ = ['create_dataset', 'create_dataloader'] # automatically scan and import dataset modules # scan all the files under the data folder with '_dataset' in file names data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [ osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py') ] # import all the dataset modules _dataset_modules = [ importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames ] def create_dataset(dataset_opt): """Create dataset. Args: dataset_opt (dict): Configuration for dataset. It constains: name (str): Dataset name.
import importlib from os import path as osp from basicsr.utils import get_root_logger, scandir # automatically scan and import model modules # scan all the files under the 'models' folder and collect files ending with # '_model.py' model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [ osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py') ] # import all the model modules _model_modules = [ importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames ] def create_model(opt): """Create model. Args: opt (dict): Configuration. It constains: model_type (str): Model type. """ model_type = opt['model_type'] # dynamic instantiation for module in _model_modules:
import importlib from basicsr.utils import scandir from os import path as osp # automatically scan and import arch modules for registry # scan all the files that end with '_arch.py' under the archs folder arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] # import all the arch modules _arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]