예제 #1
0
    def __getitem__(self, idx):
        # Randomly choose a sequence in a video
        # key = self.keys[idx]
        # str_idx, seq_idx = self.find_set(key)

        key, str_idx, seq_idx = self.find_set(idx)
        set_hr = self.dict_hr[key]

        if self.stride[str_idx] > 0:
            seq_end = seq_idx + self.sample_length[str_idx]
        else:
            seq_end = seq_idx - self.sample_length[str_idx]

        if seq_end >= 0:
            name_hr = set_hr[seq_idx:seq_end:self.stride[str_idx]]
        else:
            name_hr = set_hr[seq_idx::self.stride[str_idx]]

        if self.img_type == 'img':
            fn_read = imageio.imread
        elif self.img_type == 'bin':
            fn_read = np.load
        else:
            raise ValueError('Wrong img type: {}'.format(self.img_type))

        name = [path.join(key, path.basename(f)) for f in name_hr]
        name = [path.splitext(f)[0] for f in name]
        seq_hr = [fn_read(f) for f in name_hr]
        seq_hr = np.stack(seq_hr, axis=-1)
        seq_hr = preprocessing.np2tensor(seq_hr)

        if self.opt['network_E']['which_model_E'] == 'gaussargs':
            kwargs = preprocessing.set_kernel_params(base='bicubic')
        else:
            kwargs = preprocessing.set_kernel_params()

        kwargs_l = kwargs['sigma']
        kwargs_l.append(kwargs['theta'])
        kwargs_l = torch.Tensor(kwargs_l)
        basis_label = int(5 * kwargs['type'])

        kernel_gen = rkg.Degradation(self.opt_data['train']['kernel_size'],
                                     self.scale, **kwargs)
        seq_lr = kernel_gen.apply(seq_hr)

        if self.train:
            seq_hr, seq_lr = preprocessing.crop(
                seq_hr,
                seq_lr,
                patch_size=self.opt_data['train']['patch_size'])
            # seq_hr, seq_lr = preprocessing.augment(seq_hr, seq_lr)

        return {
            'LQs': seq_lr,
            'GT': seq_hr,
            'Kernel': kernel_gen.kernel,
            'Kernel_type': basis_label,
            'Kernel_args': kwargs_l,
            'name': name
        }
예제 #2
0
    def __getitem__(self, idx):
        name_hr = path.join(self.data_root, 'sequences', self.img_list[idx])
        names_hr = sorted(glob.glob(path.join(name_hr, '*.png')))
        names = [path.splitext(path.basename(f))[0] for f in names_hr]
        names = [path.join(self.img_list[idx], name) for name in names]

        seq_hr = [imageio.imread(f) for f in names_hr]
        seq_hr = np.stack(seq_hr, axis=-1)
        start_frame = (7 - self.nframes) // 2
        seq_hr = seq_hr[..., start_frame:start_frame + self.nframes]

        if self.train:
            seq_hr = preprocessing.crop_border(seq_hr, border=[4, 4])
            seq_hr = preprocessing.crop(
                seq_hr, patch_size=self.opt['datasets']['train']['patch_size'])
            seq_hr = preprocessing.augment(seq_hr, rot=False)

        seq_hr = preprocessing.np2tensor(seq_hr)

        if self.opt['network_E']['which_model_E'] == 'gaussargs':
            kwargs = preprocessing.set_kernel_params(base='bicubic')
        else:
            kwargs = preprocessing.set_kernel_params()

        kernel_gen = rkg.Degradation(
            self.opt['datasets']['train']['kernel_size'], self.scale, **kwargs)
        seq_lr = kernel_gen.apply(seq_hr)

        return {
            'LQs': seq_lr,
            'GT': seq_hr,
            'Kernel': kernel_gen.kernel,
            'Kernel_args': kwargs,
            'name': names
        }
예제 #3
0
    def __getitem__(self, idx):
        name_hr = path.join(self.data_root, 'sequences', self.img_list[idx])
        names_hr = sorted(glob.glob(path.join(name_hr, '*.png')))
        names = [path.splitext(path.basename(f))[0] for f in names_hr]
        names = [path.join(self.img_list[idx], name) for name in names]

        seq_hr = [imageio.imread(f) for f in names_hr]
        seq_hr = np.stack(seq_hr, axis=-1)
        start_frame = random.randint(0, 7 - self.nframes)
        seq_hr = seq_hr[..., start_frame:start_frame + self.nframes]
        seq_hr = preprocessing.np2tensor(seq_hr)

        seq_hr = preprocessing.crop_border(seq_hr, border=[4, 4])
        # To make time efficient crop by seq_hr and make it downsample to seq_lr
        # if self.train:
        # sinc patch_size is decided in seq_LR scale, we have to make it twice larger
        # seq_hr = preprocessing.common_crop(img=seq_hr, patch_size=self.opt['datasets']['train']['patch_size']*2)

        kwargs = preprocessing.set_kernel_params()
        kwargs_l = kwargs['sigma']
        kwargs_l.append(kwargs['theta'])
        kwargs_l = torch.Tensor(kwargs_l)

        base_type = random.random()

        # include random noise for each frame
        kernel_gen = rkg.Degradation(
            self.opt['datasets']['train']['kernel_size'], self.scale,
            base_type, **kwargs)
        seq_lr = []
        seq_superlr = []
        for i in range(seq_hr.shape[0]):
            # kernel_gen.gen_new_noise()
            seq_lr_slice = kernel_gen.apply(seq_hr[i])
            seq_lr.append(seq_lr_slice)
            seq_superlr.append(kernel_gen.apply(seq_lr_slice))

        seq_lr = torch.stack(seq_lr, dim=0)
        seq_superlr = torch.stack(seq_superlr, dim=0)

        if self.train:
            seq_hr, seq_lr, seq_superlr = preprocessing.crop(
                seq_hr,
                seq_lr,
                seq_superlr,
                patch_size=self.opt['datasets']['train']['patch_size'])
            seq_hr, seq_lr, seq_superlr = preprocessing.augment(
                seq_hr, seq_lr, seq_superlr)

        return {
            'SuperLQs': seq_superlr,
            'LQs': seq_lr,
            'GT': seq_hr,
            'Kernel': kernel_gen.kernel,
            'Kernel_args': kwargs_l,
            'name': names
        }
예제 #4
0
    def __getitem__(self, idx):
        name_hr = path.join(self.data_root, 'sequences', self.img_list[idx])
        names_hr = sorted(glob.glob(path.join(name_hr, '*.png')))
        names = [path.splitext(path.basename(f))[0] for f in names_hr]
        names = [path.join(self.img_list[idx], name) for name in names]

        seq_hr = [imageio.imread(f) for f in names_hr]
        seq_hr = np.stack(seq_hr, axis=-1)
        start_frame = random.randint(0, 7 - self.nframes)
        seq_hr = seq_hr[..., start_frame:start_frame + self.nframes]
        seq_hr = preprocessing.np2tensor(seq_hr)

        # seq_hr = preprocessing.crop_border(seq_hr, border=[4, 4])
        # To make time efficient crop by seq_hr and make it downsample to seq_lr
        # if self.train:
        # sinc patch_size is decided in seq_LR scale, we have to make it twice larger
        # seq_hr = preprocessing.common_crop(img=seq_hr, patch_size=self.opt['datasets']['train']['patch_size']*2)

        # include random noise for each frame
        '''
        kernel_set = []
        for i in range(5):
            kwargs = preprocessing.set_kernel_params()
            kernel_set.append(rkg.Degradation(self.opt['datasets']['train']['kernel_size'], self.scale, **kwargs).kernel)
        kernel_set = np.stack(kernel_set, axis=0)
        
        kernel_temp = rkg.Degradation(self.opt['datasets']['train']['kernel_size'], self.scale)
        kernel_temp.set_kernel_directly(kernel_set)

        seq_lr = kernel_temp.apply(seq_hr)
        seq_lr = seq_lr.mul(255).clamp(0, 255).round().div(255)
        kernel_temp.set_kernel_directly(kernel_set[2])
        seq_superlr = kernel_temp.apply(seq_lr)
        '''
        kwargs = preprocessing.set_kernel_params()
        kernel_gen = rkg.Degradation(
            self.opt['datasets']['train']['kernel_size'], self.scale, **kwargs)

        seq_lr = kernel_gen.apply(seq_hr)
        seq_lr = seq_lr.mul(255).clamp(0, 255).round().div(255)
        seq_superlr = kernel_gen.apply(seq_lr)

        if self.train:
            # seq_hr, seq_lr, seq_superlr = preprocessing.crop(seq_hr, seq_lr, seq_superlr, patch_size=self.opt['datasets']['train']['patch_size'])
            seq_hr, seq_lr, seq_superlr = preprocessing.augment(
                seq_hr, seq_lr, seq_superlr)

        return {
            'SuperLQs': seq_superlr,
            'LQs': seq_lr,
            'GT': seq_hr,
            # 'Kernel': kernel_gen.kernel
        }
예제 #5
0
    def __getitem__(self, index):
        folder = self.data_info['folder'][index]
        idx, max_idx = self.data_info['idx'][index].split('/')
        idx, max_idx = int(idx), int(max_idx)
        border = self.data_info['border'][index]

        select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
                                           padding=self.opt['padding'])
        imgs_GT = self.imgs_GT[folder].index_select(0, torch.LongTensor(select_idx))

        # Fix the kernel type 1.3 1.3 no noise
        '''
        kwargs = preprocessing.set_kernel_params()
        kwargs_l = kwargs['sigma']
        kwargs_l.append(kwargs['theta'])
        kwargs_l = torch.Tensor(kwargs_l)
        '''
        kernel_gen = rkg.Degradation(self.kernel_size, self.scale, type=1, sigma=[1.3, 1.3], theta=0)
        imgs_LR = kernel_gen.apply(imgs_GT)
        imgs_SuperLR = kernel_gen.apply(imgs_LR)
        #imgs_LR = []
        #imgs_SuperLR = []
        '''
        for i in range(imgs_GT.shape[0]):
            kernel_gen.gen_new_noise()
            imgs_LR_slice = kernel_gen.apply(imgs_GT[i])
            imgs_LR.append(imgs_LR_slice)
            imgs_SuperLR.append(kernel_gen.apply(imgs_LR_slice))
        imgs_LR = torch.stack(imgs_LR, dim=0)
        imgs_SuperLR = torch.stack(imgs_SuperLR, dim=0)
        '''
        
        return {
            'SuperLQs': imgs_SuperLR,
            'LQs': imgs_LR,
            'GT': imgs_GT,
            'Kernel': kernel_gen.kernel,
            'folder': folder,
            'idx': self.data_info['idx'][index],
            'border': border
        }
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    prog = argparse.ArgumentParser()
    prog.add_argument('--dataset_mode',
                      '-m',
                      type=str,
                      default='Vid4+REDS',
                      help='data_mode')
    prog.add_argument('--degradation_mode',
                      '-d',
                      type=str,
                      default='impulse',
                      choices=('impulse', 'bicubic', 'preset'),
                      help='path to image output directory.')
    prog.add_argument('--sigma_x',
                      '-sx',
                      type=float,
                      default=1,
                      help='sigma_x')
    prog.add_argument('--sigma_y',
                      '-sy',
                      type=float,
                      default=0,
                      help='sigma_y')
    prog.add_argument('--theta', '-t', type=float, default=0, help='theta')
    prog.add_argument('--scale',
                      '-sc',
                      type=int,
                      default=2,
                      choices=(2, 4),
                      help='scale factor')

    args = prog.parse_args()

    data_modes = args.dataset_mode
    degradation_mode = args.degradation_mode  # impulse | bicubic
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta * math.pi / 180
    if sig_y == 0:
        sig_y = sig_x

    scale = args.scale
    kernel_size = 21

    N_frames = 7
    padding = 'new_info'

    data_mode_l = data_modes.split('+')

    for i in range(len(data_mode_l)):
        data_mode = data_mode_l[i]
        #### dataset
        if data_mode == 'Vid4':
            kernel_folder = '../pretrained_models/Mixed/Vid4.npy'
            dataset_folder = '../dataset/Vid4'
        elif data_mode == 'REDS':
            kernel_folder = '../pretrained_models/Mixed/REDS.npy'
            dataset_folder = '../dataset/REDS/train'
        elif data_mode == 'Vimeo':
            if degradation_mode == 'preset':
                raise NotImplementedError(
                    'We do not support preset mode in Vimeo dataset')
            dataset_folder = '../dataset/vimeo_septuplet'
        else:
            raise NotImplementedError()

        save_folder_name = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
            '{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(
                sig_y)) + '_' + str('{:.1f}'.format(args.theta))
        save_folder = osp.join(dataset_folder, 'LR_' + save_folder_name, 'X2')
        if not osp.exists(save_folder):
            os.makedirs(save_folder)

        save_folder2 = osp.join(dataset_folder, 'LR_' + save_folder_name, 'X4')
        if not osp.exists(save_folder2):
            os.makedirs(save_folder2)

        if scale == 4:
            save_folder3 = osp.join(dataset_folder, 'LR_' + save_folder_name,
                                    'X16')
            if not osp.exists(save_folder3):
                os.makedirs(save_folder3)

        if data_mode == 'Vimeo':
            GT_dataset_folder = osp.join(dataset_folder, 'sequences')
            meta = osp.join(dataset_folder, 'sep_testlist.txt')
            with open(meta, 'r') as f:
                seq_list = sorted(f.read().splitlines())
            subfolder_GT_l = [
                osp.join(GT_dataset_folder, seq_ind) for seq_ind in seq_list
            ]

        else:
            GT_dataset_folder = osp.join(dataset_folder, 'HR')
            subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder,
                                                       '*')))
            if data_mode == 'REDS':
                subfolder_GT_l = [
                    k for k in subfolder_GT_l
                    if k.find('000') >= 0 or k.find('011') >= 0
                    or k.find('015') >= 0 or k.find('020') >= 0
                ]

        sig_x, sig_y, the = float(sig_x), float(sig_y), float(the)

        for subfolder_GT in subfolder_GT_l:
            print(subfolder_GT)
            gen_kwargs = set_kernel_params(sigma_x=sig_x,
                                           sigma_y=sig_y,
                                           theta=the)
            if degradation_mode == 'impulse' or degradation_mode == 'preset':
                kernel_gen = rkg.Degradation(kernel_size, 2, **gen_kwargs)
                if degradation_mode == 'preset':
                    kernel_preset = np.load(kernel_folder)
            else:
                kernel_gen = oldkg.Degradation(kernel_size,
                                               2,
                                               type=0.7,
                                               **gen_kwargs)

            if data_mode == 'Vimeo':
                sub1 = osp.basename(osp.dirname(subfolder_GT))
                sub2 = osp.basename(subfolder_GT)
                subfolder_name = osp.join(sub1, sub2)
            else:
                subfolder_name = osp.basename(subfolder_GT)

            save_subfolder = osp.join(save_folder, subfolder_name)
            if not osp.exists(save_subfolder):
                os.makedirs(save_subfolder)

            save_subfolder2 = osp.join(save_folder2, subfolder_name)
            if not osp.exists(save_subfolder2):
                os.makedirs(save_subfolder2)

            if scale == 4:
                save_subfolder3 = osp.join(save_folder3, subfolder_name)
                if not osp.exists(save_subfolder3):
                    os.makedirs(save_subfolder3)

            img_GT_path_l = sorted(glob.glob(osp.join(subfolder_GT, '*')))
            seq_length = len(img_GT_path_l)

            imgs_GT = data_util.read_img_seq(subfolder_GT)  # T C H W

            if degradation_mode == 'preset':
                for index in range(seq_length):
                    save_subsubfolder = osp.join(
                        save_subfolder,
                        osp.splitext(osp.basename(img_GT_path_l[index]))[0])
                    save_subsubfolder2 = osp.join(
                        save_subfolder2,
                        osp.splitext(osp.basename(img_GT_path_l[index]))[0])
                    if not osp.exists(save_subsubfolder):
                        os.mkdir(save_subsubfolder)
                    if not osp.exists(save_subsubfolder2):
                        os.mkdir(save_subsubfolder2)
                    if scale == 4 and not osp.exists(save_subsubfolder3):
                        os.mkdir(save_subsubfolder3)

                    kernel_gen.set_kernel_directly(kernel_preset[index])
                    imgs_HR = imgs_GT[data_util.index_generation(
                        index, seq_length, N_frames, padding)]
                    imgs_LR = kernel_gen.apply(imgs_HR)
                    imgs_LR = imgs_LR.mul(255).clamp(0, 255).round().div(255)
                    imgs_LR_np = imgs_LR.permute(0, 2, 3, 1).cpu().numpy()
                    imgs_LR_np = (imgs_LR_np * 255).astype('uint8')
                    for i, img_LR in enumerate(imgs_LR_np):
                        imageio.imwrite(
                            osp.join(save_subsubfolder, 'img{}.png'.format(i)),
                            img_LR)

                    imgs_SuperLR = kernel_gen.apply(imgs_LR)
                    imgs_SuperLR = imgs_SuperLR.mul(255).clamp(
                        0, 255).round().div(255)
                    imgs_SuperLR_np = imgs_SuperLR.permute(0, 2, 3,
                                                           1).cpu().numpy()
                    imgs_SuperLR_np = (imgs_SuperLR_np * 255).astype('uint8')
                    for i, img_SuperLR in enumerate(imgs_SuperLR_np):
                        imageio.imwrite(
                            osp.join(save_subsubfolder2,
                                     'img{}.png'.format(i)), img_SuperLR)

                    if scale == 4:
                        imgs_SuperLR = kernel_gen.apply(imgs_LR)
                        imgs_SuperLR = imgs_SuperLR.mul(255).clamp(
                            0, 255).round().div(255)
                        imgs_SuperLR = kernel_gen.apply(imgs_LR)
                        imgs_SuperLR = imgs_SuperLR.mul(255).clamp(
                            0, 255).round().div(255)
                        imgs_SuperLR_np = imgs_SuperLR.permute(
                            0, 2, 3, 1).cpu().numpy()
                        imgs_SuperLR_np = (imgs_SuperLR_np *
                                           255).astype('uint8')
                        for i, img_SuperLR in enumerate(imgs_SuperLR_np):
                            imageio.imwrite(
                                osp.join(save_subsubfolder3,
                                         'img{}.png'.format(i)), img_SuperLR)

            else:
                count = 0
                imgs_GT_l = imgs_GT.split(32)
                for img_batch in imgs_GT_l:
                    img_lr_batch = kernel_gen.apply(img_batch)
                    img_lr_batch = img_lr_batch.permute(0, 2, 3,
                                                        1).cpu().numpy()
                    img_lr_batch = (img_lr_batch.clip(0, 1) * 255).round()
                    img_lr_batch = img_lr_batch.astype('uint8')
                    count_temp = count
                    for img_lr in img_lr_batch:
                        filename = osp.basename(img_GT_path_l[count])
                        imageio.imwrite(osp.join(save_subfolder, filename),
                                        img_lr)
                        count += 1

                    img_lr_batch = img_lr_batch.astype('float32') / 255
                    img_lr_batch = torch.from_numpy(img_lr_batch).permute(
                        0, 3, 1, 2)

                    img_superlr_batch = kernel_gen.apply(img_lr_batch)
                    img_superlr_batch = img_superlr_batch.permute(
                        0, 2, 3, 1).cpu().numpy()
                    img_superlr_batch = (img_superlr_batch.clip(0, 1) *
                                         255).round()
                    img_superlr_batch = img_superlr_batch.astype('uint8')
                    count = count_temp
                    for img_superlr in img_superlr_batch:
                        filename = osp.basename(img_GT_path_l[count])
                        imageio.imwrite(osp.join(save_subfolder2, filename),
                                        img_superlr)
                        count += 1
                    if scale == 4:
                        img_superlr_batch = img_superlr_batch.astype(
                            'float32') / 255
                        img_superlr_batch = torch.from_numpy(
                            img_superlr_batch).permute(0, 3, 1, 2)
                        img_superlr_batch = kernel_gen.apply(img_superlr_batch)
                        img_superlr_batch = img_superlr_batch.permute(
                            0, 2, 3, 1).cpu().numpy()
                        img_superlr_batch = (img_superlr_batch.clip(0, 1) *
                                             255).round()
                        img_superlr_batch = img_superlr_batch.astype(
                            'float32') / 255
                        img_superlr_batch = torch.from_numpy(
                            img_superlr_batch).permute(0, 3, 1, 2)
                        img_superlr_batch = kernel_gen.apply(img_superlr_batch)
                        img_superlr_batch = img_superlr_batch.permute(
                            0, 2, 3, 1).cpu().numpy()
                        img_superlr_batch = (img_superlr_batch.clip(0, 1) *
                                             255).round()

                        img_superlr_batch = img_superlr_batch.astype('uint8')
                        count = count_temp
                        for img_superlr in img_superlr_batch:
                            filename = osp.basename(img_GT_path_l[count])
                            imageio.imwrite(
                                osp.join(save_subfolder2, filename),
                                img_superlr)
                            count += 1
예제 #7
0
    def __getitem__(self, idx):
        # Randomly choose a sequence in a video
        # key = self.keys[idx]
        # str_idx, seq_idx = self.find_set(key)

        key, str_idx, seq_idx = self.find_set(idx)
        set_hr = self.dict_hr[key]

        if self.stride[str_idx] > 0:
            seq_end = seq_idx + self.sample_length[str_idx]
        else:
            seq_end = seq_idx - self.sample_length[str_idx]

        if seq_end >= 0:
            name_hr = set_hr[seq_idx:seq_end:self.stride[str_idx]]
        else:
            name_hr = set_hr[seq_idx::self.stride[str_idx]]

        if self.img_type == 'img':
            fn_read = imageio.imread
        elif self.img_type == 'bin':
            fn_read = np.load
        else:
            raise ValueError('Wrong img type: {}'.format(self.img_type))

        name = [path.join(key, path.basename(f)) for f in name_hr]
        seq_hr = [fn_read(f) for f in name_hr]
        seq_hr = np.stack(seq_hr, axis=-1)
        # if self.train:
        # sinc patch_size is decided in seq_LR scale, we have to make it twice larger
        seq_hr = preprocessing.np_common_crop(
            seq_hr, patch_size=self.opt['datasets']['train']['patch_size'] * 2)
        seq_hr = preprocessing.np2tensor(seq_hr)

        # include random noise for each frame
        '''
        kernel_set = []
        for i in range(5):
            kwargs = preprocessing.set_kernel_params()
            kernel_set.append(rkg.Degradation(self.opt['datasets']['train']['kernel_size'], self.scale, **kwargs).kernel)
        kernel_set = np.stack(kernel_set, axis=0)
        
        kernel_temp = rkg.Degradation(self.opt['datasets']['train']['kernel_size'], self.scale)
        kernel_temp.set_kernel_directly(kernel_set)

        seq_lr = kernel_temp.apply(seq_hr)
        seq_lr = seq_lr.mul(255).clamp(0, 255).round().div(255)
        kernel_temp.set_kernel_directly(kernel_set[2])
        seq_superlr = kernel_temp.apply(seq_lr)
        '''
        kwargs = preprocessing.set_kernel_params()
        kernel_gen = rkg.Degradation(self.opt_data['train']['kernel_size'],
                                     self.scale, **kwargs)

        seq_lr = kernel_gen.apply(seq_hr)
        seq_lr = seq_lr.mul(255).clamp(0, 255).round().div(255)
        seq_superlr = kernel_gen.apply(seq_lr)
        if self.train:
            # seq_hr, seq_lr, seq_superlr = preprocessing.crop(seq_hr, seq_lr, seq_superlr, patch_size=self.opt['datasets']['train']['patch_size'])
            seq_hr, seq_lr, seq_superlr = preprocessing.augment(
                seq_hr, seq_lr, seq_superlr)

        return {
            'SuperLQs': seq_superlr,
            'LQs': seq_lr,
            'GT': seq_hr,
            #'Kernel': kernel_gen.kernel,
        }
예제 #8
0
def main():
    #################
    # configurations
    #################
    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    prog = argparse.ArgumentParser()
    prog.add_argument('--dataset_mode',
                      '-m',
                      type=str,
                      default='Vid4+REDS',
                      help='data_mode')
    prog.add_argument('--degradation_mode',
                      '-d',
                      type=str,
                      default='impulse',
                      choices=('impulse', 'bicubic', 'preset'),
                      help='path to image output directory.')
    prog.add_argument('--sigma_x',
                      '-sx',
                      type=float,
                      default=1,
                      help='sigma_x')
    prog.add_argument('--sigma_y',
                      '-sy',
                      type=float,
                      default=0,
                      help='sigma_y')
    prog.add_argument('--theta', '-t', type=float, default=0, help='theta')

    args = prog.parse_args()

    data_modes = args.dataset_mode
    degradation_mode = args.degradation_mode  # impulse | bicubic
    sig_x, sig_y, the = args.sigma_x, args.sigma_y, args.theta * math.pi / 180
    if sig_y == 0:
        sig_y = sig_x

    scale = 2
    kernel_size = 21

    N_frames = 7
    padding = 'new_info'
    ############################################################################

    # model = EDVR_arch.EDVR(n_feats, N_in, 8, 5, back_RBs, predeblur=False, HR_in=False, scale=scale)
    # est_model = LR_arch.DirectKernelEstimator_CMS(nf=n_feats)
    data_mode_l = data_modes.split('+')

    for i in range(len(data_mode_l)):
        data_mode = data_mode_l[i]
        #### dataset
        if data_mode == 'Vid4':
            kernel_folder = '../experiments/pretrained_models/Vid4Gauss.npy'
            dataset_folder = '../dataset/Vid4'
        elif data_mode == 'MM522':
            kernel_folder = '../experiments/pretrained_models/MM522Gauss.npy'
            dataset_folder = '../dataset/MM522val'
        else:
            kernel_folder = '../experiments/pretrained_models/REDSGauss.npy'
            dataset_folder = '../dataset/REDS/train'

        GT_dataset_folder = osp.join(dataset_folder, 'HR')
        save_folder_name = 'preset' if degradation_mode == 'preset' else degradation_mode + '_' + str(
            '{:.1f}'.format(sig_x)) + '_' + str('{:.1f}'.format(
                sig_y)) + '_' + str('{:.1f}'.format(args.theta))
        save_folder = osp.join(dataset_folder, 'LR_' + save_folder_name,
                               'X' + str(scale))
        if not osp.exists(save_folder):
            os.makedirs(save_folder)

        save_folder2 = osp.join(dataset_folder, 'LR_' + save_folder_name,
                                'X' + str(scale * scale))  #*scale*scale))
        if not osp.exists(save_folder2):
            os.makedirs(save_folder2)
        #### log info
        # logger.info('Data: {} - {}'.format(data_mode, lr_set_method))
        # logger.info('Padding mode: {}'.format(padding))
        # logger.info('Model path: {}'.format(model_path))
        # logger.info('Save images: {}'.format(save_imgs))
        # logger.info('Flip test: {}'.format(flip_test))

        #### set up the models
        # model.load_state_dict(torch.load(model_path), strict=True)
        # model.eval()
        # model = model.to(device)

        avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
        subfolder_name_l = []
        # subfolder_l = sorted(glob.glob(osp.join(bicubic_dataset_folder, '*')))
        subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
        if data_mode == 'REDS':
            subfolder_GT_l = [
                k for k in subfolder_GT_l
                if k.find('000') >= 0 or k.find('011') >= 0
                or k.find('015') >= 0 or k.find('020') >= 0
            ]
        elif data_mode == 'MM522':
            subfolder_GT_l = [
                k for k in subfolder_GT_l
                if k.find('001') >= 0 or k.find('005') >= 0
                or k.find('008') >= 0 or k.find('009') >= 0
            ]
        # for each subfolder
        # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

        sig_x, sig_y, the = float(sig_x), float(sig_y), float(the)

        for subfolder_GT in subfolder_GT_l:
            print(subfolder_GT)
            gen_kwargs = set_kernel_params(sigma_x=sig_x,
                                           sigma_y=sig_y,
                                           theta=the)
            if degradation_mode == 'impulse' or degradation_mode == 'preset':
                kernel_gen = rkg.Degradation(kernel_size, scale, **gen_kwargs)
                if degradation_mode == 'preset':
                    kernel_preset = np.load(kernel_folder)
            else:
                kernel_gen = oldkg.Degradation(kernel_size,
                                               scale,
                                               type=0.7,
                                               **gen_kwargs)

            subfolder_name = osp.basename(subfolder_GT)
            subfolder_name_l.append(subfolder_name)
            save_subfolder = osp.join(save_folder, subfolder_name)
            if not osp.exists(save_subfolder):
                os.mkdir(save_subfolder)

            save_subfolder2 = osp.join(save_folder2, subfolder_name)
            if not osp.exists(save_subfolder2):
                os.mkdir(save_subfolder2)

            img_GT_path_l = sorted(glob.glob(osp.join(subfolder_GT, '*')))
            seq_length = len(img_GT_path_l)

            # max_idx = len(img_GT_path_l)
            # if save_imgs:
            #    util.mkdirs(save_subfolder)
            imgs_GT = data_util.read_img_seq(subfolder_GT)  # T C H W

            if degradation_mode == 'preset':
                for index in range(seq_length):
                    save_subsubfolder = osp.join(
                        save_subfolder,
                        osp.splitext(osp.basename(img_GT_path_l[index]))[0])
                    save_subsubfolder2 = osp.join(
                        save_subfolder2,
                        osp.splitext(osp.basename(img_GT_path_l[index]))[0])
                    if not osp.exists(save_subsubfolder):
                        os.mkdir(save_subsubfolder)
                    if not osp.exists(save_subsubfolder2):
                        os.mkdir(save_subsubfolder2)
                    kernel_gen.set_kernel_directly(kernel_preset[index])
                    imgs_HR = imgs_GT[data_util.index_generation(
                        index, seq_length, N_frames, padding)]
                    imgs_LR = kernel_gen.apply(imgs_HR)
                    imgs_LR = imgs_LR.mul(255).clamp(0, 255).round().div(255)
                    imgs_LR_np = imgs_LR.permute(0, 2, 3, 1).cpu().numpy()
                    imgs_LR_np = (imgs_LR_np * 255).astype('uint8')
                    for i, img_LR in enumerate(imgs_LR_np):
                        imageio.imwrite(
                            osp.join(save_subsubfolder, 'img{}.png'.format(i)),
                            img_LR)

                    imgs_SuperLR = kernel_gen.apply(imgs_LR)
                    imgs_SuperLR = imgs_SuperLR.mul(255).clamp(
                        0, 255).round().div(255)
                    imgs_SuperLR_np = imgs_SuperLR.permute(0, 2, 3,
                                                           1).cpu().numpy()
                    imgs_SuperLR_np = (imgs_SuperLR_np * 255).astype('uint8')
                    for i, img_SuperLR in enumerate(imgs_SuperLR_np):
                        imageio.imwrite(
                            osp.join(save_subsubfolder2,
                                     'img{}.png'.format(i)), img_SuperLR)

            else:
                count = 0
                imgs_GT_l = imgs_GT.split(32)
                for img_batch in imgs_GT_l:
                    if degradation_mode == 'preset':
                        kernel_gen.set_kernel_directly(kernel_preset[count])
                    img_lr_batch = kernel_gen.apply(img_batch)
                    img_lr_batch = img_lr_batch.permute(0, 2, 3,
                                                        1).cpu().numpy()
                    img_lr_batch = (img_lr_batch.clip(0, 1) * 255).round()
                    img_lr_batch = img_lr_batch.astype('uint8')
                    count_temp = count
                    for img_lr in img_lr_batch:
                        filename = osp.basename(img_GT_path_l[count])
                        imageio.imwrite(osp.join(save_subfolder, filename),
                                        img_lr)
                        count += 1

                    img_lr_batch = img_lr_batch.astype('float32') / 255
                    img_lr_batch = torch.from_numpy(img_lr_batch).permute(
                        0, 3, 1, 2)

                    img_superlr_batch = kernel_gen.apply(img_lr_batch)
                    img_superlr_batch = img_superlr_batch.permute(
                        0, 2, 3, 1).cpu().numpy()
                    img_superlr_batch = (img_superlr_batch.clip(0, 1) *
                                         255).round()
                    '''
                    img_superlr_batch = img_superlr_batch.astype('float32') / 255
                    img_superlr_batch = torch.from_numpy(img_superlr_batch).permute(0, 3, 1, 2)
                    img_superlr_batch = kernel_gen.apply(img_superlr_batch)
                    img_superlr_batch = img_superlr_batch.permute(0, 2, 3, 1).cpu().numpy()
                    img_superlr_batch = (img_superlr_batch.clip(0, 1) * 255).round()
                    img_superlr_batch = img_superlr_batch.astype('float32') / 255
                    img_superlr_batch = torch.from_numpy(img_superlr_batch).permute(0, 3, 1, 2)
                    img_superlr_batch = kernel_gen.apply(img_superlr_batch)
                    img_superlr_batch = img_superlr_batch.permute(0, 2, 3, 1).cpu().numpy()
                    img_superlr_batch = (img_superlr_batch.clip(0, 1) * 255).round()
                    '''
                    img_superlr_batch = img_superlr_batch.astype('uint8')
                    count = count_temp
                    for img_superlr in img_superlr_batch:
                        filename = osp.basename(img_GT_path_l[count])
                        imageio.imwrite(osp.join(save_subfolder2, filename),
                                        img_superlr)
                        count += 1
    def __init__(self, opt, **kwargs):
        super(VideoTestDataset, self).__init__()
        self.scale = kwargs['scale']
        self.kernel_size = kwargs['kernel_size']
        self.model_name = kwargs['model_name']
        idx = kwargs['idx'] if 'idx' in kwargs else None
        self.opt = opt
        self.cache_data = opt['cache_data']
        self.half_N_frames = opt['N_frames'] // 2
        if idx is None:
            self.name = opt['name']
            self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
            degradation_type = opt['degradation_type']
            opt_sigma_x = opt['sigma_x']
            opt_sigma_y = opt['sigma_y']
            opt_theta = opt['theta']
        else:
            self.name = opt['name'].split('+')[idx]
            self.GT_root, self.LQ_root = opt['dataroot_GT'].split(
                '+')[idx], opt['dataroot_LQ'].split('+')[idx]
            if '+' in opt['degradation_type']:
                degradation_type = opt['degradation_type'].split('+')[idx]
                if '+' in str(opt['sigma_x']):
                    opt_sigma_x = float(opt['sigma_x'].split('+')[idx])
                    opt_sigma_y = float(opt['sigma_y'].split('+')[idx])
                    opt_theta = float(opt['theta'].split('+')[idx])

                else:
                    opt_sigma_x = opt['sigma_x']
                    opt_sigma_y = opt['sigma_y']
                    opt_theta = opt['theta']

            else:
                degradation_type = opt['degradation_type']
                opt_sigma_x = opt['sigma_x']
                opt_sigma_y = opt['sigma_y']
                opt_theta = opt['theta']

        self.data_type = self.opt['data_type']
        self.data_info = {
            'path_SLQ': [],
            'path_LQ': [],
            'path_GT': [],
            'folder': [],
            'idx': [],
            'border': []
        }
        if self.data_type == 'lmdb':
            raise ValueError('No need to use LMDB during validation/test.')
        #### Generate data info and cache data
        self.imgs_SLQ, self.imgs_LQ, self.imgs_GT = {}, {}, {}

        if opt['degradation_mode'] == 'preset':
            self.LQ_root = self.LQ_root + '_preset'
        else:
            if isinstance(opt_sigma_x, list):
                assert len(opt_sigma_x) == len(opt_sigma_y)
                assert len(opt_sigma_x) == len(opt_theta)

                LQ_root_list = []
                for i, (sigma_x, sigma_y, theta) in enumerate(
                        zip(opt_sigma_x, opt_sigma_y, opt_theta)):
                    LQ_root_list.append(self.LQ_root + '_' + degradation_type + '_' + str('{:.1f}'.format(opt_sigma_x[i]))\
                           + '_' + str('{:.1f}'.format(opt_sigma_y[i])) + '_' + str('{:.1f}'.format(opt_theta[i])))
                self.LQ_root = LQ_root_list

            else:
                self.LQ_root = self.LQ_root + '_' + degradation_type + '_' + str('{:.1f}'.format(opt_sigma_x))\
                           + '_' + str('{:.1f}'.format(opt_sigma_y)) + '_' + str('{:.1f}'.format(opt_theta))

        slr_name = '' if opt['slr_mode'] is None else '_{}'.format(
            opt['slr_mode'])

        print(self.LQ_root)

        if self.name.lower() in ['vid4', 'reds', 'mm522']:
            if self.name.lower() == 'vid4':
                img_type = 'img'
                subfolders_GT = util.glob_file_list(self.GT_root)
                if isinstance(self.LQ_root, list):
                    num_settings = len(self.LQ_root)
                    subfolders_LQ_list = [
                        util.glob_file_list(
                            osp.join(LQ_root, 'X{}'.format(self.scale)))
                        for LQ_root in self.LQ_root
                    ]
                    subfolders_SLQ_list = [
                        util.glob_file_list(
                            osp.join(
                                LQ_root,
                                'X{}{}'.format(self.scale * self.scale,
                                               slr_name)))
                        for LQ_root in self.LQ_root
                    ]

                    subfolders_LQ = []
                    subfolders_SLQ = []
                    for i in range(len(subfolders_LQ_list[0])):
                        subfolders_LQ.append([
                            subfolders_LQ_list[j][i]
                            for j in range(len(subfolders_LQ_list))
                        ])
                        subfolders_SLQ.append([
                            subfolders_SLQ_list[j][i]
                            for j in range(len(subfolders_SLQ_list))
                        ])

                else:
                    subfolders_LQ = util.glob_file_list(
                        osp.join(self.LQ_root, 'X{}'.format(self.scale)))
                    subfolders_SLQ = util.glob_file_list(
                        osp.join(
                            self.LQ_root,
                            'X{}{}'.format(self.scale * self.scale, slr_name)))

            elif self.name.lower() == 'reds':
                img_type = 'img'
                list_hr_seq = util.glob_file_list(self.GT_root)
                subfolders_GT = [
                    k for k in list_hr_seq
                    if k.find('000') >= 0 or k.find('011') >= 0
                    or k.find('015') >= 0 or k.find('020') >= 0
                ]
                if isinstance(self.LQ_root, list):
                    num_settings = len(self.LQ_root)
                    subfolders_LQ_list = []
                    subfolders_SLQ_list = []

                    for i in range(num_settings):
                        list_lr_seq = util.glob_file_list(
                            osp.join(self.LQ_root[i],
                                     'X{}'.format(self.scale)))
                        list_slr_seq = util.glob_file_list(
                            osp.join(
                                self.LQ_root[i],
                                'X{}{}'.format(self.scale * self.scale,
                                               slr_name)))
                        subfolder_LQ = [
                            k for k in list_lr_seq
                            if k.find('000') >= 0 or k.find('011') >= 0
                            or k.find('015') >= 0 or k.find('020') >= 0
                        ]
                        subfolder_SLQ = [
                            k for k in list_slr_seq
                            if k.find('000') >= 0 or k.find('011') >= 0
                            or k.find('015') >= 0 or k.find('020') >= 0
                        ]
                        subfolders_LQ_list.append(subfolder_LQ)
                        subfolders_SLQ_list.append(subfolder_SLQ)
                    subfolders_LQ = []
                    subfolders_SLQ = []
                    for i in range(len(subfolders_LQ_list[0])):
                        subfolders_LQ.append([
                            subfolders_LQ_list[j][i]
                            for j in range(len(subfolders_LQ_list))
                        ])
                        subfolders_SLQ.append([
                            subfolders_SLQ_list[j][i]
                            for j in range(len(subfolders_SLQ_list))
                        ])

                else:
                    list_lr_seq = util.glob_file_list(
                        osp.join(self.LQ_root, 'X{}'.format(self.scale)))
                    list_slr_seq = util.glob_file_list(
                        osp.join(
                            self.LQ_root,
                            'X{}{}'.format(self.scale * self.scale, slr_name)))
                    #subfolders_GT = [k for k in list_hr_seq if
                    #                   k.find('000') >= 0 or k.find('011') >= 0 or k.find('015') >= 0 or k.find('020') >= 0]
                    subfolders_LQ = [
                        k for k in list_lr_seq
                        if k.find('000') >= 0 or k.find('011') >= 0
                        or k.find('015') >= 0 or k.find('020') >= 0
                    ]
                    subfolders_SLQ = [
                        k for k in list_slr_seq
                        if k.find('000') >= 0 or k.find('011') >= 0
                        or k.find('015') >= 0 or k.find('020') >= 0
                    ]

            else:
                img_type = 'img'
                list_hr_seq = util.glob_file_list(self.GT_root)
                list_lr_seq = util.glob_file_list(
                    osp.join(self.LQ_root, 'X{}'.format(self.scale)))
                list_slr_seq = util.glob_file_list(
                    osp.join(self.LQ_root,
                             'X{}'.format(self.scale * self.scale)))
                subfolders_GT = [
                    k for k in list_hr_seq
                    if k.find('001') >= 0 or k.find('005') >= 0
                    or k.find('008') >= 0 or k.find('009') >= 0
                ]
                subfolders_LQ = [
                    k for k in list_lr_seq
                    if k.find('001') >= 0 or k.find('005') >= 0
                    or k.find('008') >= 0 or k.find('009') >= 0
                ]
                subfolders_SLQ = [
                    k for k in list_slr_seq
                    if k.find('001') >= 0 or k.find('005') >= 0
                    or k.find('008') >= 0 or k.find('009') >= 0
                ]

            print(subfolders_GT[0], '\n', subfolders_LQ[0], '\n',
                  subfolders_SLQ[0])

            for subfolder_SLQ, subfolder_LQ, subfolder_GT in zip(
                    subfolders_SLQ, subfolders_LQ, subfolders_GT):
                subfolder_name = osp.basename(subfolder_GT)
                img_paths_GT = util.glob_file_list(subfolder_GT)
                if isinstance(subfolder_LQ, list):
                    img_paths_LQ_list = [
                        util.glob_file_list(subf_LQ)
                        for subf_LQ in subfolder_LQ
                    ]
                    img_paths_SLQ_list = [
                        util.glob_file_list(subf_SLQ)
                        for subf_SLQ in subfolder_SLQ
                    ]
                    img_paths_LQ = []
                    img_paths_SLQ = []
                    for i in range(len(img_paths_GT)):
                        img_paths_LQ.append(img_paths_LQ_list[i %
                                                              num_settings][i])
                        img_paths_SLQ.append(
                            img_paths_SLQ_list[i % num_settings][i])
                else:
                    img_paths_LQ = util.glob_file_list(subfolder_LQ)
                    img_paths_SLQ = util.glob_file_list(subfolder_SLQ)

                max_idx = len(img_paths_GT)
                self.data_info['path_SLQ'].extend(img_paths_SLQ)
                self.data_info['path_LQ'].extend(img_paths_LQ)
                self.data_info['path_GT'].extend(img_paths_GT)
                self.data_info['folder'].extend([subfolder_name] * max_idx)
                for i in range(max_idx):
                    self.data_info['idx'].append('{}/{}'.format(i, max_idx))
                border_l = [0] * max_idx
                for i in range(self.half_N_frames):
                    border_l[i] = 1
                    border_l[max_idx - i - 1] = 1
                self.data_info['border'].extend(border_l)
                self.imgs_GT[subfolder_name] = util.read_img_seq(
                    img_paths_GT, img_type)
                if opt['degradation_mode'] == 'preset':
                    self.imgs_LQ[subfolder_name] = torch.stack([
                        util.read_img_seq(util.glob_file_list(paths_LQ),
                                          img_type)
                        for paths_LQ in img_paths_LQ
                    ],
                                                               dim=0)
                    self.imgs_SLQ[subfolder_name] = torch.stack([
                        util.read_img_seq(util.glob_file_list(paths_SLQ),
                                          img_type)
                        for paths_SLQ in img_paths_SLQ
                    ],
                                                                dim=0)
                else:
                    self.imgs_LQ[subfolder_name] = util.read_img_seq(
                        img_paths_LQ, img_type)
                    self.imgs_SLQ[subfolder_name] = util.read_img_seq(
                        img_paths_SLQ, img_type)
                h, w = self.imgs_SLQ[subfolder_name].shape[-2:]
                if h % 4 != 0 or w % 4 != 0:
                    self.imgs_SLQ[subfolder_name] = self.imgs_SLQ[
                        subfolder_name][..., :h - (h % 4), :w - (w % 4)]
                    self.imgs_LQ[subfolder_name] = self.imgs_LQ[
                        subfolder_name][..., :self.scale *
                                        (h - (h % 4)), :self.scale * (w -
                                                                      (w % 4))]
                    self.imgs_GT[subfolder_name] = self.imgs_GT[
                        subfolder_name][..., :self.scale * self.scale *
                                        (h - (h % 4)), :self.scale *
                                        self.scale * (w - (w % 4))]

        else:
            raise ValueError(
                'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.'
            )
        '''
        if opt['degradation_mode'] == 'set':
            sigma_x = float(opt['sigma_x'])
            sigma_y = float(opt['sigma_y'])
            theta = float(opt['theta'])
            gen_kwargs = preprocessing.set_kernel_params(sigma_x=sigma_x, sigma_y=sigma_y, theta=theta)
            self.kernel_gen = rkg.Degradation(self.kernel_size, self.scale, **gen_kwargs)
            self.gen_kwargs_l = [gen_kwargs['sigma'][0], gen_kwargs['sigma'][1], gen_kwargs['theta']]
        '''
        if opt['degradation_mode'] == 'preset':
            self.kernel_gen = rkg.Degradation(self.kernel_size, self.scale)
            if self.name.lower() == 'vid4':
                self.kernel_dict = np.load(
                    '../pretrained_models/Mixed/Vid4.npy')
            elif self.name.lower() == 'reds':
                self.kernel_dict = np.load(
                    '../pretrained_models/Mixed/REDS.npy')
            else:
                raise NotImplementedError()
예제 #10
0
    def __getitem__(self, idx):
        # Randomly choose a sequence in a video
        # key = self.keys[idx]
        # str_idx, seq_idx = self.find_set(key)
        key, str_idx, seq_idx = self.find_set(idx)
        set_hr = self.dict_hr[key]

        if self.stride[str_idx] > 0:
            seq_end = seq_idx + self.sample_length[str_idx]
        else:
            seq_end = seq_idx - self.sample_length[str_idx]

        if seq_end >= 0:
            name_hr = set_hr[seq_idx:seq_end:self.stride[str_idx]]
        else:
            name_hr = set_hr[seq_idx::self.stride[str_idx]]

        if self.img_type == 'img':
            fn_read = imageio.imread
        elif self.img_type == 'bin':
            fn_read = np.load
        else:
            raise ValueError('Wrong img type: {}'.format(self.img_type))

        name = [path.join(key, path.basename(f)) for f in name_hr]
        seq_hr = [fn_read(f) for f in name_hr]
        seq_hr = np.stack(seq_hr, axis=-1)
        
        # if self.train:
            # sinc patch_size is decided in seq_LR scale, we have to make it twice larger
            # seq_hr = preprocessing.common_crop(seq_hr, patch_size=self.opt['datasets']['train']['patch_size']*2)
        seq_hr = preprocessing.np2tensor(seq_hr)        
        seq_hr = preprocessing.common_crop(seq_hr, patch_size=self.opt['datasets']['train']['patch_size'] * 3)
        kwargs = preprocessing.set_kernel_params()
        kwargs_l = kwargs['sigma']
        kwargs_l.append(kwargs['theta'])
        kwargs_l = torch.Tensor(kwargs_l)

        base_type = random.random()

        # include random noise for each frame
        kernel_gen = rkg.Degradation(self.opt['datasets']['train']['kernel_size'], self.scale, base_type, **kwargs)
        seq_lr = []
        seq_superlr = []
        for i in range(seq_hr.shape[0]):
            # kernel_gen.gen_new_noise()
            seq_lr_slice = kernel_gen.apply(seq_hr[i])
            seq_lr.append(seq_lr_slice)
            seq_superlr.append(kernel_gen.apply(seq_lr_slice))

        seq_lr = torch.stack(seq_lr, dim=0)
        seq_superlr = torch.stack(seq_superlr, dim=0)
        '''
        if not os.path.exists(os.path.join('../result', self.opt['name'])):
            os.makedirs(os.path.join('../result', self.opt['name']), exist_ok=True)
        imageio.imwrite(os.path.join('../result', self.opt['name'], name[2].split('/')[0]+'_'+name[2].split('/')[1]+'_before.png'), seq_lr[2].numpy().transpose(1,2,0))
        imageio.imwrite(os.path.join('../result', self.opt['name'], name[2].split('/')[0]+'_'+name[2].split('/')[1]+'_after.png'), seq_hr[2].numpy().transpose(1,2,0))

        if self.train:
            seq_hr, seq_lr = preprocessing.crop(seq_hr, seq_lr, patch_size=self.opt_data['train']['patch_size'])
            # seq_hr, seq_lr = preprocessing.augment(seq_hr, seq_lr)
        # imageio.imwrite(os.path.join('../result', self.opt['name'], name[2].split('/')[0]+'_'+name[2].split('/')[1]+'_after.png'), seq_lr[2].numpy().transpose(1,2,0))
        '''

        if self.train:
            seq_hr, seq_lr, seq_superlr = preprocessing.crop(seq_hr, seq_lr, seq_superlr, patch_size=self.opt['datasets']['train']['patch_size'])
            seq_hr, seq_lr, seq_superlr = preprocessing.augment(seq_hr, seq_lr, seq_superlr)

        return {'SuperLQs': seq_superlr, 'LQs': seq_lr, 'GT': seq_hr, 'Kernel': kernel_gen.kernel, 'Kernel_args': kwargs_l, 'name': name}