コード例 #1
0
    def __init__(self, opt):
        super(MixDataset, self).__init__()
        self.opt = opt
        # temporal augmentation

        self.HQ_roots = opt["dataroots_HQ"]
        self.LQ_roots = opt["dataroots_LQ"]
        self.use_identical = opt["identical_loss"]
        dataset_weights = opt["dataset_weights"]
        self.data_type = "lmdb"
        # directly load image keys
        self.HQ_envs, self.LQ_envs = None, None
        self.paths_HQ = []
        for idx, (HQ_root,
                  LQ_root) in enumerate(zip(self.HQ_roots, self.LQ_roots)):
            paths_HQ, _ = util.get_image_paths(self.data_type, HQ_root)
            self.paths_HQ += list(zip([idx] * len(paths_HQ),
                                      paths_HQ)) * dataset_weights[idx]
        random.shuffle(self.paths_HQ)
        logger.info("Using lmdb meta info for cache keys.")
コード例 #2
0
    def __init__(self, opt):
        super(REDSDataset, self).__init__()
        self.opt = opt
        # temporal augmentation
        self.interval_list = opt['interval_list']
        self.random_reverse = opt['random_reverse']
        logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
            ','.join(str(x) for x in opt['interval_list']), self.random_reverse))

        self.half_N_frames = opt['N_frames'] // 2
        self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
        self.data_type = self.opt['data_type']
        self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True  # low resolution inputs
        #### directly load image keys
        if self.data_type == 'lmdb':
            self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
            logger.info('Using lmdb meta info for cache keys.')
        elif opt['cache_keys']:
            logger.info('Using cache keys: {}'.format(opt['cache_keys']))
            self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
        else:
            raise ValueError(
                'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')

        # remove the REDS4 for testing
        self.paths_GT = [
            v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
        ]
        assert self.paths_GT, 'Error: GT path is empty.'

        if self.data_type == 'lmdb':
            self.GT_env, self.LQ_env = None, None
        elif self.data_type == 'mc':  # memcached
            self.mclient = None
        elif self.data_type == 'img':
            pass
        else:
            raise ValueError('Wrong data type: {}'.format(self.data_type))
コード例 #3
0
ファイル: REDS_dataset.py プロジェクト: DengpanFu/mmsr
    def __init__(self, opt):
        super(MetaREDSDatasetOnline, self).__init__()
        self.opt = opt
        # temporal augmentation
        self.interval_list = opt['interval_list']
        self.random_reverse = opt['random_reverse']
        logger.info(
            'Temporal augmentation interval list: [{}], with random reverse is {}.'
            .format(','.join(str(x) for x in opt['interval_list']),
                    self.random_reverse))

        self.half_N_frames = opt['N_frames'] // 2
        self.GT_root = opt['dataroot_GT']
        self.data_type = self.opt['data_type']

        #### Load image keys
        if self.data_type == 'lmdb':
            self.paths_GT, self.GT_size_tuple = util.get_image_paths(
                self.data_type, self.GT_root)
            logger.info('Using lmdb meta info for cache keys.')
            self.paths_GT = [v for v in self.paths_GT if v.split('_')[0] not in \
                            ['000', '011', '015', '020']]
        else:
            seqs = sorted(os.listdir(self.GT_root))
            self.paths_GT = []
            for seq in seqs:
                if not seq in ['000', '011', '015', '020']:
                    names = os.listdir(osp.join(self.GT_root, seq))
                    self.paths_GT.extend([seq + '_' + x[:-4] for x in names])
        assert self.paths_GT, 'Error: GT path is empty.'

        if self.data_type == 'lmdb':
            self.GT_env = None

        self.scales = self.opt['scale']
        assert (len(self.scales) >= 1)
        self.GT_sizes = self.opt['GT_size']
        self.LQ_size = self.opt['LQ_size']
コード例 #4
0
    def __init__(self, opt, transform=DataTransform2(), acc_fac=None):
        super(HDF5Dataset, self).__init__()

        self.opt = opt
        self.acc_fac = acc_fac
        self.transform = transform
        self.paths_h5 = None
        self.data_type = 'h5'

        if self.acc_fac is not None:  # Use both if self.acc_fac is None.
            assert self.acc_fac in (4, 8), 'Invalid acceleration factor'

        _, self.paths_h5 = util.get_image_paths(self.data_type,
                                                opt['dataroot_HR'])
        file_names = self.paths_h5
        data_path = Path(opt['dataroot_HR'])
        #print(f'Initializing {data_path.stem}. This might take a minute')
        slice_counts = [
            self.get_slice_number(file_name) for file_name in file_names
        ]
        self.num_slices = sum(slice_counts)

        names_and_slices = list()

        if self.acc_fac is not None:
            for name, slice_num in zip(file_names, slice_counts):
                names_and_slices += [[name, s_idx, self.acc_fac]
                                     for s_idx in range(slice_num)]

        else:
            for name, slice_num in zip(file_names, slice_counts):
                names_and_slices += [[name, s_idx, choice((4, 8))]
                                     for s_idx in range(slice_num)]

        self.names_and_slices = names_and_slices
        assert self.num_slices == len(names_and_slices), 'Error in length'
コード例 #5
0
    def __init__(self, opt, specific_image=None):
        super(JpegDataset, self).__init__()
        self.opt = opt
        self.paths_Uncomp = None
        self.Uncomp_env = None
        self.block_size = 8
        if opt['input_downsampling'] is not None:
            self.block_size *= opt['input_downsampling']
        self.quality_factors = opt['jpeg_quality_factor']
        if not isinstance(self.quality_factors, list):
            self.quality_factors = [self.quality_factors]
        self.QF_probs = opt['QF_probs']
        if self.QF_probs is None:
            self.QF_probs = np.ones([len(self.quality_factors)])
        else:
            assert len(self.QF_probs) == len(self.quality_factors)
        self.QF_probs /= self.QF_probs.sum()

        # read image list from subset list txt
        if opt['subset_file'] is not None and opt['phase'] == 'train':
            with open(opt['subset_file']) as f:
                self.paths_Uncomp = sorted([os.path.join(opt['dataroot_Uncomp'], line.rstrip('\n')) \
                        for line in f])
            if opt['dataroot_LR'] is not None:
                raise NotImplementedError(
                    'Now subset only supports generating LR on-the-fly.')
        else:  # read image list from lmdb or image files
            self.Uncomp_env, self.paths_Uncomp = util.get_image_paths(
                opt['data_type'], opt['dataroot_Uncomp'])
            # self.Uncomp_env, self.paths_Uncomp = util.get_image_paths(opt['data_type'],
            #     opt['dataroot_Uncomp'].replace('GrayScale','HRx4') if '_chroma' in opt['mode'] else opt['dataroot_Uncomp'])
        if opt['scales'] is not None:
            assert len(opt['scales']) == 3
            new_paths_list = []
            for scale_num, prob_ratio in enumerate(opt['scales']):
                if prob_ratio == 0:
                    continue
                new_paths_list += prob_ratio * [
                    path for path in self.paths_Uncomp if '_scale%d_' %
                    (scale_num) in path
                ]
            self.paths_Uncomp = new_paths_list
        assert self.paths_Uncomp, 'Error: Uncomp path is empty.'
        if self.opt['phase'] == 'train':
            assert not self.opt[
                'patch_size'] % 8, 'Training for JPEG compression artifacts removal - Training images should have an integer number of 8x8 blocks.'
        else:
            # self.per_index_QF = np.round(np.linspace(start=self.quality_factors[0][0],stop=self.quality_factors[0][1]-1,num=len(self))).astype(int)
            if len(self.quality_factors) >= len(self):
                sampled_QFs = np.round(
                    np.linspace(start=0,
                                stop=len(self.quality_factors),
                                num=len(self))).astype(int)
                per_range_len = [
                    1 if (QF in sampled_QFs) else 0
                    for QF in self.quality_factors
                ]
            else:
                num_exact_values = sum(
                    [not isinstance(QF, list) for QF in self.quality_factors])
                per_range_len = [
                    ((len(self) - num_exact_values) //
                     (len(self.quality_factors) - num_exact_values))
                    if isinstance(QF, list) else 1
                    for QF in self.quality_factors
                ]
                if any([isinstance(QF, list) for QF in self.quality_factors]):
                    per_range_len[np.argwhere([
                        isinstance(QF, list) for QF in self.quality_factors
                    ])[0][0]] += len(self) - sum(per_range_len)
                else:
                    per_range_len[0] += len(self) - sum(per_range_len)
            self.per_index_QF = []
            for i, QF_range_len in enumerate(per_range_len):
                if isinstance(self.quality_factors[i], list):
                    self.per_index_QF += list(
                        np.round(
                            np.linspace(start=self.quality_factors[i][0],
                                        stop=self.quality_factors[i][1] - 1,
                                        num=QF_range_len)).astype(int))
                else:
                    self.per_index_QF += [self.quality_factors[i]
                                          ] * QF_range_len

        self.random_scale_list = [1]
コード例 #6
0
    def __init__(self, opt):
        super(LRHRRefDataset, self).__init__()
        self.opt = opt
        self.paths_LR = []
        self.paths_HR = []
        self.paths_ref = []

        if opt['data_type'] == 'lmdb':  # only used in train phase
            import lmdb
            if opt['dataroot_LR'] is not None:
                self.LR_env = lmdb.open(opt['dataroot_LR'], readonly=True, \
                        lock=False, readahead=False, meminit=False)
                # get keys
                keys_cache_file = os.path.join(opt['dataroot_LR'],
                                               '_keys_cache.p')
                if os.path.isfile(keys_cache_file):
                    print('read lmdb keys from cache: {}'.format(
                        keys_cache_file))
                    keys = pickle.load(open(keys_cache_file, "rb"))
                else:
                    with self.LR_env.begin(write=False) as txn:
                        print('creating lmdb keys cache: {}'.format(
                            keys_cache_file))
                        keys = [key.decode('ascii') for key, _ in txn.cursor()]
                        pickle.dump(keys, open(keys_cache_file, "wb"))
                self.paths_LR = sorted(
                    [key for key in keys if not key.endswith('.meta')])
            if opt['dataroot_HR'] is not None:
                self.HR_env = lmdb.open(opt['dataroot_HR'], readonly=True, \
                        lock=False, readahead=False, meminit=False)
                # get keys
                keys_cache_file = os.path.join(opt['dataroot_HR'],
                                               '_keys_cache.p')
                if os.path.isfile(keys_cache_file):
                    print('read lmdb keys from cache: {}'.format(
                        keys_cache_file))
                    keys = pickle.load(open(keys_cache_file, "rb"))
                else:
                    with self.HR_env.begin(write=False) as txn:
                        print('creating lmdb keys cache: {}'.format(
                            keys_cache_file))
                        keys = [key.decode('ascii') for key, _ in txn.cursor()]
                        pickle.dump(keys, open(keys_cache_file, "wb"))
                self.paths_HR = sorted(
                    [key for key in keys if not key.endswith('.meta')])
            if self.paths_LR and self.paths_HR:
                assert len(self.paths_LR) == len(self.paths_HR), \
                    'HR and LR lmdb datasets have different number of images.'
            # TODO lmdb does not support ref image now
        else:  # read image from files
            if opt['phase'] == 'train' and opt['subset_file'] is not None:
                # get HR image paths from list
                with open(opt['subset_file']) as f:
                    self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
                            for line in f])
                if opt['dataroot_LR'] is not None:
                    raise NotImplementedError(
                        'Now subset only support generating LR on-the-fly.')
            else:
                if opt['dataroot_LR'] is not None:
                    self.paths_LR = sorted(get_image_paths(opt['dataroot_LR']))
                if opt['dataroot_HR'] is not None:
                    self.paths_HR = sorted(get_image_paths(opt['dataroot_HR']))
                assert self.paths_LR or self.paths_HR, 'Both LR and HR paths are empty.'
                if self.paths_LR and self.paths_HR:
                    assert len(self.paths_LR) == len(self.paths_HR), \
                        'HR and LR datasets have different number of images.'
            # ref images
            if opt['dataroot_ref']:
                self.paths_ref = sorted(get_image_paths(opt['dataroot_ref']))
コード例 #7
0
    def __init__(self, opt):
        super(LRHRDataset, self).__init__()
        self.opt = opt
        self.paths_LR = None
        self.paths_HR = None
        self.LR_env = None  # environment for lmdb
        self.HR_env = None
        self.HR_crop = None #v
        self.HR_rrot = None #v
        self.LR_scale = None #v
        self.LR_blur = None #v
        self.HR_noise = None #v
        self.LR_noise = None #v
        self.LR_noise2 = None #v
        self.LR_cutout = None #v
        self.LR_erasing = None #v

        # read image list from subset list txt
        if opt['subset_file'] is not None and opt['phase'] == 'train':
            with open(opt['subset_file']) as f:
                self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
                        for line in f])
            if opt['dataroot_LR'] is not None:
                raise NotImplementedError('Now subset only supports generating LR on-the-fly.')
        else:  # read image list from lmdb or image files
            self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
            self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])

        assert self.paths_HR, 'Error: HR path is empty.'
        if self.paths_LR and self.paths_HR:
            assert len(self.paths_LR) == len(self.paths_HR), \
                'HR and LR datasets have different number of images - {}, {}.'.format(\
                len(self.paths_LR), len(self.paths_HR)) 
        
        #v parse on the fly options
        if opt['hr_crop']: #v variable to activate automatic crop of HR image to correct size and generate LR
            self.HR_crop = True
            print("Automatic crop of HR images enabled")
        if opt['hr_rrot']: #v variable to activate automatic rotate HR image and generate LR
            self.HR_rrot = True
            print("HR random rotation enabled")
        if opt['hr_noise']: #v  variable to activate adding noise to HR image
            self.HR_noise = True 
            self.hr_noise_types = opt['hr_noise_types']
            print("HR_noise enabled")
            print(self.hr_noise_types)
        if opt['lr_downscale']: #v variable to activate automatic downscale of HR images to LR pair, controlled by the scale of the model
            self.LR_scale = True 
            self.scale_algos = opt['lr_downscale_types']
            print("LR_scale enabled")
            print(self.scale_algos)
        if opt['lr_blur']: #v variable to activate automatic blur of LR images
            self.LR_blur = True 
            self.blur_algos = opt['lr_blur_types']
            print("LR_blur enabled")
            print(self.blur_algos)
        if opt['lr_noise']: #v variable to activate adding noise to LR image
            self.LR_noise = True 
            self.noise_types = opt['lr_noise_types']
            print("LR_noise enabled")
            print(self.noise_types)
        if opt['lr_noise2']: #v variable to activate adding a secondary noise to LR image
            self.LR_noise2 = True 
            self.noise_types2 = opt['lr_noise_types2']
            print("LR_noise 2 enabled")
            print(self.noise_types2)
        if opt['lr_cutout']: #v variable to activate random cutout 
            self.LR_cutout = True
            print("LR cutout enabled")
        if opt['lr_erasing']: #v variable to activate random erasing
            self.LR_erasing = True
            print("LR random erasing enabled")
        #v parse on the fly options     
        
        self.random_scale_list = [1]
コード例 #8
0
 def __init__(self, opt):
     self.opt = opt
     input_folder = opt['input_folder']
     self.images = data_util.get_image_paths('img', input_folder)[0]
     print("Found %i images" % (len(self.images), ))
コード例 #9
0
    def __init__(self, opt):
        super(REDSDataset, self).__init__()
        self.opt = opt
        # temporal augmentation
        self.interval_list = opt['interval_list']
        self.random_reverse = opt['random_reverse']
        logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
            ','.join(str(x) for x in opt['interval_list']), self.random_reverse))

        self.half_N_frames = opt['N_frames'] // 2
        self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
        self.data_type = self.opt['data_type']
        self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True  # low resolution inputs
        #### directly load image keys
        if self.data_type == 'lmdb':
            self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
            logger.info('Using lmdb meta info for cache keys.')
        elif opt['cache_keys']:
            logger.info('Using cache keys: {}'.format(opt['cache_keys']))
            self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
        else:
            raise ValueError(
                'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')

        # # remove the REDS4 for testing
        # self.paths_GT = [
        #     v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
        # ]

        # remove the 4Khdr for testing
        if not opt['all_dataset_for_train'] == True:
            print("exclude validation set for training")
            self.paths_GT = [
                v for v in self.paths_GT if v.split('_')[0] not in ['00000', '00001', '00002', '00003', '00004',
                                                                    '00005', '00006', '00007', '00008', '00009']
            ]

        if len(opt['baddata_list']) > 0:
            print("exclude bad training data")
            print(opt['baddata_list'])
            self.paths_GT = [
                v for v in self.paths_GT if v.split('_')[0] not in opt['baddata_list']
            ]


        print("Include validation set for training")

        assert self.paths_GT, 'Error: GT path is empty.'

        if self.data_type == 'lmdb':
            self.GT_env, self.LQ_env = None, None
        elif self.data_type == 'mc':  # memcached
            self.mclient = None
        elif self.data_type == 'img':
            pass
        else:
            raise ValueError('Wrong data type: {}'.format(self.data_type))

        # load screen notation
        with open(opt['frame_notation']) as f:
            self.frame_notation = json.load(f)

        print("Training Dataset Initialized")
コード例 #10
0
    def __init__(self, opt):
        super(VideoTestDataset, self).__init__()
        self.opt = opt
        self.cache_data = opt['cache_data']
        self.half_N_frames = opt['N_frames'] // 2
        self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
        self.data_type = self.opt['data_type']
        self.data_info = {
            'path_LQ': [],
            'path_GT': [],
            'folder': [],
            'idx': [],
            'border': []
        }
        #### Generate data info and cache data
        self.imgs_LQ, self.imgs_GT = {}, {}
        if opt['name'].lower() in ['vid4', 'reds4']:
            if self.data_type == 'lmdb':
                self.lmdb_paths_GT, _ = util.get_image_paths(
                    self.data_type, self.GT_root)
                self.lmdb_paths_LQ, _ = util.get_image_paths(
                    self.data_type, self.LQ_root)
                self.GT_env, self.LQ_env = None, None
                previous_name_a = None
                previous_name_b = None
                for lmdb_path_GT, lmdb_path_LQ in zip(self.lmdb_paths_GT,
                                                      self.lmdb_paths_LQ):
                    GT_name_a, GT_name_b = lmdb_path_GT.split('_')
                    assert lmdb_path_GT == lmdb_path_LQ, 'GT path and LQ path in lmdb is not matched'
                    if previous_name_a != GT_name_a and previous_name_a is not None:
                        max_idx = int(previous_name_b) + 1
                        for i in range(max_idx):
                            self.data_info['idx'].append('{}/{}'.format(
                                i, max_idx))
                        border_l = [0] * max_idx
                        for i in range(self.half_N_frames):
                            border_l[i] = 1
                            border_l[max_idx - i - 1] = 1
                        self.data_info['border'].extend(border_l)
                    self.data_info['folder'].append(GT_name_a)
                    previous_name_a = GT_name_a
                    previous_name_b = GT_name_b
                max_idx = int(previous_name_b) + 1
                for i in range(max_idx):
                    self.data_info['idx'].append('{}/{}'.format(i, max_idx))
                border_l = [0] * max_idx
                for i in range(self.half_N_frames):
                    border_l[i] = 1
                    border_l[max_idx - i - 1] = 1
                self.data_info['border'].extend(border_l)
            else:
                subfolders_LQ = util.glob_file_list(self.LQ_root)
                subfolders_GT = util.glob_file_list(self.GT_root)
                for subfolder_LQ, subfolder_GT in zip(subfolders_LQ,
                                                      subfolders_GT):
                    subfolder_name = osp.basename(subfolder_GT)
                    img_paths_LQ = util.glob_file_list(subfolder_LQ)
                    img_paths_GT = util.glob_file_list(subfolder_GT)
                    max_idx = len(img_paths_LQ)
                    assert max_idx == len(
                        img_paths_GT
                    ), 'Different number of images in LQ and GT folders'
                    self.data_info['path_LQ'].extend(img_paths_LQ)
                    self.data_info['path_GT'].extend(img_paths_GT)
                    self.data_info['folder'].extend([subfolder_name] * max_idx)
                    for i in range(max_idx):
                        self.data_info['idx'].append('{}/{}'.format(
                            i, max_idx))
                    border_l = [0] * max_idx
                    for i in range(self.half_N_frames):
                        border_l[i] = 1
                        border_l[max_idx - i - 1] = 1
                    self.data_info['border'].extend(border_l)

                    if self.cache_data:
                        self.imgs_LQ[subfolder_name] = util.read_img_seq(
                            img_paths_LQ)
                        self.imgs_GT[subfolder_name] = util.read_img_seq(
                            img_paths_GT)
                    else:
                        self.imgs_LQ[subfolder_name] = img_paths_LQ
                        self.imgs_GT[subfolder_name] = img_paths_GT
        elif opt['name'].lower() in ['vimeo90k-test']:
            pass  # TODO
        else:
            raise ValueError(
                'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.'
            )
コード例 #11
0
ファイル: LRHROTF_dataset.py プロジェクト: liuguoyou/BasicSR
    def __init__(self, opt):
        super(LRHRDataset, self).__init__()
        self.opt = opt
        self.paths_LR = None
        self.paths_HR = None
        self.LR_env = None  # environment for lmdb
        self.HR_env = None
        #self.HR_crop = None #v
        self.HR_rrot = None  #v
        self.LR_scale = None  #v
        self.scale_algos = None  #v
        self.LR_blur = None  #v
        self.HR_noise = None  #v
        self.LR_noise = None  #v
        self.LR_noise2 = None  #v
        self.LR_cutout = None  #v
        self.LR_erasing = None  #v
        self.output_sample_imgs = None

        # read image list from subset list txt
        if opt['subset_file'] is not None and opt['phase'] == 'train':
            with open(opt['subset_file']) as f:
                self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
                        for line in f])
            if opt['dataroot_LR'] is not None:
                raise NotImplementedError(
                    'Now subset only supports generating LR on-the-fly.')
        else:  # read image list from lmdb or image files
            self.HR_env, self.paths_HR = util.get_image_paths(
                opt['data_type'], opt['dataroot_HR'])
            self.LR_env, self.paths_LR = util.get_image_paths(
                opt['data_type'], opt['dataroot_LR'])

        assert self.paths_HR, 'Error: HR path is empty.'
        if self.paths_LR and self.paths_HR:
            # Modify to allow using HR and LR folders with different amount of images
            # - If an LR image pair is not found, downscale HR on the fly, else, use the LR
            # - If all LR are provided and 'lr_downscale' is enabled, randomize use of provided LR and OTF LR for augmentation
            """
            assert len(self.paths_LR) == len(self.paths_HR), \
                'HR and LR datasets have different number of images - {}, {}.'.format(\
                len(self.paths_LR), len(self.paths_HR))
            """
            #"""
            assert len(self.paths_HR) >= len(self.paths_LR), \
                'HR dataset contains less images than LR dataset  - {}, {}.'.format(\
                len(self.paths_LR), len(self.paths_HR))
            if len(self.paths_LR) < len(self.paths_HR):
                print(
                    'LR contains less images than HR dataset  - {}, {}. Will generate missing images on the fly.'
                    .format(len(self.paths_LR), len(self.paths_HR)))
                import os
                i = 0
                tmp = []
                for idx in range(0, len(self.paths_HR)):
                    _, HRtail = os.path.split(self.paths_HR[idx])
                    if i < len(self.paths_LR):
                        LRhead, LRtail = os.path.split(self.paths_LR[i])

                        if LRtail == HRtail:
                            LRimg_path = os.path.join(LRhead, LRtail)
                            tmp.append(LRimg_path)
                            i += 1
                        else:
                            LRimg_path = None
                            tmp.append(LRimg_path)
                    else:  #if the last image is missing
                        LRimg_path = None
                        tmp.append(LRimg_path)
                self.paths_LR = tmp
            #"""

        #v parse on the fly options
        if opt['hr_rrot']:  #v variable to activate automatic rotate HR image and generate LR
            self.HR_rrot = True
            print("HR random rotation enabled")
        if opt['hr_noise']:  #v  variable to activate adding noise to HR image
            self.HR_noise = True
            self.hr_noise_types = opt['hr_noise_types']
            print("HR_noise enabled")
            print(self.hr_noise_types)
        if opt['lr_downscale']:  #v variable to activate automatic downscale of HR images to LR pair, controlled by the scale of the model
            self.LR_scale = True
            self.scale_algos = opt['lr_downscale_types']
            print("LR_scale enabled")
            print(self.scale_algos)
        if opt['lr_blur']:  #v variable to activate automatic blur of LR images
            self.LR_blur = True
            self.blur_algos = opt['lr_blur_types']
            print("LR_blur enabled")
            print(self.blur_algos)
        if opt['lr_noise']:  #v variable to activate adding noise to LR image
            self.LR_noise = True
            self.noise_types = opt['lr_noise_types']
            print("LR_noise enabled")
            print(self.noise_types)
        if opt['lr_noise2']:  #v variable to activate adding a secondary noise to LR image
            self.LR_noise2 = True
            self.noise_types2 = opt['lr_noise_types2']
            print("LR_noise 2 enabled")
            print(self.noise_types2)
        if opt['lr_cutout']:  #v variable to activate random cutout
            self.LR_cutout = True
            print("LR cutout enabled")
        if opt['lr_erasing']:  #v variable to activate random erasing
            self.LR_erasing = True
            print("LR random erasing enabled")
コード例 #12
0
    def __init__(self, opt):
        super(AI4KDataset, self).__init__()
        self.opt = opt
        if opt['video_class']:
            self.video_class = opt[
                'video_class']  # all | movie | cartoon | lego
        else:
            self.video_class = 'all'

        # temporal augmentation
        self.interval_list = opt['interval_list']
        self.random_reverse = opt['random_reverse']
        logger.info(
            'Temporal augmentation interval list: [{}], with random reverse is {}.'
            .format(','.join(str(x) for x in opt['interval_list']),
                    self.random_reverse))

        self.half_N_frames = opt['N_frames'] // 2
        self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
        self.data_type = self.opt['data_type']
        self.LR_input = False if opt['GT_size'] == opt[
            'LQ_size'] else True  # low resolution inputs
        #### directly load image keys
        if self.data_type == 'lmdb':
            self.paths_GT, _ = util.get_image_paths(self.data_type,
                                                    opt['dataroot_GT'])
            train_list = []
            if self.video_class == 'all':
                pass
            elif self.video_class == 'movie':
                with open('data/movie_list.txt', 'r') as f:
                    for line in f.readlines():
                        line = line.strip()
                        train_list.append(line)
                    #print((train_list))
                for item in self.paths_GT.copy():
                    if item.split('_')[0] not in train_list:
                        self.paths_GT.remove(item)
            elif self.video_class == 'cartoon':
                with open('data/cartoon_list.txt', 'r') as f:
                    for line in f.readlines():
                        line = line.strip()
                        train_list.append(line)
                for item in self.paths_GT.copy():
                    if item.split('_')[0] not in train_list:
                        self.paths_GT.remove(item)
            elif self.video_class == 'lego':
                with open('data/lego_list.txt', 'r') as f:
                    for line in f.readlines():
                        line = line.strip()
                        train_list.append(line)
                for item in self.paths_GT.copy():
                    if item.split('_')[0] not in train_list:
                        self.paths_GT.remove(item)

            # clear bad data
            for item in self.paths_GT.copy():
                if item.split('_')[0] == '15922480':
                    self.paths_GT.remove(item)

            logger.info('Using lmdb meta info for cache keys.')
        elif self.data_type == 'img':
            self.paths_GT, _ = util.get_image_paths(self.data_type,
                                                    opt['dataroot_GT'])

        elif opt['cache_keys']:
            logger.info('Using cache keys: {}'.format(opt['cache_keys']))
            self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
        #else:
        #    raise ValueError(
        #        'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')

        assert self.paths_GT, 'Error: GT path is empty.'
        #rint((self.paths_GT))
        print(len(self.paths_GT))

        if self.data_type == 'lmdb':
            self.GT_env, self.LQ_env = None, None
        elif self.data_type == 'mc':  # memcached
            self.mclient = None
        elif self.data_type == 'img':
            pass
        else:
            raise ValueError('Wrong data type: {}'.format(self.data_type))
コード例 #13
0
    def __init__(self, opt):
        super(ImageLabelDatasetFiveLevels, self).__init__()
        self.opt = opt

        # read image list from subset list txt
        if opt['subset_file'] is not None and opt['phase'] == 'train':
            with open(opt['subset_file']) as f:
                self.paths_Images = sorted([
                    os.path.join(opt['dataroot_HR_Image'], line.rstrip('\n'))
                    for line in f
                ])
            if opt['dataroot_LR_Image'] is not None:
                raise NotImplementedError(
                    'Now subset only supports generating LR on-the-fly.')
        else:  # read image list from lmdb or image files
            # self.HR_images_env, self.paths_HR_images = util.get_image_paths(opt['data_type'], opt['dataroot_HR_Image'])
            # self.LR_labels_env, self.paths_LR_labels = util.get_image_paths(opt['data_type'], opt['dataroot_HR_Label'])
            self.HR_images_env, self.paths_HR_images = util.get_image_paths(
                "lmdb", opt['dataroot_HR_Image'])
            self.LR_labels_env, self.paths_LR_labels = util.get_image_paths(
                "lmdb", opt['dataroot_HR_Label'])
            # self.LR_images_env, self.paths_LR_images = util.get_image_paths(opt['data_type'], opt['dataroot_LR_Image'])
            # self.LR_labels_env, self.paths_LR_labels = util.get_image_paths(opt['data_type'], opt['dataroot_LR_Label'])
            #
            # self.D1_images_env, self.paths_D1_images = util.get_image_paths(opt['data_type'], opt['dataroot_D1_Image'])
            # self.D1_labels_env, self.paths_D1_labels = util.get_image_paths(opt['data_type'], opt['dataroot_D1_Label'])
            #
            # self.D2_images_env, self.paths_D2_images = util.get_image_paths(opt['data_type'], opt['dataroot_D2_Image'])
            # self.D2_labels_env, self.paths_D2_labels = util.get_image_paths(opt['data_type'], opt['dataroot_D2_Label'])
            #
            # self.D3_images_env, self.paths_D3_images = util.get_image_paths(opt['data_type'], opt['dataroot_D3_Image'])
            # self.D3_labels_env, self.paths_D3_labels = util.get_image_paths(opt['data_type'], opt['dataroot_D3_Label'])
            #
            # self.D4_images_env, self.paths_D4_images = util.get_image_paths(opt['data_type'], opt['dataroot_D4_Image'])
            # self.D4_labels_env, self.paths_D4_labels = util.get_image_paths(opt['data_type'], opt['dataroot_D4_Label'])

            print("Images: {}".format(len(self.paths_HR_images)))
            print("Labels: {}".format(len(self.paths_LR_labels)))

        if opt['max_len'] is not None:
            print("Trimming the first", opt['max_len'])
            self.paths_HR_images = self.paths_HR_images[:opt['max_len']]
            # self.paths_LR_images = self.paths_LR_images[:opt['max_len']]
            self.paths_LR_labels = self.paths_LR_labels[:opt['max_len']]
            # self.paths_D1_images = self.paths_D1_images[:opt['max_len']]
            # self.paths_D1_labels = self.paths_D1_labels[:opt['max_len']]
            # self.paths_D2_images = self.paths_D2_images[:opt['max_len']]
            # self.paths_D2_labels = self.paths_D2_labels[:opt['max_len']]
            # self.paths_D3_images = self.paths_D3_images[:opt['max_len']]
            # self.paths_D3_labels = self.paths_D3_labels[:opt['max_len']]
            # self.paths_D4_images = self.paths_D4_images[:opt['max_len']]
            # self.paths_D4_labels = self.paths_D4_labels[:opt['max_len']]

        assert self.paths_HR_images, 'Error: Images path is empty.'
        # if self.paths_HR_images and self.paths_LR_images:
        #     assert len(self.paths_HR_images) == len(self.paths_LR_images), \
        #         'Images and Labels datasets have different number of images - {}, {}.'.format( \
        #             len(self.paths_HR_images), len(self.paths_LR_images))

        self.random_scale_list = [1]
        # rarity maks
        self.rarity_masks = []

        # if opt['phase'] == 'train':
        #     # batch size for modifying the input index in the training time
        #
        #     self.batch_size = opt["batch_size_per_month"]
        #     if 'rarity_mask_1' in self.opt:
        #         for i in range(0, int(math.log(opt['scale'], 2))):
        #             self.rarity_masks.append(np.load(opt['rarity_mask_{}'.format(i + 1)], mmap_mode='r'))
        #
        #     # rarity bins for dataset re-balancing
        #     self.objectid = [0, 2, 10, 19, 1]
        #     self.objectnum = 5
        #     self.rarity_bin = []
        #     for i in range(self.objectnum):
        #         self.rarity_bin.append(np.load(opt['rarity_bins'] % self.objectid[i]))

        # color palette dataset
        self.dataset_color_map = util.PaletteDataset(opt['palette'])
コード例 #14
0
ファイル: video_test_dataset.py プロジェクト: DengpanFu/mmsr
    def __init__(self, opt):
        super(OnlineVideoTestDataset, self).__init__()
        self.opt = opt
        self.data_name = opt['name']
        self.cache_data = opt['cache_data']
        self.half_N_frames = opt['N_frames'] // 2
        self.GT_root = opt['dataroot_GT']
        self.is_lmdb = self.GT_root.endswith('lmdb')

        self.data_info = self.data_info = {
            'path_GT': [],
            'idx': [],
            'folder': [],
            'border': []
        }

        self.scale = opt['scale']
        self.imgs_GT = {}
        if self.data_name.lower() in ['vid4', 'reds4']:
            if self.is_lmdb:
                raise TypeError("{} data should not lmdb".format(
                    self.data_name))
            subs = sorted(os.listdir(self.GT_root))
            for sub in subs:
                sub_dir = osp.join(self.GT_root, sub)
                im_names = sorted(os.listdir(sub_dir))
                im_paths = [osp.join(sub_dir, name) for name in im_names]
                max_idx = len(im_names)
                self.data_info['path_GT'].extend(im_paths)
                self.data_info['folder'].extend([sub] * max_idx)
                for i in range(max_idx):
                    self.data_info['idx'].append('{}/{}'.format(i, max_idx))
                border_l = [0] * max_idx
                for i in range(self.half_N_frames):
                    border_l[i] = 1
                    border_l[max_idx - i - 1] = 1
                self.data_info['border'].extend(border_l)
                if self.cache_data:
                    self.imgs_GT[sub] = self.read_sub_images(im_paths)

        elif self.data_name.lower() in ['vimeo', 'vimeo90k', 'vimeo90k-test']:
            if self.is_lmdb:
                paths_GT, self.GT_size_tuple = util.get_image_paths(
                    'lmdb', self.GT_root)
                self.data_info['path_GT'] = [
                    x for x in paths_GT if x.endswith('_4')
                ]
            else:
                split_file = osp.join(self.GT_root, 'sep_trainlist.txt')
                paths_GT = []
                with open(txt_file, 'r') as f:
                    lines = f.readlines()
                    img_list = [line.strip() for line in lines]
                for item in img_list:
                    key = osp.join(*item.split('/'), 'im4.png')
                    paths_GT.append(osp.join(self.GT_root, 'sequences', key))
                self.data_info['path_GT'] = paths_GT
        else:
            raise ValueError(
                'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.'
            )
コード例 #15
0
ファイル: LRHROTF_dataset.py プロジェクト: canhnht/BasicSR
    def __init__(self, opt):
        super(LRHRDataset, self).__init__()
        self.opt = opt
        self.paths_LR = None
        self.paths_HR = None
        self.LR_env = None  # environment for lmdb
        self.HR_env = None
        self.output_sample_imgs = None

        # read image list from subset list txt
        if opt['subset_file'] is not None and opt['phase'] == 'train':
            with open(opt['subset_file']) as f:
                self.paths_HR = sorted([
                    os.path.join(opt['dataroot_HR'], line.rstrip('\n'))
                    for line in f
                ])
            if opt['dataroot_LR'] is not None:
                raise NotImplementedError(
                    'Now subset only supports generating LR on-the-fly.')
        else:  # read image list from lmdb or image files
            # Check if dataroot_HR is a list of directories or a single directory. Note: lmdb will not currently work with a list
            HR_images_paths = opt['dataroot_HR']
            if type(HR_images_paths) is list:
                self.HR_env = []
                self.paths_HR = []
                for path in HR_images_paths:
                    HR_env, paths_HR = util.get_image_paths(
                        opt['data_type'], path)
                    if type(HR_env) is list:
                        for imgs in HR_env:
                            self.HR_env.append(imgs)
                    for imgs in paths_HR:
                        self.paths_HR.append(imgs)
                if self.HR_env.count(None) == len(self.HR_env):
                    self.HR_env = None
                else:
                    self.HR_env = sorted(self.HR_env)
                self.paths_HR = sorted(self.paths_HR)
            elif type(HR_images_paths) is str:
                self.HR_env, self.paths_HR = util.get_image_paths(
                    opt['data_type'], HR_images_paths)

            # Check if dataroot_LR is a list of directories or a single directory. Note: lmdb will not currently work with a list
            LR_images_paths = opt['dataroot_LR']
            if type(LR_images_paths) is list:
                self.LR_env = []
                self.paths_LR = []
                for path in LR_images_paths:
                    LR_env, paths_LR = util.get_image_paths(
                        opt['data_type'], path)
                    if type(LR_env) is list:
                        for imgs in LR_env:
                            self.LR_env.append(imgs)
                    for imgs in paths_LR:
                        self.paths_LR.append(imgs)
                if self.LR_env.count(None) == len(self.LR_env):
                    self.LR_env = None
                else:
                    self.LR_env = sorted(self.LR_env)
                self.paths_LR = sorted(self.paths_LR)
            elif type(LR_images_paths) is str:
                self.LR_env, self.paths_LR = util.get_image_paths(
                    opt['data_type'], LR_images_paths)

        assert self.paths_HR, 'Error: HR path is empty.'
        if self.paths_LR and self.paths_HR:
            # Modify to allow using HR and LR folders with different amount of images
            # - If an LR image pair is not found, downscale HR on the fly, else, use the LR
            # - If all LR are provided and 'lr_downscale' is enabled, randomize use of provided LR and OTF LR for augmentation
            """
            assert len(self.paths_LR) == len(self.paths_HR), \
                'HR and LR datasets have different number of images - {}, {}.'.format(\
                len(self.paths_LR), len(self.paths_HR))
            """
            # """
            assert len(self.paths_HR) >= len(self.paths_LR), \
                'HR dataset contains less images than LR dataset  - {}, {}.'.format(
                len(self.paths_LR), len(self.paths_HR))
            # """
            if len(self.paths_LR) < len(self.paths_HR):
                print(
                    'LR contains less images than HR dataset  - {}, {}. Will generate missing images on the fly.'
                    .format(len(self.paths_LR), len(self.paths_HR)))
                i = 0
                tmp = []
                for idx in range(0, len(self.paths_HR)):
                    _, HRtail = os.path.split(self.paths_HR[idx])
                    if i < len(self.paths_LR):
                        LRhead, LRtail = os.path.split(self.paths_LR[i])

                        if LRtail == HRtail:
                            LRimg_path = os.path.join(LRhead, LRtail)
                            tmp.append(LRimg_path)
                            i += 1
                        else:
                            LRimg_path = None
                            tmp.append(LRimg_path)
                    else:  # if the last image is missing
                        LRimg_path = None
                        tmp.append(LRimg_path)
                self.paths_LR = tmp
コード例 #16
0
    def __init__(self, opt):
        self.opt = opt
        self.corruptor = ImageCorruptor(opt)
        if 'center_crop_hq_sz' in opt.keys():
            self.center_crop = functools.partial(ndarray_center_crop,
                                                 opt['center_crop_hq_sz'])
        self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys(
        ) else None
        self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys(
        ) else 1
        self.scale = opt['scale']
        self.paths = opt['paths']
        self.corrupt_before_downsize = opt[
            'corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys(
            ) else False
        self.fetch_alt_image = opt[
            'fetch_alt_image']  # If specified, this dataset will attempt to find a second image
        # from the same video source. Search for 'fetch_alt_image' for more info.
        self.fetch_alt_tiled_image = opt[
            'fetch_alt_tiled_image']  # If specified, this dataset will attempt to find anoter tile from the same source image
        #  Search for 'fetch_alt_tiled_image' for more info.
        assert not (self.fetch_alt_image and self.fetch_alt_tiled_image
                    )  # These are mutually exclusive.
        self.skip_lq = opt_get(opt, ['skip_lq'], False)
        self.disable_flip = opt_get(opt, ['disable_flip'], False)
        self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
        self.force_square = opt_get(opt, ['force_square'], True)
        self.fixed_parameters = {
            k: torch.tensor(v)
            for k, v in opt_get(opt, ['fixed_parameters'], {}).items()
        }
        self.all_image_color_jitter = opt_get(opt, ['all_image_color_jitter'],
                                              0)
        if 'normalize' in opt.keys():
            if opt['normalize'] == 'stylegan2_norm':
                self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
                                           inplace=True)
            elif opt['normalize'] == 'imagenet':
                self.normalize = Normalize((0.485, 0.456, 0.406),
                                           (0.229, 0.224, 0.225),
                                           inplace=True)
            else:
                raise Exception('Unsupported normalize')
        else:
            self.normalize = None
        if self.target_hq_size is not None:
            assert (
                self.target_hq_size // self.scale
            ) % self.multiple == 0  # If we dont throw here, we get some really obscure errors.
        if not isinstance(self.paths, list):
            self.paths = [self.paths]
            self.weights = [1]
        else:
            self.weights = opt['weights']

        if 'labeler' in opt.keys():
            if opt['labeler']['type'] == 'patch_labels':
                self.labeler = VsNetImageLabeler(opt['labeler']['label_file'])
            assert len(
                self.paths
            ) == 1  # Only a single base-path is supported for labeled images.
            self.image_paths = self.labeler.get_labeled_paths(self.paths[0])
        else:
            self.labeler = None

            # Just scan the given directory for images of standard types.
            supported_types = ['jpg', 'jpeg', 'png', 'gif']
            self.image_paths = []
            for path, weight in zip(self.paths, self.weights):
                cache_path = os.path.join(path, 'cache.pth')
                if os.path.exists(cache_path):
                    imgs = torch.load(cache_path)
                else:
                    print(
                        "Building image folder cache, this can take some time for large datasets.."
                    )
                    imgs = util.get_image_paths('img', path)[0]
                    torch.save(imgs, cache_path)
                for w in range(weight):
                    self.image_paths.extend(imgs)
        self.len = len(self.image_paths)
コード例 #17
0
def main():
    device = torch.device("cuda")

    parser = argparse.ArgumentParser(description="Kernel extractor testing")

    parser.add_argument("--source_H",
                        action="store",
                        help="source image height",
                        type=int,
                        required=True)
    parser.add_argument("--source_W",
                        action="store",
                        help="source image width",
                        type=int,
                        required=True)
    parser.add_argument("--target_H",
                        action="store",
                        help="target image height",
                        type=int,
                        required=True)
    parser.add_argument("--target_W",
                        action="store",
                        help="target image width",
                        type=int,
                        required=True)
    parser.add_argument("--augmented_H",
                        action="store",
                        help="desired height of the augmented images",
                        type=int,
                        required=True)
    parser.add_argument("--augmented_W",
                        action="store",
                        help="desired width of the augmented images",
                        type=int,
                        required=True)

    parser.add_argument("--source_LQ_root",
                        action="store",
                        help="source low-quality dataroot",
                        type=str,
                        required=True)
    parser.add_argument("--source_HQ_root",
                        action="store",
                        help="source high-quality dataroot",
                        type=str,
                        required=True)
    parser.add_argument("--target_HQ_root",
                        action="store",
                        help="target high-quality dataroot",
                        type=str,
                        required=True)
    parser.add_argument("--save_path",
                        action="store",
                        help="save path",
                        type=str,
                        required=True)
    parser.add_argument("--yml_path",
                        action="store",
                        help="yml path",
                        type=str,
                        required=True)
    parser.add_argument("--num_images",
                        action="store",
                        help="number of desire augmented images",
                        type=int,
                        required=True)

    args = parser.parse_args()

    source_LQ_root = args.source_LQ_root
    source_HQ_root = args.source_HQ_root
    target_HQ_root = args.target_HQ_root

    save_path = args.save_path
    source_H, source_W = args.source_H, args.source_W
    target_H, target_W = args.target_H, args.target_W
    augmented_H, augmented_W = args.augmented_H, args.augmented_W
    yml_path = args.yml_path
    num_images = args.num_images

    # Initializing logger
    logger = logging.getLogger("base")
    os.makedirs(save_path, exist_ok=True)
    util.setup_logger("base",
                      save_path,
                      "test",
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger.info("source LQ root: {}".format(source_LQ_root))
    logger.info("source HQ root: {}".format(source_HQ_root))
    logger.info("target HQ root: {}".format(target_HQ_root))
    logger.info("augmented height: {}".format(augmented_H))
    logger.info("augmented width: {}".format(augmented_W))
    logger.info("Number of augmented images: {}".format(num_images))

    # Initializing mode
    logger.info("Loading model...")
    with open(yml_path, "r") as f:
        print(yml_path)
        opt = yaml.load(f)["KernelWizard"]
    model_path = opt["pretrained"]
    model = KernelWizard(opt)
    model.eval()
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    logger.info("Done")

    # processing data
    source_HQ_env = lmdb.open(source_HQ_root,
                              readonly=True,
                              lock=False,
                              readahead=False,
                              meminit=False)
    source_LQ_env = lmdb.open(source_LQ_root,
                              readonly=True,
                              lock=False,
                              readahead=False,
                              meminit=False)
    target_HQ_env = lmdb.open(target_HQ_root,
                              readonly=True,
                              lock=False,
                              readahead=False,
                              meminit=False)
    paths_source_HQ, _ = data_util.get_image_paths("lmdb", source_HQ_root)
    paths_target_HQ, _ = data_util.get_image_paths("lmdb", target_HQ_root)

    psnr_avg = 0

    for i in range(num_images):
        source_key = np.random.choice(paths_source_HQ)
        target_key = np.random.choice(paths_target_HQ)

        source_rnd_h = random.randint(0, max(0, source_H - augmented_H))
        source_rnd_w = random.randint(0, max(0, source_W - augmented_W))
        target_rnd_h = random.randint(0, max(0, target_H - augmented_H))
        target_rnd_w = random.randint(0, max(0, target_W - augmented_W))

        source_LQ = read_image(source_LQ_env, source_key, source_rnd_h,
                               source_rnd_w, augmented_H, augmented_W)
        source_HQ = read_image(source_HQ_env, source_key, source_rnd_h,
                               source_rnd_w, augmented_H, augmented_W)
        target_HQ = read_image(target_HQ_env, target_key, target_rnd_h,
                               target_rnd_w, augmented_H, augmented_W)

        source_LQ = torch.Tensor(source_LQ).unsqueeze(0).to(device)
        source_HQ = torch.Tensor(source_HQ).unsqueeze(0).to(device)
        target_HQ = torch.Tensor(target_HQ).unsqueeze(0).to(device)

        with torch.no_grad():
            kernel_mean, kernel_sigma = model(source_HQ, source_LQ)
            kernel = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean)
            fake_source_LQ = model.adaptKernel(source_HQ, kernel)
            target_LQ = model.adaptKernel(target_HQ, kernel)

        LQ_img = util.tensor2img(source_LQ)
        fake_LQ_img = util.tensor2img(fake_source_LQ)
        target_LQ_img = util.tensor2img(target_LQ)
        target_HQ_img = util.tensor2img(target_HQ)

        target_HQ_dst = osp.join(
            save_path, "sharp/{:03d}/{:08d}.png".format(i // 100, i % 100))
        target_LQ_dst = osp.join(
            save_path, "blur/{:03d}/{:08d}.png".format(i // 100, i % 100))

        os.makedirs(osp.dirname(target_HQ_dst), exist_ok=True)
        os.makedirs(osp.dirname(target_LQ_dst), exist_ok=True)

        cv2.imwrite(target_HQ_dst, target_HQ_img)
        cv2.imwrite(target_LQ_dst, target_LQ_img)
        # torch.save(kernel, osp.join(osp.dirname(target_LQ_dst), f'kernel{i:03d}.pth'))

        psnr = util.calculate_psnr(LQ_img, fake_LQ_img)

        logger.info(
            "Reconstruction PSNR of image #{:03d}/{:03d}: {:.2f}db".format(
                i, num_images, psnr))
        psnr_avg += psnr

    logger.info("Average reconstruction PSNR: {:.2f}db".format(psnr_avg /
                                                               num_images))