Exemple #1
0
    def set_transform(obj, dataset_opt):
        """This function create and set the transform to the obj as attributes
        """
        obj.pre_transform = None
        obj.test_transform = None
        obj.train_transform = None
        obj.val_transform = None
        obj.inference_transform = None

        for key_name in dataset_opt.keys():
            if "transform" in key_name:
                new_name = key_name.replace("transforms", "transform")
                try:
                    transform = instantiate_transforms(
                        getattr(dataset_opt, key_name))
                except Exception:
                    log.exception("Error trying to create {}, {}".format(
                        new_name, getattr(dataset_opt, key_name)))
                    continue
                setattr(obj, new_name, transform)

        inference_transform = explode_transform(obj.pre_transform)
        inference_transform += explode_transform(obj.test_transform)
        obj.inference_transform = Compose(
            inference_transform) if len(inference_transform) > 0 else None
Exemple #2
0
    def test_lottery_transform_from_yaml(self):
        """
        test the lottery transform when params are indicated in the yaml
        """
        string = """

        - transform: LotteryTransform
          params:
            transform_options:
              - transform: GridSampling3D
                params:
                  size: 0.1
              - transform: Center
        """
        conf = OmegaConf.create(string)
        pos = torch.randn(10000, 3)
        x = torch.randn(10000, 6)
        dummy = torch.randn(10000, 6)
        data = Data(pos=pos, x=x, dummy=dummy)

        tr = instantiate_transforms(conf).transforms[0]
        tr(data)
        self.assertIsInstance(tr.random_transforms.transforms[0],
                              GridSampling3D)
        self.assertIsInstance(tr.random_transforms.transforms[1], T.Center)
    def test_random_param_transform_with_sphere_dropout(self):
        """
        test the random param transform transform when params are indicated in the yaml
        """
        string = """

        - transform: RandomParamTransform
          params:
            transform_name: RandomSphereDropout
            transform_params:
                radius:
                    min: 1
                    max: 2
                    type: "float"
                num_sphere:
                    min: 1
                    max: 5
                    type: "int"

        """
        conf = OmegaConf.create(string)

        pos = torch.randn(10000, 3)
        x = torch.randn(10000, 6)
        dummy = torch.randn(10000, 6)
        data = Data(pos=pos, x=x, dummy=dummy)

        tr = instantiate_transforms(conf).transforms[0]
        tr(data)
    def test_random_param_transform_with_grid_sampling(self):
        """
        test the random param transform transform when params are indicated in the yaml
        """
        string = """

        - transform: RandomParamTransform
          params:
            transform_name: GridSampling3D
            transform_params:
                size:
                    min: 0.1
                    max: 0.3
                    type: "float"
                mode:
                    value: "last"

        """
        conf = OmegaConf.create(string)
        pos = torch.randn(10000, 3)
        x = torch.randn(10000, 6)
        dummy = torch.randn(10000, 6)
        data = Data(pos=pos, x=x, dummy=dummy)

        tr = instantiate_transforms(conf).transforms[0]
        tr(data)
Exemple #5
0
 def test_InstantiateTransforms(self):
     conf = ListConfig([
         {
             "transform": "GridSampling",
             "params": {
                 "size": 0.1
             }
         },
         {
             "transform": "Center"
         },
     ])
     t = instantiate_transforms(conf)
     self.assertIsInstance(t.transforms[0], GridSampling)
     self.assertIsInstance(t.transforms[1], T.Center)
    def test_compose_transform(self):
        string = """
        - transform: ComposeTransform
          params:
            transform_options:
              - transform: GridSampling3D
                params:
                  size: 0.1
              - transform: RandomNoise
                params:
                  sigma: 0.05
        """

        conf = OmegaConf.create(string)

        pos = torch.randn(10000, 3)
        x = torch.randn(10000, 6)
        dummy = torch.randn(10000, 6)
        data = Data(pos=pos, x=x, dummy=dummy)
        tr = instantiate_transforms(conf)
        tr(data)