コード例 #1
0
ファイル: dataset.py プロジェクト: Agchai52/DeblurGANv2
    def from_config(config):
        config = deepcopy(config)
        files_a, files_b = map(
            lambda x: sorted(glob(config[x], recursive=True)),
            ('files_a', 'files_b'))
        transform_fn = aug.get_transforms(size=config['size'],
                                          scope=config['scope'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])

        hash_fn = hash_from_paths
        # ToDo: add more hash functions
        verbose = config.get('verbose', True)
        data = subsample(data=zip(files_a, files_b),
                         bounds=config.get('bounds', (0, 1)),
                         hash_fn=hash_fn,
                         verbose=verbose)

        files_a, files_b = map(list, zip(*data))

        return PairedDataset(files_a=files_a,
                             files_b=files_b,
                             preload=config['preload'],
                             preload_size=config['preload_size'],
                             corrupt_fn=corrupt_fn,
                             normalize_fn=normalize_fn,
                             transform_fn=transform_fn,
                             verbose=verbose)
コード例 #2
0
 def test_aug(self):
     for scope in ('geometric', 'weak'):
         for crop in ('random', 'center'):
             aug_pipeline = get_transforms(80, scope=scope, crop=crop)
             a, b = self.make_images()
             a, b = aug_pipeline(a, b)
             np.testing.assert_allclose(a, b)
コード例 #3
0
ファイル: train.py プロジェクト: lzmisscc/DeblurGANv2
def main(config_path='config/config.yaml'):
    with open(config_path, 'r') as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)

    batch_size = config.pop('batch_size')
    get_dataloader = partial(
        DataLoader,
        batch_size=batch_size,
        num_workers=0 if os.environ.get('DEBUG') else cpu_count(),
        shuffle=True,
        drop_last=True)

    datasets = (NoPairDataset(
        "datasets/tilt", get_transforms(256),
        get_corrupt_function(config=config['train']['corrupt'])),
                NoPairDataset("datasets/tilt", get_transforms(256)))
    train, val = map(get_dataloader, datasets)
    trainer = Trainer(config, train=train, val=val)
    trainer.train()
コード例 #4
0
ファイル: dataset.py プロジェクト: urevoleg/jupyter_notebooks
    def from_config(config):
        config = deepcopy(config)
        files = glob(
            f'{config.get("data_dir", "/home/arseny/datasets/idrnd_train")}/**/*.png',
            recursive=True)

        transform_fn = aug.get_transforms(size=config['size'],
                                          scope=config['scope'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])
        soften_fn = create_soften_fn(config.get('soften', 1))

        def hash_fn(x: str, salt: str = '') -> str:
            x = os.path.basename(x)
            label, video, frame = x.split('_')
            return sha1(f'{label}_{video}_{salt}'.encode()).hexdigest()

        verbose = config.get('verbose', True)
        n_fold = config['n_fold']
        total_folds = 10
        test = subsample(data=files,
                         bounds=(1 / total_folds * n_fold,
                                 1 / total_folds * (n_fold + 1)),
                         hash_fn=hash_fn,
                         verbose=verbose,
                         salt='validation')

        if config['test']:
            data = test
        else:
            files = set(files) - set(test)
            data = subsample(data=files,
                             bounds=config.get('bounds', (0, 1)),
                             hash_fn=hash_fn,
                             verbose=verbose,
                             salt=config['salt'])

        return IdRndDataset(files=tuple(data),
                            preload=config['preload'],
                            preload_size=config['preload_size'],
                            corrupt_fn=corrupt_fn,
                            normalize_fn=normalize_fn,
                            transform_fn=transform_fn,
                            soften_fn=soften_fn,
                            mixup=config['mixup'],
                            verbose=verbose)
コード例 #5
0
    def from_config(config):
        config = deepcopy(config)
        img_path = config.get('img_dir')
        labels_path = config.get('labels_path')

        imgs, labels = zip(*parse_labels(img_path, labels_path))
        transform_fn = aug.get_transforms(size=config['size'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])
        verbose = config.get('verbose', True)

        return TigerDataset(imgs=imgs,
                            labels=labels,
                            size=config['size'],
                            corrupt_fn=corrupt_fn,
                            normalize_fn=normalize_fn,
                            transform_fn=transform_fn,
                            verbose=verbose)
コード例 #6
0
ファイル: dataset.py プロジェクト: diyar-m/DeblurGANv2
    def from_config(config):
        config = deepcopy(config)
        files_a, files_b = [], []
        if config['phase'] == "train":
            files_a += glob("data/TRAIN/GOPRO_Large/*/blur/*.png")
            files_b += glob("data/TRAIN/GOPRO_Large/*/sharp/*.png")

            files_a += glob("data/TRAIN/GOPRO_Large/*/blur_gamma/*.png")
            files_b += glob("data/TRAIN/GOPRO_Large/*/sharp/*.png")

            files_a += glob("data/TRAIN/quantitative_datasets/*/input/*.jpg")
            files_b += glob("data/TRAIN/quantitative_datasets/*/GT/*.jpg")

            files_a += glob("data/TRAIN/REDS_train/train_blur/*/*.png")
            files_b += glob("data/TRAIN/REDS_train/train_sharp/*/*.png")

            for i in range(4):
                files_a += sorted(
                    glob(
                        "data/TRAIN/synthetic_dataset/uniform/*0{}.png".format(
                            i + 1)))
                files_b += sorted(
                    glob("data/TRAIN/synthetic_dataset/ground_truth/*.png"))

                files_a += sorted(
                    glob("data/TRAIN/synthetic_dataset/nonuniform/*0{}.png".
                         format(i + 1)))
                files_b += sorted(
                    glob("data/TRAIN/synthetic_dataset/ground_truth/*.png"))

        elif config['phase'] == "val":
            files_a += glob("data/TEST/GOPRO_Large/*/blur/*.png")
            files_b += glob("data/TEST/GOPRO_Large/*/sharp/*.png")
            # files_a += glob("data/TEST/GOPRO_Large/*/blur_gamma/*.png")
            # files_b += glob("data/TEST/GOPRO_Large/*/sharp/*.png")

        # files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
        transform_fn = aug.get_transforms(size=config['size'],
                                          scope=config['scope'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])

        hash_fn = hash_from_paths
        # ToDo: add more hash functions
        verbose = config.get('verbose', True)
        data = subsample(data=zip(files_a, files_b),
                         bounds=config.get('bounds', (0, 1)),
                         hash_fn=hash_fn,
                         verbose=verbose)

        files_a, files_b = map(list, zip(*data))

        return PairedDataset(files_a=files_a,
                             files_b=files_b,
                             preload=config['preload'],
                             preload_size=config['preload_size'],
                             corrupt_fn=corrupt_fn,
                             normalize_fn=normalize_fn,
                             transform_fn=transform_fn,
                             verbose=verbose)