Ejemplo n.º 1
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Args:
            img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]]

        Raises:
            ValueError: When ``image`` ndim is not one of [3, 4].

        Returns:
            A torch tensor with the same shape as img, note:
                1. it's the binary classification result of whether a pixel is edge or not.
                2. in order to keep the original shape of mask image, we use padding as default.
                3. the edge detection is just approximate because it defects inherent to Laplace kernel,
                   ideally the edge should be thin enough, but now it has a thickness.

        """
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
        spatial_dims = len(img_.shape) - 1
        img_ = img_.unsqueeze(0)  # adds a batch dim
        if spatial_dims == 2:
            kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]],
                                  dtype=torch.float32)
        elif spatial_dims == 3:
            kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32)
            kernel[1, 1, 1] = 26.0
        else:
            raise ValueError(
                f"{self.__class__} can only handle 2D or 3D images.")
        contour_img = apply_filter(img_, kernel)
        contour_img.clamp_(min=0.0, max=1.0)
        output, *_ = convert_to_dst_type(contour_img.squeeze(0), img)
        return output
Ejemplo n.º 2
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Filter the image on the `applied_labels`.

        Args:
            img: Pytorch tensor or numpy array of any shape.

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch tensor or numpy array of the same shape as the input.
        """
        if not isinstance(img, (np.ndarray, torch.Tensor)):
            raise NotImplementedError(
                f"{self.__class__} can not handle data of type {type(img)}.")

        if isinstance(img, torch.Tensor):
            img = convert_to_tensor(img, track_meta=get_track_meta())
            img_ = convert_to_tensor(img, track_meta=False)
            if hasattr(torch, "isin"):  # `isin` is new in torch 1.10.0
                appl_lbls = torch.as_tensor(self.applied_labels,
                                            device=img_.device)
                out = torch.where(torch.isin(img_, appl_lbls), img_,
                                  torch.tensor(0.0).to(img_))
                return convert_to_dst_type(out, dst=img)[0]
            out: NdarrayOrTensor = self(
                img_.detach().cpu().numpy())  # type: ignore
            out = convert_to_dst_type(out, img)[0]  # type: ignore
            return out
        return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0))
Ejemplo n.º 3
0
 def __call__(self, img: NdarrayOrTensor) -> torch.Tensor:
     """
     Apply the transform to `img` and make it contiguous.
     """
     return convert_to_tensor(img,
                              dtype=self.dtype,
                              device=self.device,
                              wrap_sequence=True)  # type: ignore
Ejemplo n.º 4
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Args:
            img: shape must be (C, spatial_dim1[, spatial_dim2, ...]).

        Returns:
            An array with shape (C, spatial_dim1[, spatial_dim2, ...]).
        """
        is_onehot = img.shape[
            0] > 1 if self.is_onehot is None else self.is_onehot
        if self.applied_labels is not None:
            applied_labels = self.applied_labels
        else:
            applied_labels = tuple(get_unique_labels(img, is_onehot,
                                                     discard=0))
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
        if self.independent:
            for i in applied_labels:
                foreground = img_[i] > 0 if is_onehot else img_[0] == i
                mask = get_largest_connected_component_mask(
                    foreground, self.connectivity)
                if is_onehot:
                    img_[i][foreground != mask] = 0
                else:
                    img_[0][foreground != mask] = 0
            return convert_to_dst_type(img_, dst=img)[0]
        if not is_onehot:  # not one-hot, union of labels
            labels, *_ = convert_to_dst_type(applied_labels,
                                             dst=img_,
                                             wrap_sequence=True)
            foreground = (img_[..., None] == labels).any(-1)[0]
            mask = get_largest_connected_component_mask(
                foreground, self.connectivity)
            img_[0][foreground != mask] = 0
            return convert_to_dst_type(img_, dst=img)[0]
        # one-hot, union of labels
        foreground = (img_[applied_labels, ...] == 1).any(0)
        mask = get_largest_connected_component_mask(foreground,
                                                    self.connectivity)
        for i in applied_labels:
            img_[i][foreground != mask] = 0
        return convert_to_dst_type(img_, dst=img)[0]
Ejemplo n.º 5
0
    def get_target_spacing(self,
                           spacing_key: str = "affine",
                           anisotropic_threshold: int = 3,
                           percentile: float = 10.0):
        """
        Calculate the target spacing according to all spacings.
        If the target spacing is very anisotropic,
        decrease the spacing value of the maximum axis according to percentile.
        The spacing is computed from `affine_to_spacing(data[spacing_key][0], 3)` if `data[spacing_key]` is a matrix,
        otherwise, the `data[spacing_key]` must be a vector of pixdim values.

        Args:
            spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``).
            anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``).
            percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to
                replace that axis.

        """
        if len(self.all_meta_data) == 0:
            self.collect_meta_data()
        if spacing_key not in self.all_meta_data[0]:
            raise ValueError(
                "The provided spacing_key is not in self.all_meta_data.")
        spacings = []
        for data in self.all_meta_data:
            spacing_vals = convert_to_tensor(data[spacing_key][0],
                                             track_meta=False,
                                             wrap_sequence=True)
            if spacing_vals.ndim == 1:  # vector
                spacings.append(spacing_vals[:3][None])
            elif spacing_vals.ndim == 2:  # matrix
                spacings.append(affine_to_spacing(spacing_vals, 3)[None])
            else:
                raise ValueError(
                    "data[spacing_key] must be a vector or a matrix.")
        all_spacings = concatenate(to_cat=spacings, axis=0)
        all_spacings, *_ = convert_data_type(data=all_spacings,
                                             output_type=np.ndarray,
                                             wrap_sequence=True)

        target_spacing = np.median(all_spacings, axis=0)
        if max(target_spacing) / min(target_spacing) >= anisotropic_threshold:
            largest_axis = np.argmax(target_spacing)
            target_spacing[largest_axis] = np.percentile(
                all_spacings[:, largest_axis], percentile)

        output = list(target_spacing)

        return tuple(output)
Ejemplo n.º 6
0
    def __call__(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Mapping[Hashable, NdarrayOrTensor]:
        self.randomize()
        d = dict(data)
        if not self._do_transform:
            for key in self.key_iterator(d):
                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
            return d

        for idx, key in enumerate(self.key_iterator(d)):
            self.trans.set_mode(self.mode[idx % len(self.mode)])
            d[key] = self.trans(d[key], False)

        return d
Ejemplo n.º 7
0
    def __call__(
        self,
        img: NdarrayOrTensor,
        sigmoid: Optional[bool] = None,
        softmax: Optional[bool] = None,
        other: Optional[Callable] = None,
    ) -> NdarrayOrTensor:
        """
        Args:
            sigmoid: whether to execute sigmoid function on model output before transform.
                Defaults to ``self.sigmoid``.
            softmax: whether to execute softmax function on model output before transform.
                Defaults to ``self.softmax``.
            other: callable function to execute other activation layers, for example:
                `other = torch.tanh`. Defaults to ``self.other``.

        Raises:
            ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values.
            TypeError: When ``other`` is not an ``Optional[Callable]``.
            ValueError: When ``self.other=None`` and ``other=None``. Incompatible values.

        """
        if sigmoid and softmax:
            raise ValueError(
                "Incompatible values: sigmoid=True and softmax=True.")
        if other is not None and not callable(other):
            raise TypeError(
                f"other must be None or callable but is {type(other).__name__}."
            )

        # convert to float as activation must operate on float tensor
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
        if sigmoid or self.sigmoid:
            img_t = torch.sigmoid(img_t)
        if softmax or self.softmax:
            img_t = torch.softmax(img_t, dim=0)

        act_func = self.other if other is None else other
        if act_func is not None:
            img_t = act_func(img_t)
        out, *_ = convert_to_dst_type(img_t, img)
        return out
Ejemplo n.º 8
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Fill the holes in the provided image.

        Note:
            The value 0 is assumed as background label.

        Args:
            img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_np, *_ = convert_data_type(img, np.ndarray)
        out_np: np.ndarray = fill_holes(img_np, self.applied_labels,
                                        self.connectivity)
        out, *_ = convert_to_dst_type(out_np, img)
        return out
Ejemplo n.º 9
0
    def __call__(
            self,
            img: NdarrayOrTensor,
            argmax: Optional[bool] = None,
            to_onehot: Optional[int] = None,
            threshold: Optional[float] = None,
            rounding: Optional[str] = None,
            n_classes: Optional[int] = None,  # deprecated
            num_classes: Optional[int] = None,  # deprecated
            logit_thresh: Optional[float] = None,  # deprecated
            threshold_values: Optional[bool] = None,  # deprecated
    ) -> NdarrayOrTensor:
        """
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``to_onehot`` instead.

        .. deprecated:: 0.7.0
            ``num_classes`` is deprecated, use ``to_onehot`` instead.
            ``logit_thresh`` is deprecated, use ``threshold`` instead.
            ``threshold_values`` is deprecated, use ``threshold`` instead.

        """
        if isinstance(to_onehot, bool):
            warnings.warn(
                "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead."
            )
            to_onehot = num_classes if to_onehot else None
        if isinstance(threshold, bool):
            warnings.warn(
                "`threshold_values=True/False` is deprecated, please use `threshold=value` instead."
            )
            threshold = logit_thresh if threshold else None
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor)
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=0, keepdim=True)

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise AssertionError(
                    "the number of classes for One-Hot must be an integer.")
            img_t = one_hot(img_t, num_classes=to_onehot, dim=0)

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float)
        return img
Ejemplo n.º 10
0
    def resample_if_needed(
        cls,
        data_array: NdarrayOrTensor,
        affine: Optional[NdarrayOrTensor] = None,
        target_affine: Optional[NdarrayOrTensor] = None,
        output_spatial_shape: Union[Sequence[int], int, None] = None,
        mode: str = GridSampleMode.BILINEAR,
        padding_mode: str = GridSamplePadMode.BORDER,
        align_corners: bool = False,
        dtype: DtypeLike = np.float64,
    ):
        """
        Convert the ``data_array`` into the coordinate system specified by
        ``target_affine``, from the current coordinate definition of ``affine``.

        If the transform between ``affine`` and ``target_affine`` could be
        achieved by simply transposing and flipping ``data_array``, no resampling
        will happen.  Otherwise, this function resamples ``data_array`` using the
        transformation computed from ``affine`` and ``target_affine``.

        This function assumes the NIfTI dimension notations. Spatially it
        supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D
        respectively. When saving multiple time steps or multiple channels,
        time and/or modality axes should be appended after the first three
        dimensions. For example, shape of 2D eight-class segmentation
        probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in
        shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a
        single-channel 3D image. The ``convert_to_channel_last`` method can be
        used to convert the data to the format described here.

        Note that the shape of the resampled ``data_array`` may subject to some
        rounding errors. For example, resampling a 20x20 pixel image from pixel
        size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel
        image. However, resampling a 20x20-pixel image from pixel size (2.0,
        2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where
        the image shape is rounded from 13.333x13.333 pixels. In this case
        ``output_spatial_shape`` could be specified so that this function
        writes image data to a designated shape.

        Args:
            data_array: input data array to be converted.
            affine: the current affine of ``data_array``. Defaults to identity
            target_affine: the designated affine of ``data_array``.
                The actual output affine might be different from this value due to precision changes.
            output_spatial_shape: spatial shape of the output image.
                This option is used when resampling is needed.
            mode: available options are {``"bilinear"``, ``"nearest"``, ``"bicubic"``}.
                This option is used when resampling is needed.
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            padding_mode: available options are {``"zeros"``, ``"border"``, ``"reflection"``}.
                This option is used when resampling is needed.
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            align_corners: boolean option of ``grid_sample`` to handle the corner convention.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            dtype: data type for resampling computation. Defaults to
                ``np.float64`` for best precision. If ``None``, use the data type of input data.
                The output data type of this method is always ``np.float32``.
        """
        orig_type = type(data_array)
        data_array = convert_to_tensor(data_array, track_meta=True)
        if affine is not None:
            data_array.affine = convert_to_tensor(
                affine, track_meta=False)  # type: ignore
        resampler = SpatialResample(mode=mode,
                                    padding_mode=padding_mode,
                                    align_corners=align_corners,
                                    dtype=dtype)
        output_array = resampler(data_array[None],
                                 dst_affine=target_affine,
                                 spatial_size=output_spatial_shape)
        # convert back at the end
        if isinstance(output_array, MetaTensor):
            output_array.applied_operations = []
        data_array, *_ = convert_data_type(
            output_array, output_type=orig_type)  # type: ignore
        affine, *_ = convert_data_type(output_array.affine,
                                       output_type=orig_type)  # type: ignore
        return data_array[0], affine
Ejemplo n.º 11
0
    def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
        d = dict(data)
        for (
                key,
                orig_key,
                meta_key,
                orig_meta_key,
                meta_key_postfix,
                nearest_interp,
                to_tensor,
                device,
                post_func,
        ) in self.key_iterator(
                d,
                self.orig_keys,
                self.meta_keys,
                self.orig_meta_keys,
                self.meta_key_postfix,
                self.nearest_interp,
                self.to_tensor,
                self.device,
                self.post_func,
        ):
            if isinstance(d[key], MetaTensor):
                if orig_key not in d:
                    warnings.warn(
                        f"transform info of `{orig_key}` is not available in MetaTensor {key}."
                    )
                    continue
            else:
                transform_key = InvertibleTransform.trace_key(orig_key)
                if transform_key not in d:
                    warnings.warn(
                        f"transform info of `{orig_key}` is not available or no InvertibleTransform applied."
                    )
                    continue

            orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
            if orig_key in d and isinstance(d[orig_key], MetaTensor):
                transform_info = d[orig_key].applied_operations
                meta_info = d[orig_key].meta
            else:
                transform_info = d[InvertibleTransform.trace_key(orig_key)]
                meta_info = d.get(orig_meta_key, {})
            if nearest_interp:
                transform_info = convert_applied_interp_mode(
                    trans_info=transform_info,
                    mode="nearest",
                    align_corners=None)

            inputs = d[key]
            if isinstance(inputs, torch.Tensor):
                inputs = inputs.detach()

            if not isinstance(inputs, MetaTensor):
                inputs = convert_to_tensor(inputs, track_meta=True)
            inputs.applied_operations = deepcopy(transform_info)
            inputs.meta = deepcopy(meta_info)

            # construct the input dict data
            input_dict = {orig_key: inputs}
            if config.USE_META_DICT:
                input_dict[InvertibleTransform.trace_key(
                    orig_key)] = transform_info
                input_dict[PostFix.meta(orig_key)] = meta_info
            with allow_missing_keys_mode(self.transform):  # type: ignore
                inverted = self.transform.inverse(input_dict)

            # save the inverted data
            if to_tensor and not isinstance(inverted[orig_key], MetaTensor):
                inverted_data = self._totensor(inverted[orig_key])
            else:
                inverted_data = inverted[orig_key]
            d[key] = post_func(inverted_data.to(device))
            # save the invertd applied_operations if it's in the source dict
            if InvertibleTransform.trace_key(orig_key) in d:
                d[InvertibleTransform.trace_key(
                    orig_key)] = inverted_data.applied_operations
            # save the inverted meta dict if it's in the source dict
            if orig_meta_key in d:
                meta_key = meta_key or f"{key}_{meta_key_postfix}"
                d[meta_key] = inverted.get(orig_meta_key)
        return d
 def __call__(self, x):
     return _apply(x, lambda x: convert_to_tensor(x).cuda())