Пример #1
0
    def _get_transform(**params) -> Callable:
        key_value_flag = params.pop("_key_value", False)

        if key_value_flag:
            transforms_composition = {
                key: ConfigExperiment._get_transform(**params_)
                for key, params_ in params.items()
            }

            transform = AugmentorCompose({
                key: Augmentor(
                    dict_key=key,
                    augment_fn=transform,
                    input_key=key,
                    output_key=key,
                )
                for key, transform in transforms_composition.items()
            })
        else:
            if "transforms" in params:
                transforms_composition = [
                    ConfigExperiment._get_transform(**transform_params)
                    for transform_params in params["transforms"]
                ]
                params.update(transforms=transforms_composition)

            transform = TRANSFORMS.get_from_params(**params)

        return transform
Пример #2
0
 def get_transforms(self, stage: str = None, mode: str = None):
     """
     Args:
         stage (str)
         mode (str)
     """
     if mode == "train":
         Augmentor1 = Augmentor(
             dict_key="images",
             augment_fn=lambda x: torch.from_numpy(x).float(),
         )
         Augmentor2 = Augmentor(dict_key="labels",
                                augment_fn=lambda x: torch.from_numpy(x))
         return transforms.Compose([Augmentor1, Augmentor2])
     elif mode == "valid":
         Augmentor1 = Augmentor(
             dict_key="images",
             augment_fn=lambda x: torch.from_numpy(x).float(),
         )
         Augmentor2 = Augmentor(dict_key="labels",
                                augment_fn=lambda x: torch.from_numpy(x))
         return transforms.Compose([Augmentor1, Augmentor2])
Пример #3
0
 def get_transforms(
         stage: str = None,
         dataset: str = None
 ):
     if dataset == 'train':
         train_transforms = compose([
             pre_transforms(),
             hard_transforms(),
             post_transforms()
         ])
         train_data_transforms = Augmentor(
             dict_key="image",
             augment_fn=lambda x: train_transforms(image=x)["image"]
         )
         return train_data_transforms
     elif dataset == 'valid':
         valid_transforms = compose([pre_transforms(), post_transforms()])
         valid_data_transforms = Augmentor(
             dict_key="image",
             augment_fn=lambda x: valid_transforms(image=x)["image"]
         )
         return valid_data_transforms
     else:
         raise NotImplementedError
Пример #4
0
        dict_transform=transform,
        batch_size=bs,
        num_workers=num_workers,
        shuffle=False,
    )

    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    return loaders


data_transform = transforms.Compose([
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: torch.from_numpy(x.copy().astype(np.float32) /
                                              255.0).unsqueeze_(0),
    ),
    Augmentor(
        dict_key="features",
        augment_fn=transforms.Normalize((0.5, ), (0.5, )),
    ),
    Augmentor(
        dict_key="targets",
        augment_fn=lambda x: torch.from_numpy(x.copy().astype(np.float32) /
                                              255.0).unsqueeze_(0),
    ),
])

loaders = get_loaders(data_transform)