示例#1
0
def create_dataset(dataset_opt):
    mode = dataset_opt["mode"]
    if mode == "LQ":  # Predictor
        from data.LQ_dataset import LQDataset as D

        dataset = D(dataset_opt)
    elif mode == "LQGTker":  # SFTMD
        from data.LQGTker_dataset import LQGTKerDataset as D

        dataset = D(dataset_opt)
    elif mode == "SRker":  # Corrector
        from data.SRker_dataset import SRkerDataset as D

        dataset = D(dataset_opt)
    # elif mode == 'LQGTseg_bg':
    #     from data.LQGT_seg_bg_dataset import LQGTSeg_BG_Dataset as D
    else:
        raise NotImplementedError("Dataset [{:s}] is not recognized.".format(mode))

    logger = logging.getLogger("base")
    logger.info(
        "Dataset [{:s} - {:s}] is created.".format(
            dataset.__class__.__name__, dataset_opt["name"]
        )
    )
    return dataset
示例#2
0
def create_dataset(dataset_opt):
    mode = dataset_opt['mode']
    # datasets for image restoration
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
    elif mode == 'Color':
        from data.Color_dataset import ColorDataset as D
    elif mode == 'ContinueLQGT':
        from data.ContinueLQGT_dataset import ContinueLQGTDataset as D
    # datasets for video restoration
    elif mode == 'REDS':
        from data.REDS_dataset import REDSDataset as D
    elif mode == 'Vimeo90K':
        from data.Vimeo90K_dataset import Vimeo90KDataset as D
    elif mode == 'video_test':
        from data.video_test_dataset import VideoTestDataset as D
    else:
        raise NotImplementedError(
            'Dataset [{:s}] is not recognized.'.format(mode))
    dataset = D(dataset_opt)

    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(
        dataset.__class__.__name__, dataset_opt['name']))
    return dataset
示例#3
0
def create_dataset(dataset_opt):
    # assign dataset
    # Vimeo90K: Vimeo90K train&val
    # video_test: Vid4 test
    mode = dataset_opt['mode']
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
    elif mode == 'Vimeo90K':
        from data.Vimeo90K_dataset import Vimeo90KDataset as D
    elif mode == 'video_test':
        from data.video_test_dataset import VideoTestDataset as D
    elif mode in ['DIV2K_easy', 'DIV2K_train']:
        from data.DIV2K_dataset import ImageTrainDataset as D
    elif mode in ['DIV2K_val']:
        from data.DIV2K_dataset import ImageValDataset as D
    else:
        raise NotImplementedError(
            'Dataset [{:s}] is not recognized.'.format(mode))
    dataset = D(dataset_opt)

    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(
        dataset.__class__.__name__, dataset_opt['name']))
    return dataset
示例#4
0
def create_dataset(dataset_opt):
    mode = dataset_opt['mode']
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
    elif mode == 'VISR':
        from data.VISR_dataset import VISRDataset as D
    elif mode == 'SEV':
        from data.SEV_dataset import SEVDataset as D
    elif mode == 'REDS':
        from data.REDS_dataset import REDSDataset as D
    elif mode == 'video_test':
        from data.video_test_dataset import VideoTestDataset as D
    # elif mode == 'LQGTseg_bg':
    #     from data.LQGT_seg_bg_dataset import LQGTSeg_BG_Dataset as D
    else:
        raise NotImplementedError(
            'Dataset [{:s}] is not recognized.'.format(mode))
    dataset = D(dataset_opt)

    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(
        dataset.__class__.__name__, dataset_opt['name']))
    return dataset
def create_dataset(dataset_opt):
    mode = dataset_opt['mode']  # mode ~ which dataset to use
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
        dataset = D(dataset_opt)
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
        dataset = D(dataset_opt)
    elif mode == 'FastMRI':
        from data.fastmri_dataset import FASTMRIDataset as D
        from data.fastmri import subsample, transforms
        # Create a mask function
        mask_func = subsample.RandomMaskFunc(center_fractions=[0.08],
                                             accelerations=[4])

        class DataTransform:
            def __call__(self, target, mask_func, seed=None):
                # Preprocess the data here
                # target shape: [H, W, 1] or [H, W, 3]
                if target.shape[2] == 1:
                    img = np.concatenate((target, np.zeros_like(target)),
                                         axis=2)
                assert img.shape[-1] == 2
                img = transforms.to_tensor(img)
                kspace = transforms.fft2(img)

                center_kspace, _ = transforms.apply_mask(kspace,
                                                         mask_func,
                                                         seed=seed)
                img_LF = transforms.complex_abs(
                    transforms.ifft2(center_kspace))
                img_LF = img_LF.unsqueeze(0)
                # img_LF tensor should have shape [H, W, ?]
                target = transforms.to_tensor(np.transpose(
                    target, (2, 0, 1)))  # target shape [1, H, W]
                return img_LF, target

        dataset = D(dataset_opt, mask_func, transform=DataTransform())
    else:
        raise NotImplementedError(
            'Dataset [{:s}] is not recognized.'.format(mode))

    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(
        dataset.__class__.__name__, dataset_opt['name']))
    return dataset
示例#6
0
文件: __init__.py 项目: zoq/RankSRGAN
def create_dataset(dataset_opt, is_train=True):
    mode = dataset_opt['mode']
    # datasets for image restoration
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
    elif mode == 'RANK_IMIM_Pair':
        from data.Rank_IMIM_Pair_dataset import RANK_IMIM_Pair_Dataset as D
    else:
        raise NotImplementedError(
            'Dataset [{:s}] is not recognized.'.format(mode))
    if 'RANK_IMIM_Pair' in mode:
        dataset = D(dataset_opt, is_train=is_train)
    else:
        dataset = D(dataset_opt)
    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(
        dataset.__class__.__name__, dataset_opt['name']))
    return dataset
示例#7
0
def create_dataset(dataset_opt):
    mode = dataset_opt['mode']
    # datasets for image restoration
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
    else:
        raise NotImplementedError(
            'Dataset [{:s}] is not recognized.'.format(mode))
    dataset = D(dataset_opt)

    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(
        dataset.__class__.__name__, dataset_opt['name']))
    return dataset
示例#8
0
def create_dataset(dataset_opt):
    mode = dataset_opt['mode']
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
    elif mode == 'LQGT_nopatch':
        from data.LQGT_nopatch_dataset import LQGTDataset as D
    # elif mode == 'LQGTseg_bg':
    #     from data.LQGT_seg_bg_dataset import LQGTSeg_BG_Dataset as D
    else:
        raise NotImplementedError(
            'Dataset [{:s}] is not recognized.'.format(mode))
    dataset = D(dataset_opt)

    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(
        dataset.__class__.__name__, dataset_opt['name']))
    return dataset
示例#9
0
def create_dataset(dataset_opt):
    mode = dataset_opt['mode']
    # datasets for image restoration
    if mode == 'LQ':
        from data.LQ_dataset import LQDataset as D
    elif mode == 'LQGT':
        from data.LQGT_dataset import LQGTDataset as D
    # datasets for video restoration
    elif mode == 'REDS':
        from data.REDS_dataset import REDSDataset as D
    elif mode == 'REDSImg':
        from data.REDS_dataset import REDSImgDataset as D
    elif mode == 'REDSMultiImg':
        from data.REDS_dataset import REDSMultiImgDataset as D
    elif mode == 'MultiREDS':
        from data.REDS_dataset import MultiREDSDataset as D
    elif mode == 'MetaREDS':
        from data.REDS_dataset import MetaREDSDataset as D
    elif mode == 'MetaREDSOnline':
        from data.REDS_dataset import MetaREDSDatasetOnline as D
    elif mode == 'UPREDS':
        from data.REDS_dataset import UPREDSDataset as D
    elif mode == 'Vimeo90K':
        from data.Vimeo90K_dataset import Vimeo90KDataset as D
    elif mode == 'UPVimeo':
        from data.Vimeo90K_dataset import UPVimeoDataset as D
    elif mode == 'video_test':
        from data.video_test_dataset import VideoTestDataset as D
    elif mode == 'online_video_test':
        from data.video_test_dataset import OnlineVideoTestDataset as D
    elif mode == 'img_test':
        from data.video_test_dataset import ImgTestDataset as D
    else:
        raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
    dataset = D(dataset_opt)

    logger = logging.getLogger('base')
    logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
                                                           dataset_opt['name']))
    return dataset