Example #1
0
    def __call__(
        self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
    ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
        d = dict(data)

        sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size,
                                  data[self.keys[0]].shape[1:])
        self.randomize(spatial_size=sp_size)

        if self.rand_2d_elastic.do_transform:
            grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
            grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
            grid = _torch_interp(
                input=grid.unsqueeze(0),
                scale_factor=ensure_tuple_rep(
                    self.rand_2d_elastic.deform_grid.spacing, 2),
                mode=InterpolateMode.BICUBIC.value,
                align_corners=False,
            )
            grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
        else:
            grid = create_grid(spatial_size=sp_size)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_2d_elastic.resampler(
                d[key],
                grid,
                mode=self.mode[idx],
                padding_mode=self.padding_mode[idx])
        return d
Example #2
0
    def __call__(self, data):
        d = dict(data)

        sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size,
                                  data[self.keys[0]].shape[1:])
        self.randomize(spatial_size=sp_size)

        if self.rand_2d_elastic.do_transform:
            grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
            grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
            grid = _torch_interp(
                input=grid[None],
                scale_factor=list(self.rand_2d_elastic.deform_grid.spacing),
                mode=InterpolateMode.BICUBIC.value,
                align_corners=False,
            )
            grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
        else:
            grid = create_grid(spatial_size=sp_size)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_2d_elastic.resampler(
                d[key],
                grid,
                mode=self.mode[idx],
                padding_mode=self.padding_mode[idx])
        return d
Example #3
0
    def __call__(self, data):
        d = dict(data)
        spatial_size = self.rand_2d_elastic.spatial_size
        self.randomize(spatial_size)

        if self.rand_2d_elastic.do_transform:
            grid = self.rand_2d_elastic.deform_grid(spatial_size)
            grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
            grid = _torch_interp(input=grid[None],
                                 size=spatial_size,
                                 mode="bicubic",
                                 align_corners=False)[0]
        else:
            grid = create_grid(spatial_size)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_2d_elastic.resampler(
                d[key],
                grid,
                padding_mode=self.padding_mode[idx],
                mode=self.mode[idx])
        return d