def __init__(
        self,
        keys: KeysCollection,
        range_x: Union[Tuple[float, float], float] = 0.0,
        range_y: Union[Tuple[float, float], float] = 0.0,
        range_z: Union[Tuple[float, float], float] = 0.0,
        prob: float = 0.1,
        keep_size: bool = True,
        mode: GridSampleModeSequence = GridSampleMode.BILINEAR,
        padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER,
        align_corners: Union[Sequence[bool], bool] = False,
    ):
        super().__init__(keys)
        self.range_x = ensure_tuple(range_x)
        if len(self.range_x) == 1:
            self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]]))
        self.range_y = ensure_tuple(range_y)
        if len(self.range_y) == 1:
            self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]]))
        self.range_z = ensure_tuple(range_z)
        if len(self.range_z) == 1:
            self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]]))

        self.prob = prob
        self.keep_size = keep_size
        self.mode = ensure_tuple_rep(mode, len(self.keys))
        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))
        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))

        self._do_transform = False
        self.x = 0.0
        self.y = 0.0
        self.z = 0.0
 def randomize(self) -> None:  # type: ignore # see issue #495
     self._do_transform = self.R.random_sample() < self.prob
     if isinstance(self.min_zoom, Iterable):
         _min_zoom = ensure_tuple(self.min_zoom)
         _max_zoom = ensure_tuple(self.max_zoom)
         self._zoom = [self.R.uniform(l, h) for l, h in zip(_min_zoom, _max_zoom)]
     else:
         # to keep the spatial shape ratio, use same random zoom factor for all dims
         self._zoom = self.R.uniform(self.min_zoom, self.max_zoom)
Example #3
0
 def __init__(self, keys: KeysCollection):
     self.keys: Tuple[Any, ...] = ensure_tuple(keys)
     if not self.keys:
         raise ValueError("keys unspecified")
     for key in self.keys:
         if not isinstance(key, Hashable):
             raise ValueError(f"keys should be a hashable or a sequence of hashables, got {type(key)}")
    def __init__(self, keys: KeysCollection, times: int,
                 names: KeysCollection):
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`ponai.transforms.compose.MapTransform`
            times: expected copy times, for example, if keys is "img", times is 3,
                it will add 3 copies of "img" data to the dictionary.
            names: the names coresponding to the newly copied data,
                the length should match `len(keys) x times`. for example, if keys is ["img", "seg"]
                and times is 2, names can be: ["img_1", "seg_1", "img_2", "seg_2"].

        Raises:
            ValueError: times must be greater than 0.
            ValueError: length of names does not match `len(keys) x times`.

        """
        super().__init__(keys)
        if times < 1:
            raise ValueError("times must be greater than 0.")
        self.times = times
        names = ensure_tuple(names)
        if len(names) != (len(self.keys) * times):
            raise ValueError(
                "length of names does not match `len(keys) x times`.")
        self.names = names
Example #5
0
def generate_spatial_bounding_box(
    img: np.ndarray,
    select_fn: Callable = lambda x: x > 0,
    channel_indexes: Optional[IndexSelection] = None,
    margin: int = 0,
):
    """
    generate the spatial bounding box of foreground in the image with start-end positions.
    Users can define arbitrary function to select expected foreground from the whole image or specified channels.
    And it can also add margin to every dim of the bounding box.

    Args:
        img (ndarrary): source image to generate bounding box from.
        select_fn: function to select expected foreground, default is to select values > 0.
        channel_indexes: if defined, select foreground only on the specified channels
            of image. if None, select foreground on the whole image.
        margin: add margin to all dims of the bounding box.
    """
    assert isinstance(margin, int), "margin must be int type."
    data = img[[*(ensure_tuple(channel_indexes))]] if channel_indexes is not None else img
    data = np.any(select_fn(data), axis=0)
    nonzero_idx = np.nonzero(data)

    box_start = list()
    box_end = list()
    for i in range(data.ndim):
        assert len(nonzero_idx[i]) > 0, f"did not find nonzero index at spatial dim {i}"
        box_start.append(max(0, np.min(nonzero_idx[i]) - margin))
        box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin + 1))
    return box_start, box_end
Example #6
0
 def randomize(self, img_size):
     self._size = fall_back_tuple(self.roi_size, img_size)
     if self.random_size:
         self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))]
     if self.random_center:
         valid_size = get_valid_patch_size(img_size, self._size)
         self._slices = ensure_tuple(slice(None)) + get_random_patch(img_size, valid_size, self.R)
Example #7
0
    def __call__(self, img, mode: Optional[Union[NumpyPadMode, str]] = None):
        """
        Args:
            img: data to be transformed, assuming `img` is channel-first and
                padding doesn't apply to the channel dim.
            mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
                ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
                One of the listed string values or a user supplied function. Defaults to ``self.mode``.
                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

        Raises:
            ValueError: spatial_border must be int number and can not be less than 0.
            ValueError: unsupported length of spatial_border definition.
        """
        spatial_shape = img.shape[1:]
        spatial_border = ensure_tuple(self.spatial_border)
        for b in spatial_border:
            if b < 0 or not isinstance(b, int):
                raise ValueError("spatial_border must be int number and can not be less than 0.")

        if len(spatial_border) == 1:
            data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in range(len(spatial_shape))]
        elif len(spatial_border) == len(spatial_shape):
            data_pad_width = [(spatial_border[i], spatial_border[i]) for i in range(len(spatial_shape))]
        elif len(spatial_border) == len(spatial_shape) * 2:
            data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))]
        else:
            raise ValueError("unsupported length of spatial_border definition.")

        return np.pad(
            img, [(0, 0)] + data_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value
        )
Example #8
0
def create_translate(spatial_dims: int, shift):
    """
    create a translation matrix

    Args:
        spatial_dims: spatial rank
        shift (floats): translate factors, defaults to 0.
    """
    shift = ensure_tuple(shift)
    affine = np.eye(spatial_dims + 1)
    for i, a in enumerate(shift[:spatial_dims]):
        affine[i, spatial_dims] = a
    return affine
Example #9
0
 def __init__(
     self, select_fn: Callable = lambda x: x > 0, channel_indexes: Optional[IndexSelection] = None, margin: int = 0
 ):
     """
     Args:
         select_fn: function to select expected foreground, default is to select values > 0.
         channel_indexes: if defined, select foreground only on the specified channels
             of image. if None, select foreground on the whole image.
         margin: add margin to all dims of the bounding box.
     """
     self.select_fn = select_fn
     self.channel_indexes = ensure_tuple(channel_indexes) if channel_indexes is not None else None
     self.margin = margin
Example #10
0
    def __init__(
        self,
        spatial_size=None,
        normalized: bool = False,
        mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
        padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.ZEROS,
        align_corners: bool = False,
        reverse_indexing: bool = True,
    ):
        """
        Apply affine transformations with a batch of affine matrices.

        When `normalized=False` and `reverse_indexing=True`,
        it does the commonly used resampling in the 'pull' direction
        following the ``scipy.ndimage.affine_transform`` convention.
        In this case `theta` is equivalent to (ndim+1, ndim+1) input ``matrix`` of ``scipy.ndimage.affine_transform``,
        operates on homogeneous coordinates.
        See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html

        When `normalized=True` and `reverse_indexing=False`,
        it applies `theta` to the normalized coordinates (coords. in the range of [-1, 1]) directly.
        This is often used with `align_corners=False` to achieve resolution-agnostic resampling,
        thus useful as a part of trainable modules such as the spatial transformer networks.
        See also: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

        Args:
            spatial_size (list or tuple of int): output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src` input of `self.forward`.
            normalized: indicating whether the provided affine matrix `theta` is defined
                for the normalized coordinates. If `normalized=False`, `theta` will be converted
                to operate on normalized coordinates as pytorch affine_grid works with the normalized
                coordinates.
            mode: {``"bilinear"``, ``"nearest"``}
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"zeros"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            align_corners: see also https://pytorch.org/docs/stable/nn.functional.html#grid-sample.
            reverse_indexing: whether to reverse the spatial indexing of image and coordinates.
                set to `False` if `theta` follows pytorch's default "D, H, W" convention.
                set to `True` if `theta` follows `scipy.ndimage` default "i, j, k" convention.
        """
        super().__init__()
        self.spatial_size = ensure_tuple(spatial_size) if spatial_size is not None else None
        self.normalized = normalized
        self.mode: GridSampleMode = GridSampleMode(mode)
        self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)
        self.align_corners = align_corners
        self.reverse_indexing = reverse_indexing
Example #11
0
    def __call__(self, filename):
        """
        Args:
            filename (str, list, tuple, file): path file or file-like object or a list of files.
        """
        filename = ensure_tuple(filename)
        img_array = list()
        compatible_meta = dict()
        for name in filename:
            img = nib.load(name)
            img = correct_nifti_header_if_necessary(img)
            header = dict(img.header)
            header["filename_or_obj"] = name
            header["affine"] = img.affine
            header["original_affine"] = img.affine.copy()
            header["as_closest_canonical"] = self.as_closest_canonical
            ndim = img.header["dim"][0]
            spatial_rank = min(ndim, 3)
            header["spatial_shape"] = img.header["dim"][1:spatial_rank + 1]

            if self.as_closest_canonical:
                img = nib.as_closest_canonical(img)
                header["affine"] = img.affine

            img_array.append(np.array(img.get_fdata(dtype=self.dtype)))
            img.uncache()

            if self.image_only:
                continue

            if not compatible_meta:
                for meta_key in header:
                    meta_datum = header[meta_key]
                    # pytype: disable=attribute-error
                    if (type(meta_datum).__name__ == "ndarray"
                            and np_str_obj_array_pattern.search(
                                meta_datum.dtype.str) is not None):
                        continue
                    # pytype: enable=attribute-error
                    compatible_meta[meta_key] = meta_datum
            else:
                assert np.allclose(
                    header["affine"], compatible_meta["affine"]
                ), "affine data of all images should be same."

        img_array = np.stack(img_array,
                             axis=0) if len(img_array) > 1 else img_array[0]
        if self.image_only:
            return img_array
        return img_array, compatible_meta
Example #12
0
 def __init__(self,
              applied_labels,
              independent: bool = True,
              connectivity: Optional[int] = None):
     """
     Args:
         applied_labels (int, list or tuple of int): Labels for applying the connected component on.
             If only one channel. The pixel whose value is not in this list will remain unchanged.
             If the data is in one-hot format, this is used to determine what channels to apply.
         independent (bool): consider several labels as a whole or independent, default is `True`.
             Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case
             you want this "independent" to be specified as False.
         connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
             Accepted values are ranging from  1 to input.ndim. If ``None``, a full
             connectivity of ``input.ndim`` is used.
     """
     super().__init__()
     self.applied_labels = ensure_tuple(applied_labels)
     self.independent = independent
     self.connectivity = connectivity
Example #13
0
def create_rotate(spatial_dims: int, radians):
    """
    create a 2D or 3D rotation matrix

    Args:
        spatial_dims: {``2``, ``3``} spatial rank
        radians (float or a sequence of floats): rotation radians
            when spatial_dims == 3, the `radians` sequence corresponds to
            rotation in the 1st, 2nd, and 3rd dim respectively.

    Raises:
        ValueError: create_rotate got spatial_dims={spatial_dims}, radians={radians}.

    """
    radians = ensure_tuple(radians)
    if spatial_dims == 2:
        if len(radians) >= 1:
            sin_, cos_ = np.sin(radians[0]), np.cos(radians[0])
            return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]])

    if spatial_dims == 3:
        affine = None
        if len(radians) >= 1:
            sin_, cos_ = np.sin(radians[0]), np.cos(radians[0])
            affine = np.array(
                [[1.0, 0.0, 0.0, 0.0], [0.0, cos_, -sin_, 0.0], [0.0, sin_, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]]
            )
        if len(radians) >= 2:
            sin_, cos_ = np.sin(radians[1]), np.cos(radians[1])
            affine = affine @ np.array(
                [[cos_, 0.0, sin_, 0.0], [0.0, 1.0, 0.0, 0.0], [-sin_, 0.0, cos_, 0.0], [0.0, 0.0, 0.0, 1.0]]
            )
        if len(radians) >= 3:
            sin_, cos_ = np.sin(radians[2]), np.cos(radians[2])
            affine = affine @ np.array(
                [[cos_, -sin_, 0.0, 0.0], [sin_, cos_, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
            )
        return affine

    raise ValueError(f"create_rotate got spatial_dims={spatial_dims}, radians={radians}.")
Example #14
0
 def __init__(
     self,
     keys: KeysCollection,
     source_key: str,
     select_fn: Callable = lambda x: x > 0,
     channel_indexes: Optional[IndexSelection] = None,
     margin: int = 0,
 ):
     """
     Args:
         keys: keys of the corresponding items to be transformed.
             See also: :py:class:`ponai.transforms.compose.MapTransform`
         source_key: data source to generate the bounding box of foreground, can be image or label, etc.
         select_fn: function to select expected foreground, default is to select values > 0.
         channel_indexes: if defined, select foreground only on the specified channels
             of image. if None, select foreground on the whole image.
         margin: add margin to all dims of the bounding box.
     """
     super().__init__(keys)
     self.source_key = source_key
     self.select_fn = select_fn
     self.channel_indexes = ensure_tuple(
         channel_indexes) if channel_indexes is not None else None
     self.margin = margin
Example #15
0
    def __call__(self, filename):
        """
        Args:
            filename (str, list, tuple, file): path file or file-like object or a list of files.
        """
        filename = ensure_tuple(filename)
        img_array = list()
        compatible_meta = None
        for name in filename:
            img = Image.open(name)
            data = np.asarray(img)
            if self.dtype:
                data = data.astype(self.dtype)
            img_array.append(data)
            meta = dict()
            meta["filename_or_obj"] = name
            meta["spatial_shape"] = data.shape[:2]
            meta["format"] = img.format
            meta["mode"] = img.mode
            meta["width"] = img.width
            meta["height"] = img.height
            meta["info"] = img.info

            if self.image_only:
                continue

            if not compatible_meta:
                compatible_meta = meta
            else:
                assert np.allclose(
                    meta["spatial_shape"], compatible_meta["spatial_shape"]
                ), "all the images in the list should have same spatial shape."

        img_array = np.stack(img_array,
                             axis=0) if len(img_array) > 1 else img_array[0]
        return img_array if self.image_only else (img_array, compatible_meta)
Example #16
0
    def __init__(
        self,
        device,
        max_epochs: int,
        amp: bool,
        data_loader,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        post_transform=None,
        key_metric=None,
        additional_metrics=None,
        handlers=None,
    ):
        # pytype: disable=invalid-directive
        # pytype: disable=wrong-arg-count
        super().__init__(iteration_update
                         if iteration_update is not None else self._iteration)
        # pytype: enable=invalid-directive
        # pytype: enable=wrong-arg-count
        # FIXME:
        if amp:
            self.logger.info(
                "Will add AMP support when PyTorch v1.6 released.")
        if not isinstance(device, torch.device):
            raise ValueError("device must be PyTorch device object.")
        if not isinstance(data_loader,
                          torch.utils.data.DataLoader):  # type: ignore
            raise ValueError("data_loader must be PyTorch DataLoader.")

        # set all sharable data for the workflow based on Ignite engine.state
        self.state = State(
            seed=0,
            iteration=0,
            epoch=0,
            max_epochs=max_epochs,
            epoch_length=-1,
            output=None,
            batch=None,
            metrics={},
            dataloader=None,
            device=device,
            amp=amp,
            key_metric_name=
            None,  # we can set many metrics, only use key_metric to compare and save the best model
            best_metric=-1,
            best_metric_epoch=-1,
        )
        self.data_loader = data_loader
        self.prepare_batch = prepare_batch

        if post_transform is not None:

            @self.on(Events.ITERATION_COMPLETED)
            def run_post_transform(engine):
                engine.state.output = apply_transform(post_transform,
                                                      engine.state.output)

        if key_metric is not None:

            if not isinstance(key_metric, dict):
                raise ValueError("key_metric must be a dict object.")
            self.state.key_metric_name = list(key_metric.keys())[0]
            metrics = key_metric
            if additional_metrics is not None and len(additional_metrics) > 0:
                if not isinstance(additional_metrics, dict):
                    raise ValueError(
                        "additional_metrics must be a dict object.")
                metrics.update(additional_metrics)
            for name, metric in metrics.items():
                metric.attach(self, name)

            @self.on(Events.EPOCH_COMPLETED)
            def _compare_metrics(engine):
                if engine.state.key_metric_name is not None:
                    current_val_metric = engine.state.metrics[
                        engine.state.key_metric_name]
                    if current_val_metric > engine.state.best_metric:
                        self.logger.info(
                            f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}"
                        )
                        engine.state.best_metric = current_val_metric
                        engine.state.best_metric_epoch = engine.state.epoch

        if handlers is not None:
            handlers = ensure_tuple(handlers)
            for handler in handlers:
                handler.attach(self)
Example #17
0
    def forward(self, src, theta, spatial_size=None):
        """
        ``theta`` must be an affine transformation matrix with shape
        3x3 or Nx3x3 or Nx2x3 or 2x3 for spatial 2D transforms,
        4x4 or Nx4x4 or Nx3x4 or 3x4 for spatial 3D transforms,
        where `N` is the batch size. `theta` will be converted into float Tensor for the computation.

        Args:
            src (array_like): image in spatial 2D or 3D (N, C, spatial_dims),
                where N is the batch dim, C is the number of channels.
            theta (array_like): Nx3x3, Nx2x3, 3x3, 2x3 for spatial 2D inputs,
                Nx4x4, Nx3x4, 3x4, 4x4 for spatial 3D inputs. When the batch dimension is omitted,
                `theta` will be repeated N times, N is the batch dim of `src`.
            spatial_size (list or tuple of int): output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src`.

        Raises:
            TypeError: both src and theta must be torch Tensor, got {type(src).__name__}, {type(theta).__name__}.
            ValueError: affine must be Nxdxd or dxd.
            ValueError: affine must be Nx3x3 or Nx4x4, got: {theta.shape}.
            ValueError: src must be spatially 2D or 3D.
            ValueError: batch dimension of affine and image does not match, got affine: {} and image: {}.

        """
        # validate `theta`
        if not torch.is_tensor(theta) or not torch.is_tensor(src):
            raise TypeError(
                f"both src and theta must be torch Tensor, got {type(src).__name__}, {type(theta).__name__}."
            )
        if theta.ndim not in (2, 3):
            raise ValueError("affine must be Nxdxd or dxd.")
        if theta.ndim == 2:
            theta = theta[None]  # adds a batch dim.
        theta = theta.clone()  # no in-place change of theta
        theta_shape = tuple(theta.shape[1:])
        if theta_shape in ((2, 3), (3, 4)):  # needs padding to dxd
            pad_affine = torch.tensor([0, 0, 1] if theta_shape[0] == 2 else [0, 0, 0, 1])
            pad_affine = pad_affine.repeat(theta.shape[0], 1, 1).to(theta)
            pad_affine.requires_grad = False
            theta = torch.cat([theta, pad_affine], dim=1)
        if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):
            raise ValueError(f"affine must be Nx3x3 or Nx4x4, got: {theta.shape}.")

        # validate `src`
        sr = src.ndim - 2  # input spatial rank
        if sr not in (2, 3):
            raise ValueError("src must be spatially 2D or 3D.")

        # set output shape
        src_size = tuple(src.shape)
        dst_size = src_size  # default to the src shape
        if self.spatial_size is not None:
            dst_size = src_size[:2] + self.spatial_size
        if spatial_size is not None:
            dst_size = src_size[:2] + ensure_tuple(spatial_size)

        # reverse and normalise theta if needed
        if not self.normalized:
            theta = to_norm_affine(
                affine=theta, src_size=src_size[2:], dst_size=dst_size[2:], align_corners=self.align_corners
            )
        if self.reverse_indexing:
            rev_idx = torch.as_tensor(range(sr - 1, -1, -1), device=src.device)
            theta[:, :sr] = theta[:, rev_idx]
            theta[:, :, :sr] = theta[:, :, rev_idx]
        if (theta.shape[0] == 1) and src_size[0] > 1:
            # adds a batch dim to `theta` in order to match `src`
            theta = theta.repeat(src_size[0], 1, 1)
        if theta.shape[0] != src_size[0]:
            raise ValueError(
                "batch dimension of affine and image does not match, got affine: {} and image: {}.".format(
                    theta.shape[0], src_size[0]
                )
            )

        grid = nn.functional.affine_grid(theta=theta[:, :sr], size=dst_size, align_corners=self.align_corners)
        dst = nn.functional.grid_sample(
            input=src.contiguous(),
            grid=grid,
            mode=self.mode.value,
            padding_mode=self.padding_mode.value,
            align_corners=self.align_corners,
        )
        return dst
Example #18
0
 def __init__(self, transforms=None) -> None:
     if transforms is None:
         transforms = []
     self.transforms = ensure_tuple(transforms)
     self.set_random_state(seed=get_seed())
 def __call__(self, engine):
     args = ensure_tuple(self.step_transform(engine))
     self.lr_scheduler.step(*args)
     if self.print_lr:
         self.logger.info(
             f"Current learning rate: {self.lr_scheduler._last_lr[0]}")