def set_transform(obj, dataset_opt): """This function create and set the transform to the obj as attributes """ obj.pre_transform = None obj.test_transform = None obj.train_transform = None obj.val_transform = None obj.inference_transform = None for key_name in dataset_opt.keys(): if "transform" in key_name: new_name = key_name.replace("transforms", "transform") try: transform = instantiate_transforms( getattr(dataset_opt, key_name)) except Exception: log.exception("Error trying to create {}, {}".format( new_name, getattr(dataset_opt, key_name))) continue setattr(obj, new_name, transform) inference_transform = explode_transform(obj.pre_transform) inference_transform += explode_transform(obj.test_transform) obj.inference_transform = Compose( inference_transform) if len(inference_transform) > 0 else None
def test_lottery_transform_from_yaml(self): """ test the lottery transform when params are indicated in the yaml """ string = """ - transform: LotteryTransform params: transform_options: - transform: GridSampling3D params: size: 0.1 - transform: Center """ conf = OmegaConf.create(string) pos = torch.randn(10000, 3) x = torch.randn(10000, 6) dummy = torch.randn(10000, 6) data = Data(pos=pos, x=x, dummy=dummy) tr = instantiate_transforms(conf).transforms[0] tr(data) self.assertIsInstance(tr.random_transforms.transforms[0], GridSampling3D) self.assertIsInstance(tr.random_transforms.transforms[1], T.Center)
def test_random_param_transform_with_sphere_dropout(self): """ test the random param transform transform when params are indicated in the yaml """ string = """ - transform: RandomParamTransform params: transform_name: RandomSphereDropout transform_params: radius: min: 1 max: 2 type: "float" num_sphere: min: 1 max: 5 type: "int" """ conf = OmegaConf.create(string) pos = torch.randn(10000, 3) x = torch.randn(10000, 6) dummy = torch.randn(10000, 6) data = Data(pos=pos, x=x, dummy=dummy) tr = instantiate_transforms(conf).transforms[0] tr(data)
def test_random_param_transform_with_grid_sampling(self): """ test the random param transform transform when params are indicated in the yaml """ string = """ - transform: RandomParamTransform params: transform_name: GridSampling3D transform_params: size: min: 0.1 max: 0.3 type: "float" mode: value: "last" """ conf = OmegaConf.create(string) pos = torch.randn(10000, 3) x = torch.randn(10000, 6) dummy = torch.randn(10000, 6) data = Data(pos=pos, x=x, dummy=dummy) tr = instantiate_transforms(conf).transforms[0] tr(data)
def test_InstantiateTransforms(self): conf = ListConfig([ { "transform": "GridSampling", "params": { "size": 0.1 } }, { "transform": "Center" }, ]) t = instantiate_transforms(conf) self.assertIsInstance(t.transforms[0], GridSampling) self.assertIsInstance(t.transforms[1], T.Center)
def test_compose_transform(self): string = """ - transform: ComposeTransform params: transform_options: - transform: GridSampling3D params: size: 0.1 - transform: RandomNoise params: sigma: 0.05 """ conf = OmegaConf.create(string) pos = torch.randn(10000, 3) x = torch.randn(10000, 6) dummy = torch.randn(10000, 6) data = Data(pos=pos, x=x, dummy=dummy) tr = instantiate_transforms(conf) tr(data)