Esempio n. 1
0
    def resample_and_clip(
        cls,
        data_array: NdarrayOrTensor,
        output_spatial_shape: Optional[Sequence[int]] = None,
        mode: str = InterpolateMode.BICUBIC,
    ):
        """
        Resample ``data_array`` to ``output_spatial_shape`` if needed.
        Args:
            data_array: input data array. This method assumes the 'channel-last' format.
            output_spatial_shape: output spatial shape.
            mode: interpolation mode, defautl is ``InterpolateMode.BICUBIC``.
        """

        data: np.ndarray = convert_data_type(data_array, np.ndarray)[0]
        if output_spatial_shape is not None:
            output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2)
            mode = look_up_option(mode, InterpolateMode)
            align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False
            xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners)
            _min, _max = np.min(data), np.max(data)
            if len(data.shape) == 3:
                data = np.moveaxis(data, -1, 0)  # to channel first
                data = convert_data_type(xform(data), np.ndarray)[0]  # type: ignore
                data = np.moveaxis(data, 0, -1)
            else:  # (H, W)
                data = np.expand_dims(data, 0)  # make a channel
                data = convert_data_type(xform(data), np.ndarray)[0][0]  # type: ignore
            if mode != InterpolateMode.NEAREST:
                data = np.clip(data, _min, _max)
        return data
Esempio n. 2
0
 def create_backend_obj(cls, data_array: NdarrayOrTensor,
                        **kwargs) -> np.ndarray:
     """
     Subclass should implement this method to return a backend-specific data representation object.
     This method is used by ``cls.write`` and the input ``data_array`` is assumed 'channel-last'.
     """
     return convert_data_type(data_array, np.ndarray)[0]
Esempio n. 3
0
    def create_backend_obj(cls,
                           data_array: NdarrayOrTensor,
                           affine: Optional[NdarrayOrTensor] = None,
                           dtype: DtypeLike = None,
                           **kwargs):
        """
        Create an Nifti1Image object from ``data_array``. This method assumes a 'channel-last' ``data_array``.

        Args:
            data_array: input data array.
            affine: affine matrix of the data array.
            dtype: output data type.
            kwargs: keyword arguments. Current ``nib.nifti1.Nifti1Image`` will read
                ``header``, ``extra``, ``file_map`` from this dictionary.

        See also:

            - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.Nifti1Image
        """
        data_array = super().create_backend_obj(data_array)
        if dtype is not None:
            data_array = data_array.astype(dtype, copy=False)
        affine = convert_data_type(affine, np.ndarray)[0]
        if affine is None:
            affine = np.eye(4)
        affine = to_affine_nd(r=3, affine=affine)
        return nib.nifti1.Nifti1Image(
            data_array,
            affine,
            header=kwargs.pop("header", None),
            extra=kwargs.pop("extra", None),
            file_map=kwargs.pop("file_map", None),
        )
Esempio n. 4
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Args:
            img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]]

        Raises:
            ValueError: When ``image`` ndim is not one of [3, 4].

        Returns:
            A torch tensor with the same shape as img, note:
                1. it's the binary classification result of whether a pixel is edge or not.
                2. in order to keep the original shape of mask image, we use padding as default.
                3. the edge detection is just approximate because it defects inherent to Laplace kernel,
                   ideally the edge should be thin enough, but now it has a thickness.

        """
        img_: torch.Tensor = convert_data_type(img, torch.Tensor)[0]
        spatial_dims = len(img_.shape) - 1
        img_ = img_.unsqueeze(0)  # adds a batch dim
        if spatial_dims == 2:
            kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]],
                                  dtype=torch.float32)
        elif spatial_dims == 3:
            kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32)
            kernel[1, 1, 1] = 26.0
        else:
            raise ValueError(
                f"{self.__class__} can only handle 2D or 3D images.")
        contour_img = apply_filter(img_, kernel)
        contour_img.clamp_(min=0.0, max=1.0)
        output, *_ = convert_to_dst_type(contour_img.squeeze(0), img)
        return output
Esempio n. 5
0
def _image3_animated_gif(
    tag: str, image: Union[np.ndarray, torch.Tensor], writer, frame_dim: int = 0, scale_factor: float = 1.0
):
    """Function to actually create the animated gif.

    Args:
        tag: Data identifier
        image: 3D image tensors expected to be in `HWD` format
        writer: the tensorboard writer to plot image
        frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`.
        scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will
            scale it to displayable range
    """
    if len(image.shape) != 3:
        raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3")

    image_np, *_ = convert_data_type(image, output_type=np.ndarray)
    ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)]
    ims = [GifImage.fromarray(im) for im in ims]
    img_str = b""
    for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]:
        img_str += b_data
    img_str += b"\x21\xFF\x0B\x4E\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2E\x30\x03\x01\x00\x00\x00"
    for i in ims:
        for b_data in PIL.GifImagePlugin.getdata(i):
            img_str += b_data
    img_str += b"\x3B"

    summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary
    summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)
    image_summary = summary.Value(tag=tag, image=summary_image_str)
    return summary(value=[image_summary])
Esempio n. 6
0
    def extend(self, *data) -> None:
        """
        Extend the local buffers with new ("batch-first") data.
        A buffer will be allocated for each `data` item.
        Compared with `self.append`, this method adds a "batch" of data to the local buffers.

        Args:
            data: each item can be a "batch-first" tensor or a list of "channel-first" tensors.
                they will be concatenated at the 0-th dimension when `get_buffer()` is called.
        """
        if self._buffers is None:
            self._buffers = [[] for _ in data]
        for b, d in zip(self._buffers, data):
            # converting to pytorch tensors so that we can use the distributed API
            d_t, *_ = convert_data_type(d,
                                        output_type=torch.Tensor,
                                        wrap_sequence=True)
            try:
                b.extend([x[0] for x in torch.split(d_t, 1, dim=0)])
            except (AttributeError, IndexError, RuntimeError) as e:
                raise TypeError(
                    f"{e}. `data` should be a batch-first tensor or"
                    f" a list of channel-first tensors, got {type(d_t)}"
                ) from e
        self._synced = False
Esempio n. 7
0
    def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0):
        """
        Calculate the target spacing according to all spacings.
        If the target spacing is very anisotropic,
        decrease the spacing value of the maximum axis according to percentile.
        So far, this function only supports NIFTI images which store spacings in headers with key "pixdim".
        After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`.

        Args:
            spacing_key: key of spacing in meta data (default: ``pixdim``).
            anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``).
            percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to
                replace that axis.

        """
        if len(self.all_meta_data) == 0:
            self.collect_meta_data()
        if spacing_key not in self.all_meta_data[0]:
            raise ValueError("The provided spacing_key is not in self.all_meta_data.")
        all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0)
        all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True)

        target_spacing = np.median(all_spacings, axis=0)
        if max(target_spacing) / min(target_spacing) >= anisotropic_threshold:
            largest_axis = np.argmax(target_spacing)
            target_spacing[largest_axis] = np.percentile(all_spacings[:, largest_axis], percentile)

        output = list(target_spacing)

        return tuple(output)
Esempio n. 8
0
def _create_itk_obj(array, affine):
    itk_img = deepcopy(array)
    itk_img = convert_data_type(itk_img, np.ndarray)[0]
    itk_obj = ITKWriter.create_backend_obj(itk_img,
                                           channel_dim=None,
                                           affine=affine,
                                           affine_lps=True)
    return itk_obj
Esempio n. 9
0
    def __call__(
            self,
            img: NdarrayOrTensor,
            meta_data: Optional[Dict] = None,
            mask: Optional[np.ndarray] = None) -> Tuple[NdarrayOrTensor, Dict]:
        """
        Compute statistics for the intensity of input image.

        Args:
            img: input image to compute intensity stats.
            meta_data: meta data dictionary to store the statistics data, if None, will create an empty dictionary.
            mask: if not None, mask the image to extract only the interested area to compute statistics.
                mask must have the same shape as input `img`.

        """
        img_np: np.ndarray
        img_np, *_ = convert_data_type(img, np.ndarray)  # type: ignore
        if meta_data is None:
            meta_data = {}

        if mask is not None:
            if mask.shape != img_np.shape or mask.dtype != bool:
                raise TypeError(
                    "mask must be bool array with the same shape as input `img`."
                )
            img_np = img_np[mask]

        supported_ops = {
            "mean": np.nanmean,
            "median": np.nanmedian,
            "max": np.nanmax,
            "min": np.nanmin,
            "std": np.nanstd,
        }

        def _compute(op: Callable, data: np.ndarray):
            if self.channel_wise:
                return [op(c) for c in data]
            return op(data)

        custom_index = 0
        for o in self.ops:
            if isinstance(o, str):
                o = look_up_option(o, supported_ops.keys())
                meta_data[self.key_prefix + "_" + o] = _compute(
                    supported_ops[o], img_np)  # type: ignore
            elif callable(o):
                meta_data[self.key_prefix + "_custom_" +
                          str(custom_index)] = _compute(o, img_np)
                custom_index += 1
            else:
                raise ValueError(
                    "ops must be key string for predefined operations or callable function."
                )

        return img, meta_data
Esempio n. 10
0
    def calculate_percentiles(
        self,
        foreground_threshold: int = 0,
        sampling_flag: bool = True,
        interval: int = 10,
        min_percentile: float = 0.5,
        max_percentile: float = 99.5,
    ):
        """
        This function is used to calculate the percentiles of intensities (and median) of the input dataset. To get
        the required values, all voxels need to be accumulated. To reduce the memory used, this function can be set
        to accumulate only a part of the voxels.

        Args:
            foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter
                is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding
                voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set
                the threshold to ``-1`` (default: ``0``).
            sampling_flag: whether to sample only a part of the voxels (default: ``True``).
            interval: the sampling interval for accumulating voxels (default: ``10``).
            min_percentile: minimal percentile (default: ``0.5``).
            max_percentile: maximal percentile (default: ``99.5``).

        """
        all_intensities = []
        for data in self.data_loader:
            if self.image_key and self.label_key:
                image, label = data[self.image_key], data[self.label_key]
            else:
                image, label = data
            image, *_ = convert_data_type(data=image, output_type=torch.Tensor)
            label, *_ = convert_data_type(data=label, output_type=torch.Tensor)

            intensities = image[torch.where(
                label > foreground_threshold)].tolist()
            if sampling_flag:
                intensities = intensities[::interval]
            all_intensities.append(intensities)

        all_intensities = list(chain(*all_intensities))
        self.data_min_percentile, self.data_max_percentile = np.percentile(
            all_intensities, [min_percentile, max_percentile])
        self.data_median = np.median(all_intensities)
Esempio n. 11
0
    def __call__(self, img: NdarrayOrTensor):
        """
        Args:
            img: PyTorch Tensor data for the TorchVision transform.

        """
        img_t, *_ = convert_data_type(img, torch.Tensor)  # type: ignore
        out = self.trans(img_t)
        out, *_ = convert_to_dst_type(src=out, dst=img)
        return out
Esempio n. 12
0
    def calculate_statistics(self, foreground_threshold: int = 0):
        """
        This function is used to calculate the maximum, minimum, mean and standard deviation of intensities of
        the input dataset.

        Args:
            foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter
                is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding
                voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set
                the threshold to ``-1`` (default: ``0``).

        """
        voxel_sum = torch.as_tensor(0.0)
        voxel_square_sum = torch.as_tensor(0.0)
        voxel_max, voxel_min = [], []
        voxel_ct = 0

        for data in self.data_loader:
            if self.image_key and self.label_key:
                image, label = data[self.image_key], data[self.label_key]
            else:
                image, label = data
            image, *_ = convert_data_type(data=image, output_type=torch.Tensor)
            label, *_ = convert_data_type(data=label, output_type=torch.Tensor)

            image_foreground = image[torch.where(label > foreground_threshold)]

            voxel_max.append(image_foreground.max().item())
            voxel_min.append(image_foreground.min().item())
            voxel_ct += len(image_foreground)
            voxel_sum += image_foreground.sum()
            voxel_square_sum += torch.square(image_foreground).sum()

        self.data_max, self.data_min = max(voxel_max), min(voxel_min)
        self.data_mean = (voxel_sum / voxel_ct).item()
        self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct -
                                    self.data_mean**2)).item()
Esempio n. 13
0
    def get_target_spacing(self,
                           spacing_key: str = "affine",
                           anisotropic_threshold: int = 3,
                           percentile: float = 10.0):
        """
        Calculate the target spacing according to all spacings.
        If the target spacing is very anisotropic,
        decrease the spacing value of the maximum axis according to percentile.
        The spacing is computed from `affine_to_spacing(data[spacing_key][0], 3)` if `data[spacing_key]` is a matrix,
        otherwise, the `data[spacing_key]` must be a vector of pixdim values.

        Args:
            spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``).
            anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``).
            percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to
                replace that axis.

        """
        if len(self.all_meta_data) == 0:
            self.collect_meta_data()
        if spacing_key not in self.all_meta_data[0]:
            raise ValueError(
                "The provided spacing_key is not in self.all_meta_data.")
        spacings = []
        for data in self.all_meta_data:
            spacing_vals = convert_to_tensor(data[spacing_key][0],
                                             track_meta=False,
                                             wrap_sequence=True)
            if spacing_vals.ndim == 1:  # vector
                spacings.append(spacing_vals[:3][None])
            elif spacing_vals.ndim == 2:  # matrix
                spacings.append(affine_to_spacing(spacing_vals, 3)[None])
            else:
                raise ValueError(
                    "data[spacing_key] must be a vector or a matrix.")
        all_spacings = concatenate(to_cat=spacings, axis=0)
        all_spacings, *_ = convert_data_type(data=all_spacings,
                                             output_type=np.ndarray,
                                             wrap_sequence=True)

        target_spacing = np.median(all_spacings, axis=0)
        if max(target_spacing) / min(target_spacing) >= anisotropic_threshold:
            largest_axis = np.argmax(target_spacing)
            target_spacing[largest_axis] = np.percentile(
                all_spacings[:, largest_axis], percentile)

        output = list(target_spacing)

        return tuple(output)
Esempio n. 14
0
    def create_backend_obj(
        cls,
        data_array: NdarrayOrTensor,
        channel_dim: Optional[int] = 0,
        affine: Optional[NdarrayOrTensor] = None,
        dtype: DtypeLike = np.float32,
        affine_lps_to_ras: bool = True,
        **kwargs,
    ):
        """
        Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``.

        Args:
            data_array: input data array.
            channel_dim: channel dimension of the data array. This is used to create a Vector Image if it is not ``None``.
            affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`.
            dtype: output data type.
            affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``.
                Set to ``True`` to be consistent with ``NibabelWriter``,
                otherwise the affine matrix is assumed already in the ITK convention.
            kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary.

        see also:

            - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389
        """
        data_array = super().create_backend_obj(data_array)
        _is_vec = channel_dim is not None
        if _is_vec:
            data_array = np.moveaxis(data_array, -1,
                                     0)  # from channel last to channel first
        data_array = data_array.T.astype(dtype, copy=True, order="C")
        itk_obj = itk.GetImageFromArray(data_array,
                                        is_vector=_is_vec,
                                        ttype=kwargs.pop("ttype", None))

        d = len(itk.size(itk_obj))
        if affine is None:
            affine = np.eye(d + 1, dtype=np.float64)
        _affine = convert_data_type(affine, np.ndarray)[0]
        if affine_lps_to_ras:
            _affine = orientation_ras_lps(to_affine_nd(d, _affine))
        spacing = affine_to_spacing(_affine, r=d)
        _direction: np.ndarray = np.diag(1 / spacing)
        _direction = _affine[:d, :d] @ _direction
        itk_obj.SetSpacing(spacing.tolist())
        itk_obj.SetOrigin(_affine[:d, -1].tolist())
        itk_obj.SetDirection(itk.GetMatrixFromArray(_direction))
        return itk_obj
Esempio n. 15
0
    def __call__(self, data: NdarrayOrTensor):
        """
        Args:
            data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
                will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and
                objects keep the original. for dictionary, list or tuple, ensure every item as expected type
                if applicable.

        """
        output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray
        out, *_ = convert_data_type(data,
                                    output_type=output_type,
                                    dtype=self.dtype,
                                    device=self.device)
        return out
Esempio n. 16
0
    def __call__(self, img: NdarrayOrTensor):
        img_np, *_ = convert_data_type(img, np.ndarray)
        img_flat = img_np.flatten()
        try:
            out_flat = np.copy(img_flat).astype(self.dtype)
        except ValueError:
            # can't copy unchanged labels as the expected dtype is not supported, must map all the label values
            out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype)

        for o, t in zip(self.orig_labels, self.target_labels):
            if o == t:
                continue
            np.place(out_flat, img_flat == o, t)

        out = out_flat.reshape(img_np.shape)
        out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype)
        return out
Esempio n. 17
0
    def __call__(
        self,
        img: NdarrayOrTensor,
        argmax: Optional[bool] = None,
        to_onehot: Optional[int] = None,
        threshold: Optional[float] = None,
        rounding: Optional[str] = None
    ) -> NdarrayOrTensor:
        """
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].
        """

        img_t: torch.Tensor
        img_t, *_ = convert_data_type(img, torch.Tensor)  # type: ignore
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise ValueError("the number of classes for One-Hot must be an integer.")
            img_t = one_hot(
                img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float)
            )

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get("dtype", torch.float))
        return img
Esempio n. 18
0
    def append(self, *data) -> None:
        """
        Add samples to the local cumulative buffers.
        A buffer will be allocated for each `data` item.
        Compared with `self.extend`, this method adds a single sample (instead
        of a "batch") to the local buffers.

        Args:
            data: each item will be converted into a torch tensor.
                they will be stacked at the 0-th dim with a new dimension when `get_buffer()` is called.

        """
        if self._buffers is None:
            self._buffers = [[] for _ in data]
        for b, d in zip(self._buffers, data):
            # converting to pytorch tensors so that we can use the distributed API
            d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True)
            b.append(d_t)
        self._synced = False
Esempio n. 19
0
    def aggregate(self):
        """
        Sync data from all the ranks and compute the average value with previous sum value.

        """
        data = self.get_buffer()

        # compute SUM across the batch dimension
        nans = isnan(data)
        not_nans = convert_data_type((~nans), dtype=torch.float32)[0].sum(0)
        data[nans] = 0
        f = data.sum(0)

        # clear the buffer for next update
        super().reset()
        self.sum = f if self.sum is None else (self.sum + f)
        self.not_nans = not_nans if self.not_nans is None else (self.not_nans + not_nans)

        return self.sum / self.not_nans
Esempio n. 20
0
    def __call__(
        self,
        img: NdarrayOrTensor,
        dtype: Optional[Union[DtypeLike,
                              torch.dtype]] = None) -> NdarrayOrTensor:
        """
        Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor.

        Args:
            dtype: convert image to this data type, default is `self.dtype`.

        Raises:
            TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``.

        """
        img_out, *_ = convert_data_type(img,
                                        output_type=type(img),
                                        dtype=dtype or self.dtype)
        return img_out
Esempio n. 21
0
    def __call__(
        self,
        img: NdarrayOrTensor,
        sigmoid: Optional[bool] = None,
        softmax: Optional[bool] = None,
        other: Optional[Callable] = None,
    ) -> NdarrayOrTensor:
        """
        Args:
            sigmoid: whether to execute sigmoid function on model output before transform.
                Defaults to ``self.sigmoid``.
            softmax: whether to execute softmax function on model output before transform.
                Defaults to ``self.softmax``.
            other: callable function to execute other activation layers, for example:
                `other = torch.tanh`. Defaults to ``self.other``.

        Raises:
            ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values.
            TypeError: When ``other`` is not an ``Optional[Callable]``.
            ValueError: When ``self.other=None`` and ``other=None``. Incompatible values.

        """
        if sigmoid and softmax:
            raise ValueError(
                "Incompatible values: sigmoid=True and softmax=True.")
        if other is not None and not callable(other):
            raise TypeError(
                f"other must be None or callable but is {type(other).__name__}."
            )

        # convert to float as activation must operate on float tensor
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
        if sigmoid or self.sigmoid:
            img_t = torch.sigmoid(img_t)
        if softmax or self.softmax:
            img_t = torch.softmax(img_t, dim=0)

        act_func = self.other if other is None else other
        if act_func is not None:
            img_t = act_func(img_t)
        out, *_ = convert_to_dst_type(img_t, img)
        return out
Esempio n. 22
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Fill the holes in the provided image.

        Note:
            The value 0 is assumed as background label.

        Args:
            img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_np, *_ = convert_data_type(img, np.ndarray)
        out_np: np.ndarray = fill_holes(img_np, self.applied_labels,
                                        self.connectivity)
        out, *_ = convert_to_dst_type(out_np, img)
        return out
Esempio n. 23
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Fill the holes in the provided image.

        Note:
            The value 0 is assumed as background label.

        Args:
            img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
        """
        if not isinstance(img, (np.ndarray, torch.Tensor)):
            raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.")
        img_np: np.ndarray
        img_np, *_ = convert_data_type(img, np.ndarray)  # type: ignore
        out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity)
        out, *_ = convert_to_dst_type(out_np, img)
        return out
Esempio n. 24
0
def compute_surface_dice(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    class_thresholds: List[float],
    include_background: bool = False,
    distance_metric: str = "euclidean",
):
    r"""
    This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as
    :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation
    boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the
    reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in
    pixels. The NSD is bounded between 0 and 1.

    This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`.
    The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function:

    .. math::
        \operatorname {NSD}_{b,c} \left(Y_{b,c}, \hat{Y}_{b,c}\right) = \frac{\left|\mathcal{D}_{Y_{b,c}}^{'}\right| +
        \left| \mathcal{D}_{\hat{Y}_{b,c}}^{'} \right|}{\left|\mathcal{D}_{Y_{b,c}}\right| +
        \left|\mathcal{D}_{\hat{Y}_{b,c}}\right|}
        :label: nsd

    with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor
    distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation
    boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
    :math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the
    acceptable distance :math:`\tau_c`:

    .. math::
        \mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}.


    In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value
    will be returned for this class. In the case of a class being present in only one of predicted segmentation or
    reference segmentation, the class NSD will be 0.

    This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images.
    Be aware that the computation of boundaries is different from DeepMind's implementation
    https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is
    interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
    depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).

    Args:
        y_pred: Predicted segmentation, typically segmentation model output.
            It must be a one-hot encoded, batch-first tensor [B,C,H,W].
        y: Reference segmentation.
            It must be a one-hot encoded, batch-first tensor [B,C,H,W].
        class_thresholds: List of class-specific thresholds.
            The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels.
            Each threshold needs to be a finite, non-negative number.
        include_background: Whether to skip the surface dice computation on the first channel of
            the predicted output. Defaults to ``False``.
        distance_metric: The metric used to compute surface distances.
            One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``].
            Defaults to ``"euclidean"``.

    Raises:
        ValueError: If `y_pred` and/or `y` are not PyTorch tensors.
        ValueError: If `y_pred` and/or `y` do not have four dimensions.
        ValueError: If `y_pred` and/or `y` have different shapes.
        ValueError: If `y_pred` and/or `y` are not one-hot encoded
        ValueError: If the number of channels of `y_pred` and/or `y` is different from the number of class thresholds.
        ValueError: If any class threshold is not finite.
        ValueError: If any class threshold is negative.

    Returns:
        Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index
        :math:`b` and class :math:`c`.
    """

    if not include_background:
        y_pred, y = ignore_background(y_pred=y_pred, y=y)

    if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
        raise ValueError("y_pred and y must be PyTorch Tensor.")

    if y_pred.ndimension() != 4 or y.ndimension() != 4:
        raise ValueError(
            "y_pred and y should have four dimensions: [B,C,H,W].")

    if y_pred.shape != y.shape:
        raise ValueError(
            f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)."
        )

    if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y):
        raise ValueError(
            "y_pred and y should be binarized tensors (e.g. torch.int64).")
    if torch.any(y_pred > 1) or torch.any(y > 1):
        raise ValueError("y_pred and y should be one-hot encoded.")

    y = y.float()
    y_pred = y_pred.float()

    batch_size, n_class = y_pred.shape[:2]

    if n_class != len(class_thresholds):
        raise ValueError(
            f"number of classes ({n_class}) does not match number of class thresholds ({len(class_thresholds)})."
        )

    if any(~np.isfinite(class_thresholds)):
        raise ValueError("All class thresholds need to be finite.")

    if any(np.array(class_thresholds) < 0):
        raise ValueError("All class thresholds need to be >= 0.")

    nsd = np.empty((batch_size, n_class))

    for b, c in np.ndindex(batch_size, n_class):
        (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c],
                                                y[b, c],
                                                crop=False)
        if not np.any(edges_gt):
            warnings.warn(
                f"the ground truth of class {c} is all 0, this may result in nan/inf distance."
            )
        if not np.any(edges_pred):
            warnings.warn(
                f"the prediction of class {c} is all 0, this may result in nan/inf distance."
            )

        distances_pred_gt = get_surface_distance(
            edges_pred, edges_gt, distance_metric=distance_metric)
        distances_gt_pred = get_surface_distance(
            edges_gt, edges_pred, distance_metric=distance_metric)

        boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
        boundary_correct = np.sum(
            distances_pred_gt <= class_thresholds[c]) + np.sum(
                distances_gt_pred <= class_thresholds[c])

        if boundary_complete == 0:
            # the class is neither present in the prediction, nor in the reference segmentation
            nsd[b, c] = np.nan
        else:
            nsd[b, c] = boundary_correct / boundary_complete

    return convert_data_type(nsd, torch.Tensor)[0]
Esempio n. 25
0
    def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
        img_np, *_ = convert_data_type(image, np.ndarray)

        # add random offset
        self.randomize(img_size=img_np.shape)

        if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
            img_np = img_np[:, self.offset[0]:, self.offset[1]:]

        # pad to full size, divisible by tile_size
        if self.pad_full:
            c, h, w = img_np.shape
            pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
            pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
            img_np = np.pad(  # type: ignore
                img_np,
                [[0, 0], [pad_h // 2, pad_h - pad_h // 2],
                 [pad_w // 2, pad_w - pad_w // 2]],
                constant_values=self.background_val,
            )

        # extact tiles
        x_step, y_step = self.step, self.step
        h_tile, w_tile = self.tile_size, self.tile_size
        c_image, h_image, w_image = img_np.shape
        c_stride, x_stride, y_stride = img_np.strides
        llw = as_strided(
            img_np,
            shape=((h_image - h_tile) // x_step + 1,
                   (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile),
            strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride,
                     y_stride),
            writeable=False,
        )
        img_np = llw.reshape(-1, c_image, h_tile, w_tile)  # type: ignore

        # if keeping all patches
        if self.tile_count is None:
            # retain only patches with significant foreground content to speed up inference
            # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference
            thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
            if self.filter_mode == "min":
                # default, keep non-background tiles (small values)
                idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh)
                img_np = img_np[idxs.reshape(-1)]
            elif self.filter_mode == "max":
                idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh)
                img_np = img_np[idxs.reshape(-1)]

        else:
            if len(img_np) > self.tile_count:

                if self.filter_mode == "min":
                    # default, keep non-background tiles (smallest values)
                    idxs = np.argsort(img_np.sum(axis=(1, 2,
                                                       3)))[:self.tile_count]
                    img_np = img_np[idxs]
                elif self.filter_mode == "max":
                    idxs = np.argsort(img_np.sum(axis=(1, 2,
                                                       3)))[-self.tile_count:]
                    img_np = img_np[idxs]
                else:
                    # random subset (more appropriate for WSIs without distinct background)
                    if self.random_idxs is not None:
                        img_np = img_np[self.random_idxs]

            elif len(img_np) < self.tile_count:
                img_np = np.pad(  # type: ignore
                    img_np,
                    [[0, self.tile_count - len(img_np)], [0, 0], [0, 0],
                     [0, 0]],
                    constant_values=self.background_val,
                )

        image, *_ = convert_to_dst_type(src=img_np,
                                        dst=image,
                                        dtype=image.dtype)

        return image
Esempio n. 26
0
def get_mask_edges(seg_pred,
                   seg_gt,
                   label_idx: int = 1,
                   crop: bool = True) -> Tuple[np.ndarray, np.ndarray]:
    """
    Do binary erosion and use XOR for input to get the edges. This
    function is helpful to further calculate metrics such as Average Surface
    Distance and Hausdorff Distance.
    The input images can be binary or labelfield images. If labelfield images
    are supplied, they are converted to binary images using `label_idx`.

    `scipy`'s binary erosion is used to calculate the edges of the binary
    labelfield.

    In order to improve the computing efficiency, before getting the edges,
    the images can be cropped and only keep the foreground if not specifies
    ``crop = False``.

    We require that images are the same size, and assume that they occupy the
    same space (spacing, orientation, etc.).

    Args:
        seg_pred: the predicted binary or labelfield image.
        seg_gt: the actual binary or labelfield image.
        label_idx: for labelfield images, convert to binary with
            `seg_pred = seg_pred == label_idx`.
        crop: crop input images and only keep the foregrounds. In order to
            maintain two inputs' shapes, here the bounding box is achieved
            by ``(seg_pred | seg_gt)`` which represents the union set of two
            images. Defaults to ``True``.
    """

    # Get both labelfields as np arrays
    if isinstance(seg_pred, torch.Tensor):
        seg_pred = seg_pred.detach().cpu().numpy()
    if isinstance(seg_gt, torch.Tensor):
        seg_gt = seg_gt.detach().cpu().numpy()

    if seg_pred.shape != seg_gt.shape:
        raise ValueError(
            f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}."
        )

    # If not binary images, convert them
    if seg_pred.dtype != bool:
        seg_pred = seg_pred == label_idx
    if seg_gt.dtype != bool:
        seg_gt = seg_gt == label_idx

    if crop:
        if not np.any(seg_pred | seg_gt):
            return np.zeros_like(seg_pred), np.zeros_like(seg_gt)

        channel_dim = 0
        seg_pred, seg_gt = np.expand_dims(
            seg_pred, axis=channel_dim), np.expand_dims(seg_gt,
                                                        axis=channel_dim)
        box_start, box_end = generate_spatial_bounding_box(
            np.asarray(seg_pred | seg_gt))
        cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
        seg_pred = convert_data_type(
            np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0]
        seg_gt = convert_data_type(
            np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0]

    # Do binary erosion and use XOR to get edges
    edges_pred = binary_erosion(seg_pred) ^ seg_pred
    edges_gt = binary_erosion(seg_gt) ^ seg_gt

    return edges_pred, edges_gt
Esempio n. 27
0
    def __call__(
            self,
            img: NdarrayOrTensor,
            argmax: Optional[bool] = None,
            to_onehot: Optional[int] = None,
            threshold: Optional[float] = None,
            rounding: Optional[str] = None,
            n_classes: Optional[int] = None,  # deprecated
            num_classes: Optional[int] = None,  # deprecated
            logit_thresh: Optional[float] = None,  # deprecated
            threshold_values: Optional[bool] = None,  # deprecated
    ) -> NdarrayOrTensor:
        """
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``to_onehot`` instead.

        .. deprecated:: 0.7.0
            ``num_classes`` is deprecated, use ``to_onehot`` instead.
            ``logit_thresh`` is deprecated, use ``threshold`` instead.
            ``threshold_values`` is deprecated, use ``threshold`` instead.

        """
        if isinstance(to_onehot, bool):
            warnings.warn(
                "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead."
            )
            to_onehot = num_classes if to_onehot else None
        if isinstance(threshold, bool):
            warnings.warn(
                "`threshold_values=True/False` is deprecated, please use `threshold=value` instead."
            )
            threshold = logit_thresh if threshold else None
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor)
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=0, keepdim=True)

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise AssertionError(
                    "the number of classes for One-Hot must be an integer.")
            img_t = one_hot(img_t, num_classes=to_onehot, dim=0)

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float)
        return img
Esempio n. 28
0
    def resample_if_needed(
        cls,
        data_array: NdarrayOrTensor,
        affine: Optional[NdarrayOrTensor] = None,
        target_affine: Optional[NdarrayOrTensor] = None,
        output_spatial_shape: Union[Sequence[int], int, None] = None,
        mode: str = GridSampleMode.BILINEAR,
        padding_mode: str = GridSamplePadMode.BORDER,
        align_corners: bool = False,
        dtype: DtypeLike = np.float64,
    ):
        """
        Convert the ``data_array`` into the coordinate system specified by
        ``target_affine``, from the current coordinate definition of ``affine``.

        If the transform between ``affine`` and ``target_affine`` could be
        achieved by simply transposing and flipping ``data_array``, no resampling
        will happen.  Otherwise, this function resamples ``data_array`` using the
        transformation computed from ``affine`` and ``target_affine``.

        This function assumes the NIfTI dimension notations. Spatially it
        supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D
        respectively. When saving multiple time steps or multiple channels,
        time and/or modality axes should be appended after the first three
        dimensions. For example, shape of 2D eight-class segmentation
        probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in
        shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a
        single-channel 3D image. The ``convert_to_channel_last`` method can be
        used to convert the data to the format described here.

        Note that the shape of the resampled ``data_array`` may subject to some
        rounding errors. For example, resampling a 20x20 pixel image from pixel
        size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel
        image. However, resampling a 20x20-pixel image from pixel size (2.0,
        2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where
        the image shape is rounded from 13.333x13.333 pixels. In this case
        ``output_spatial_shape`` could be specified so that this function
        writes image data to a designated shape.

        Args:
            data_array: input data array to be converted.
            affine: the current affine of ``data_array``. Defaults to identity
            target_affine: the designated affine of ``data_array``.
                The actual output affine might be different from this value due to precision changes.
            output_spatial_shape: spatial shape of the output image.
                This option is used when resampling is needed.
            mode: available options are {``"bilinear"``, ``"nearest"``, ``"bicubic"``}.
                This option is used when resampling is needed.
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            padding_mode: available options are {``"zeros"``, ``"border"``, ``"reflection"``}.
                This option is used when resampling is needed.
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            align_corners: boolean option of ``grid_sample`` to handle the corner convention.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
            dtype: data type for resampling computation. Defaults to
                ``np.float64`` for best precision. If ``None``, use the data type of input data.
                The output data type of this method is always ``np.float32``.
        """
        orig_type = type(data_array)
        data_array = convert_to_tensor(data_array, track_meta=True)
        if affine is not None:
            data_array.affine = convert_to_tensor(
                affine, track_meta=False)  # type: ignore
        resampler = SpatialResample(mode=mode,
                                    padding_mode=padding_mode,
                                    align_corners=align_corners,
                                    dtype=dtype)
        output_array = resampler(data_array[None],
                                 dst_affine=target_affine,
                                 spatial_size=output_spatial_shape)
        # convert back at the end
        if isinstance(output_array, MetaTensor):
            output_array.applied_operations = []
        data_array, *_ = convert_data_type(
            output_array, output_type=orig_type)  # type: ignore
        affine, *_ = convert_data_type(output_array.affine,
                                       output_type=orig_type)  # type: ignore
        return data_array[0], affine
Esempio n. 29
0
def sliding_window_inference(
    inputs: torch.Tensor,
    roi_size: Union[Sequence[int], int],
    sw_batch_size: int,
    predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor],
                                   Dict[Any, torch.Tensor]]],
    overlap: float = 0.25,
    mode: Union[BlendMode, str] = BlendMode.CONSTANT,
    sigma_scale: Union[Sequence[float], float] = 0.125,
    padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
    cval: float = 0.0,
    sw_device: Union[torch.device, str, None] = None,
    device: Union[torch.device, str, None] = None,
    progress: bool = False,
    roi_weight_map: Union[torch.Tensor, None] = None,
    *args: Any,
    **kwargs: Any,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
    """
    Sliding window inference on `inputs` with `predictor`.

    The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
    Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
    e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
    could be ([128,64,256], [64,32,128]).
    In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
    an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
    so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).

    When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
    To maintain the same spatial sizes, the output image will be cropped to the original input size.

    Args:
        inputs: input image to be processed (assuming NCHW[D])
        roi_size: the spatial window size for inferences.
            When its components have None or non-positives, the corresponding inputs dimension will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        predictor: given input tensor ``patch_data`` in shape NCHW[D],
            The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
            with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
            where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
            N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
            the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
            In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
            to ensure the scaled output ROI sizes are still integers.
            If the `predictor`'s input and output spatial sizes are different,
            we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
        overlap: Amount of overlap between scans.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a `tqdm` progress bar.
        roi_weight_map: pre-computed (non-negative) weight map for each ROI.
            If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
        args: optional args to be passed to ``predictor``.
        kwargs: optional keyword args to be passed to ``predictor``.

    Note:
        - input must be channel-first and have a batch dim, supports N-D sliding window.

    """
    compute_dtype = inputs.dtype
    num_spatial_dims = len(inputs.shape) - 2
    if overlap < 0 or overlap >= 1:
        raise ValueError("overlap must be >= 0 and < 1.")

    # determine image spatial size and batch size
    # Note: all input images must have the same image size and batch size
    batch_size, _, *image_size_ = inputs.shape

    if device is None:
        device = inputs.device
    if sw_device is None:
        sw_device = inputs.device

    roi_size = fall_back_tuple(roi_size, image_size_)
    # in case that image size is smaller than roi size
    image_size = tuple(
        max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
    pad_size = []
    for k in range(len(inputs.shape) - 1, 1, -1):
        diff = max(roi_size[k - 2] - inputs.shape[k], 0)
        half = diff // 2
        pad_size.extend([half, diff - half])
    inputs = F.pad(inputs,
                   pad=pad_size,
                   mode=look_up_option(padding_mode, PytorchPadMode),
                   value=cval)

    scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims,
                                       overlap)

    # Store all slices in list
    slices = dense_patch_slices(image_size, roi_size, scan_interval)
    num_win = len(slices)  # number of windows per image
    total_slices = num_win * batch_size  # total number of windows

    # Create window-level importance map
    valid_patch_size = get_valid_patch_size(image_size, roi_size)
    if valid_patch_size == roi_size and (roi_weight_map is not None):
        importance_map = roi_weight_map
    else:
        try:
            importance_map = compute_importance_map(valid_patch_size,
                                                    mode=mode,
                                                    sigma_scale=sigma_scale,
                                                    device=device)
        except BaseException as e:
            raise RuntimeError(
                "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
            ) from e
    importance_map = convert_data_type(importance_map, torch.Tensor, device,
                                       compute_dtype)[0]  # type: ignore
    # handle non-positive weights
    min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
    importance_map = torch.clamp(importance_map.to(torch.float32),
                                 min=min_non_zero).to(compute_dtype)

    # Perform predictions
    dict_key, output_image_list, count_map_list = None, [], []
    _initialized_ss = -1
    is_tensor_output = True  # whether the predictor's output is a tensor (instead of dict/tuple)

    # for each patch
    for slice_g in tqdm(range(0, total_slices,
                              sw_batch_size)) if progress else range(
                                  0, total_slices, sw_batch_size):
        slice_range = range(slice_g, min(slice_g + sw_batch_size,
                                         total_slices))
        unravel_slice = [
            [slice(int(idx / num_win),
                   int(idx / num_win) + 1),
             slice(None)] + list(slices[idx % num_win]) for idx in slice_range
        ]
        window_data = torch.cat([
            convert_data_type(inputs[win_slice], torch.Tensor)[0]
            for win_slice in unravel_slice
        ]).to(sw_device)
        seg_prob_out = predictor(window_data, *args,
                                 **kwargs)  # batched patch segmentation

        # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
        seg_prob_tuple: Tuple[torch.Tensor, ...]
        if isinstance(seg_prob_out, torch.Tensor):
            seg_prob_tuple = (seg_prob_out, )
        elif isinstance(seg_prob_out, Mapping):
            if dict_key is None:
                dict_key = sorted(
                    seg_prob_out.keys())  # track predictor's output keys
            seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
            is_tensor_output = False
        else:
            seg_prob_tuple = ensure_tuple(seg_prob_out)
            is_tensor_output = False

        # for each output in multi-output list
        for ss, seg_prob in enumerate(seg_prob_tuple):
            seg_prob = seg_prob.to(device)  # BxCxMxNxP or BxCxMxN

            # compute zoom scale: out_roi_size/in_roi_size
            zoom_scale = []
            for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
                    zip(image_size, seg_prob.shape[2:],
                        window_data.shape[2:])):
                _scale = out_w_i / float(in_w_i)
                if not (img_s_i * _scale).is_integer():
                    warnings.warn(
                        f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
                        f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
                    )
                zoom_scale.append(_scale)

            if _initialized_ss < ss:  # init. the ss-th buffer at the first iteration
                # construct multi-resolution outputs
                output_classes = seg_prob.shape[1]
                output_shape = [batch_size, output_classes] + [
                    int(image_size_d * zoom_scale_d)
                    for image_size_d, zoom_scale_d in zip(
                        image_size, zoom_scale)
                ]
                # allocate memory to store the full output and the count for overlapping parts
                output_image_list.append(
                    torch.zeros(output_shape,
                                dtype=compute_dtype,
                                device=device))
                count_map_list.append(
                    torch.zeros([1, 1] + output_shape[2:],
                                dtype=compute_dtype,
                                device=device))
                _initialized_ss += 1

            # resizing the importance_map
            resizer = Resize(spatial_size=seg_prob.shape[2:],
                             mode="nearest",
                             anti_aliasing=False)

            # store the result in the proper location of the full output. Apply weights from importance map.
            for idx, original_idx in zip(slice_range, unravel_slice):
                # zoom roi
                original_idx_zoom = list(
                    original_idx)  # 4D for 2D image, 5D for 3D image
                for axis in range(2, len(original_idx_zoom)):
                    zoomed_start = original_idx[axis].start * zoom_scale[axis -
                                                                         2]
                    zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
                    if not zoomed_start.is_integer() or (
                            not zoomed_end.is_integer()):
                        warnings.warn(
                            f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
                            f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
                            f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
                            f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
                            f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
                            "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
                        )
                    original_idx_zoom[axis] = slice(int(zoomed_start),
                                                    int(zoomed_end), None)
                importance_map_zoom = resizer(
                    importance_map.unsqueeze(0))[0].to(compute_dtype)
                # store results and weights
                output_image_list[ss][
                    original_idx_zoom] += importance_map_zoom * seg_prob[
                        idx - slice_g]
                count_map_list[ss][original_idx_zoom] += (
                    importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(
                        count_map_list[ss][original_idx_zoom].shape))

    # account for any overlapping sections
    for ss in range(len(output_image_list)):
        output_image_list[ss] = (output_image_list[ss] /
                                 count_map_list.pop(0)).to(compute_dtype)

    # remove padding if image_size smaller than roi_size
    for ss, output_i in enumerate(output_image_list):
        if torch.isnan(output_i).any() or torch.isinf(output_i).any():
            warnings.warn(
                "Sliding window inference results contain NaN or Inf.")

        zoom_scale = [
            seg_prob_map_shape_d / roi_size_d
            for seg_prob_map_shape_d, roi_size_d in zip(
                output_i.shape[2:], roi_size)
        ]

        final_slicing: List[slice] = []
        for sp in range(num_spatial_dims):
            slice_dim = slice(
                pad_size[sp * 2],
                image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
            slice_dim = slice(
                int(
                    round(slice_dim.start *
                          zoom_scale[num_spatial_dims - sp - 1])),
                int(
                    round(slice_dim.stop *
                          zoom_scale[num_spatial_dims - sp - 1])),
            )
            final_slicing.insert(0, slice_dim)
        while len(final_slicing) < len(output_i.shape):
            final_slicing.insert(0, slice(None))
        output_image_list[ss] = output_i[final_slicing]

    if dict_key is not None:  # if output of predictor is a dict
        final_output = dict(zip(dict_key, output_image_list))
    else:
        final_output = tuple(output_image_list)  # type: ignore
    final_output = final_output[
        0] if is_tensor_output else final_output  # type: ignore
    if isinstance(inputs, MetaTensor):
        final_output = convert_to_dst_type(final_output,
                                           inputs)[0]  # type: ignore
    return final_output
Esempio n. 30
0
def compute_average_surface_distance(
    y_pred: Union[np.ndarray, torch.Tensor],
    y: Union[np.ndarray, torch.Tensor],
    include_background: bool = False,
    symmetric: bool = False,
    distance_metric: str = "euclidean",
):
    """
    This function is used to compute the Average Surface Distance from `y_pred` to `y`
    under the default setting.
    In addition, if sets ``symmetric = True``, the average symmetric surface distance between
    these two inputs will be returned.
    The implementation refers to `DeepMind's implementation <https://github.com/deepmind/surface-distance>`_.

    Args:
        y_pred: input data to compute, typical segmentation model output.
            It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
            should be binarized.
        y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch.
            The values should be binarized.
        include_background: whether to skip distance computation on the first channel of
            the predicted output. Defaults to ``False``.
        symmetric: whether to calculate the symmetric average surface distance between
            `seg_pred` and `seg_gt`. Defaults to ``False``.
        distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]
            the metric used to compute surface distance. Defaults to ``"euclidean"``.
    """

    if not include_background:
        y_pred, y = ignore_background(y_pred=y_pred, y=y)

    if isinstance(y, torch.Tensor):
        y = y.float()
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.float()

    if y.shape != y_pred.shape:
        raise ValueError(
            f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}."
        )

    batch_size, n_class = y_pred.shape[:2]
    asd = np.empty((batch_size, n_class))

    for b, c in np.ndindex(batch_size, n_class):
        (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
        if not np.any(edges_gt):
            warnings.warn(
                f"the ground truth of class {c} is all 0, this may result in nan/inf distance."
            )
        if not np.any(edges_pred):
            warnings.warn(
                f"the prediction of class {c} is all 0, this may result in nan/inf distance."
            )
        surface_distance = get_surface_distance(
            edges_pred, edges_gt, distance_metric=distance_metric)
        if symmetric:
            surface_distance_2 = get_surface_distance(
                edges_gt, edges_pred, distance_metric=distance_metric)
            surface_distance = np.concatenate(
                [surface_distance, surface_distance_2])
        asd[b, c] = np.nan if surface_distance.shape == (
            0, ) else surface_distance.mean()

    return convert_data_type(asd, torch.Tensor)[0]