Esempio n. 1
0
    def __call__(self, data):
        d = dict(data)

        sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size,
                                  data[self.keys[0]].shape[1:])
        self.randomize(spatial_size=sp_size)

        if self.rand_2d_elastic.do_transform:
            grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
            grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
            grid = _torch_interp(
                input=grid.unsqueeze(0),
                scale_factor=list(self.rand_2d_elastic.deform_grid.spacing),
                mode=InterpolateMode.BICUBIC.value,
                align_corners=False,
            )
            grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
        else:
            grid = create_grid(spatial_size=sp_size)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_2d_elastic.resampler(
                d[key],
                grid,
                mode=self.mode[idx],
                padding_mode=self.padding_mode[idx])
        return d
Esempio n. 2
0
    def __call__(
        self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
    ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
        d = dict(data)

        sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:])
        self.randomize(spatial_size=sp_size)

        if self.rand_2d_elastic.do_transform:
            grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)
            grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)
            grid = torch.nn.functional.interpolate(  # type: ignore
                recompute_scale_factor=True,
                input=grid.unsqueeze(0),
                scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2),
                mode=InterpolateMode.BICUBIC.value,
                align_corners=False,
            )
            grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
        else:
            grid = create_grid(spatial_size=sp_size)

        for idx, key in enumerate(self.keys):
            d[key] = self.rand_2d_elastic.resampler(
                d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]
            )
        return d
Esempio n. 3
0
 def __call__(
     self,
     img: Union[np.ndarray, torch.Tensor],
     spatial_size: Optional[Union[Tuple[int, int], int]] = None,
     mode: Optional[Union[GridSampleMode, str]] = None,
     padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
 ) -> Union[np.ndarray, torch.Tensor]:
     sp_size = fall_back_tuple(spatial_size or self.spatial_size,
                               img.shape[1:])
     self.randomize()
     if self.do_transform:
         grid = self.deform_grid(spatial_size=sp_size)
         grid = torch.nn.functional.interpolate(
             input=grid.unsqueeze(0),
             scale_factor=list(ensure_tuple(self.deform_grid.spacing)),
             mode=InterpolateMode.BICUBIC.value,
             align_corners=False,
         )
         grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
     else:
         grid = create_grid(spatial_size=sp_size)
     return self.resampler(
         img,
         grid,
         mode=mode or self.mode,
         padding_mode=padding_mode or self.padding_mode,
     )
Esempio n. 4
0
 def __call__(self, data):
     d = dict(data)
     self.randomize(d[self.keys[0]].shape[1:])  # image shape from the first data key
     for key in self.keys:
         if self.random_center:
             d[key] = d[key][self._slices]
         else:
             cropper = CenterSpatialCrop(self._size)
             d[key] = cropper(d[key])
     return d
Esempio n. 5
0
 def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
     d = dict(data)
     self.randomize(d[self.keys[0]].shape[1:])  # image shape from the first data key
     assert self._size is not None
     for key in self.keys:
         if self.random_center:
             d[key] = d[key][self._slices]
         else:
             cropper = CenterSpatialCrop(self._size)
             d[key] = cropper(d[key])
     return d
Esempio n. 6
0
    def inverse(data: dict) -> Dict[Hashable, np.ndarray]:
        if not isinstance(data, Mapping):
            raise RuntimeError("Inverse can only currently be applied on dictionaries.")

        d = dict(data)
        for key in d:
            transforms = None
            if isinstance(d[key], MetaTensor):
                transforms = d[key].applied_operations
            else:
                transform_key = InvertibleTransform.trace_key(key)
                if transform_key in d:
                    transforms = d[transform_key]
            if not transforms or not isinstance(transforms[-1], Dict):
                continue
            if transforms[-1].get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__:
                xform = transforms.pop()
                cropping = CenterSpatialCrop(xform.get(TraceKeys.ORIG_SIZE, -1))
                with cropping.trace_transform(False):
                    d[key] = cropping(d[key])  # fallback to image size
        return d
Esempio n. 7
0
    def inverse(data: dict) -> Dict[Hashable, np.ndarray]:
        if not isinstance(data, dict):
            raise RuntimeError("Inverse can only currently be applied on dictionaries.")

        d = deepcopy(data)
        for key in d.keys():
            transform_key = str(key) + InverseKeys.KEY_SUFFIX
            if transform_key in d.keys():
                transform = d[transform_key][-1]
                if transform[InverseKeys.CLASS_NAME] == PadListDataCollate.__name__:
                    d[key] = CenterSpatialCrop(transform["orig_size"])(d[key])
                    # remove transform
                    d[transform_key].pop()
        return d
Esempio n. 8
0
 def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
     d = dict(data)
     self.randomize(d[self.keys[0]].shape[1:])  # image shape from the first data key
     if self._size is None:
         raise AssertionError
     for key in self.key_iterator(d):
         if self.random_center:
             self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]})  # type: ignore
             d[key] = d[key][self._slices]
         else:
             self.push_transform(d, key)
             cropper = CenterSpatialCrop(self._size)
             d[key] = cropper(d[key])
     return d
Esempio n. 9
0
    def inverse(data: dict) -> Dict[Hashable, np.ndarray]:
        if not isinstance(data, dict):
            raise RuntimeError(
                "Inverse can only currently be applied on dictionaries.")

        d = deepcopy(data)
        for key in d:
            transform_key = InvertibleTransform.trace_key(key)
            if transform_key in d:
                transform = d[transform_key][-1]
                if not isinstance(transform, Dict):
                    continue
                if transform.get(
                        TraceKeys.CLASS_NAME) == PadListDataCollate.__name__:
                    d[key] = CenterSpatialCrop(transform.get("orig_size", -1))(
                        d[key])  # fallback to image size
                    # remove transform
                    d[transform_key].pop()
        return d
Esempio n. 10
0
 def __init__(self,
              keys: KeysCollection,
              roi_size: Union[Sequence[int], int],
              allow_missing_keys: bool = False) -> None:
     super().__init__(keys, allow_missing_keys)
     self.cropper = CenterSpatialCrop(roi_size)
Esempio n. 11
0
 def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int],
                                                          int]) -> None:
     super().__init__(keys)
     self.cropper = CenterSpatialCrop(roi_size)
Esempio n. 12
0
 def __init__(self, keys: KeysCollection, roi_size):
     super().__init__(keys)
     self.cropper = CenterSpatialCrop(roi_size)
Esempio n. 13
0
 def __init__(self, keys, roi_size):
     super().__init__(keys)
     self.cropper = CenterSpatialCrop(roi_size)