Exemplo n.º 1
0
    def from_config(config):
        config = deepcopy(config)
        files_a, files_b = map(
            lambda x: sorted(glob(config[x], recursive=True)),
            ('files_a', 'files_b'))
        transform_fn = aug.get_transforms(size=config['size'],
                                          scope=config['scope'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])

        hash_fn = hash_from_paths
        # ToDo: add more hash functions
        verbose = config.get('verbose', True)
        data = subsample(data=zip(files_a, files_b),
                         bounds=config.get('bounds', (0, 1)),
                         hash_fn=hash_fn,
                         verbose=verbose)

        files_a, files_b = map(list, zip(*data))

        return PairedDataset(files_a=files_a,
                             files_b=files_b,
                             preload=config['preload'],
                             preload_size=config['preload_size'],
                             corrupt_fn=corrupt_fn,
                             normalize_fn=normalize_fn,
                             transform_fn=transform_fn,
                             verbose=verbose)
Exemplo n.º 2
0
 def __init__(self,
              imgs: Sequence[str],
              ):
     self.imgs = imgs
     self.normalize_fn = get_normalize()
     self.approx_img_size = 384
     logger.info(f'Dataset has been created with {len(self.imgs)} samples')
Exemplo n.º 3
0
 def __init__(self, weights_path: str, model_name: str = ''):
     with open('config/config.yaml') as cfg:
         config = yaml.load(cfg)
     model = get_generator(model_name or config['model'])
     model.load_state_dict(torch.load(weights_path)['model'])
     self.model = model.cuda()
     self.model.train(True)
     # GAN inference should be in train mode to use actual stats in norm layers,
     # it's not a bug
     self.normalize_fn = get_normalize()
Exemplo n.º 4
0
 def __init__(self, weights_path, model_name=''):
     # model = get_generator(model_name or config['model'])
     model = get_generator_new(weights_path[0:-11])
     model.load_state_dict(
         torch.load(weights_path,
                    map_location=lambda storage, loc: storage)['model'])
     if torch.cuda.is_available():
         self.model = model.cuda()
     else:
         self.model = model
     self.model.train(True)
     # GAN inference should be in train mode to use actual stats in norm layers,
     # it's not a bug
     self.normalize_fn = get_normalize()
Exemplo n.º 5
0
    def from_config(config):
        config = deepcopy(config)
        files = glob(
            f'{config.get("data_dir", "/home/arseny/datasets/idrnd_train")}/**/*.png',
            recursive=True)

        transform_fn = aug.get_transforms(size=config['size'],
                                          scope=config['scope'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])
        soften_fn = create_soften_fn(config.get('soften', 1))

        def hash_fn(x: str, salt: str = '') -> str:
            x = os.path.basename(x)
            label, video, frame = x.split('_')
            return sha1(f'{label}_{video}_{salt}'.encode()).hexdigest()

        verbose = config.get('verbose', True)
        n_fold = config['n_fold']
        total_folds = 10
        test = subsample(data=files,
                         bounds=(1 / total_folds * n_fold,
                                 1 / total_folds * (n_fold + 1)),
                         hash_fn=hash_fn,
                         verbose=verbose,
                         salt='validation')

        if config['test']:
            data = test
        else:
            files = set(files) - set(test)
            data = subsample(data=files,
                             bounds=config.get('bounds', (0, 1)),
                             hash_fn=hash_fn,
                             verbose=verbose,
                             salt=config['salt'])

        return IdRndDataset(files=tuple(data),
                            preload=config['preload'],
                            preload_size=config['preload_size'],
                            corrupt_fn=corrupt_fn,
                            normalize_fn=normalize_fn,
                            transform_fn=transform_fn,
                            soften_fn=soften_fn,
                            mixup=config['mixup'],
                            verbose=verbose)
Exemplo n.º 6
0
    def from_config(config):
        config = deepcopy(config)
        img_path = config.get('img_dir')
        labels_path = config.get('labels_path')

        imgs, labels = zip(*parse_labels(img_path, labels_path))
        transform_fn = aug.get_transforms(size=config['size'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])
        verbose = config.get('verbose', True)

        return TigerDataset(imgs=imgs,
                            labels=labels,
                            size=config['size'],
                            corrupt_fn=corrupt_fn,
                            normalize_fn=normalize_fn,
                            transform_fn=transform_fn,
                            verbose=verbose)
Exemplo n.º 7
0
    def from_config(config):
        config = deepcopy(config)
        files_a, files_b = [], []
        if config['phase'] == "train":
            files_a += glob("data/TRAIN/GOPRO_Large/*/blur/*.png")
            files_b += glob("data/TRAIN/GOPRO_Large/*/sharp/*.png")

            files_a += glob("data/TRAIN/GOPRO_Large/*/blur_gamma/*.png")
            files_b += glob("data/TRAIN/GOPRO_Large/*/sharp/*.png")

            files_a += glob("data/TRAIN/quantitative_datasets/*/input/*.jpg")
            files_b += glob("data/TRAIN/quantitative_datasets/*/GT/*.jpg")

            files_a += glob("data/TRAIN/REDS_train/train_blur/*/*.png")
            files_b += glob("data/TRAIN/REDS_train/train_sharp/*/*.png")

            for i in range(4):
                files_a += sorted(
                    glob(
                        "data/TRAIN/synthetic_dataset/uniform/*0{}.png".format(
                            i + 1)))
                files_b += sorted(
                    glob("data/TRAIN/synthetic_dataset/ground_truth/*.png"))

                files_a += sorted(
                    glob("data/TRAIN/synthetic_dataset/nonuniform/*0{}.png".
                         format(i + 1)))
                files_b += sorted(
                    glob("data/TRAIN/synthetic_dataset/ground_truth/*.png"))

        elif config['phase'] == "val":
            files_a += glob("data/TEST/GOPRO_Large/*/blur/*.png")
            files_b += glob("data/TEST/GOPRO_Large/*/sharp/*.png")
            # files_a += glob("data/TEST/GOPRO_Large/*/blur_gamma/*.png")
            # files_b += glob("data/TEST/GOPRO_Large/*/sharp/*.png")

        # files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
        transform_fn = aug.get_transforms(size=config['size'],
                                          scope=config['scope'],
                                          crop=config['crop'])
        normalize_fn = aug.get_normalize()
        corrupt_fn = aug.get_corrupt_function(config['corrupt'])

        hash_fn = hash_from_paths
        # ToDo: add more hash functions
        verbose = config.get('verbose', True)
        data = subsample(data=zip(files_a, files_b),
                         bounds=config.get('bounds', (0, 1)),
                         hash_fn=hash_fn,
                         verbose=verbose)

        files_a, files_b = map(list, zip(*data))

        return PairedDataset(files_a=files_a,
                             files_b=files_b,
                             preload=config['preload'],
                             preload_size=config['preload_size'],
                             corrupt_fn=corrupt_fn,
                             normalize_fn=normalize_fn,
                             transform_fn=transform_fn,
                             verbose=verbose)