Exemplo n.º 1
    def __init__(
        spatial_size: Optional[Union[Sequence[int], int]] = None,
        normalized: bool = False,
        mode: str = GridSampleMode.BILINEAR,
        padding_mode: str = GridSamplePadMode.ZEROS,
        align_corners: bool = False,
        reverse_indexing: bool = True,
        zero_centered: Optional[bool] = None,
    ) -> None:
        Apply affine transformations with a batch of affine matrices.

        When `normalized=False` and `reverse_indexing=True`,
        it does the commonly used resampling in the 'pull' direction
        following the ``scipy.ndimage.affine_transform`` convention.
        In this case `theta` is equivalent to (ndim+1, ndim+1) input ``matrix`` of ``scipy.ndimage.affine_transform``,
        operates on homogeneous coordinates.
        See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html

        When `normalized=True` and `reverse_indexing=False`,
        it applies `theta` to the normalized coordinates (coords. in the range of [-1, 1]) directly.
        This is often used with `align_corners=False` to achieve resolution-agnostic resampling,
        thus useful as a part of trainable modules such as the spatial transformer networks.
        See also: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

            spatial_size: output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src` input of `self.forward`.
            normalized: indicating whether the provided affine matrix `theta` is defined
                for the normalized coordinates. If `normalized=False`, `theta` will be converted
                to operate on normalized coordinates as pytorch affine_grid works with the normalized
            mode: {``"bilinear"``, ``"nearest"``}
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"zeros"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            align_corners: see also https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html.
            reverse_indexing: whether to reverse the spatial indexing of image and coordinates.
                set to `False` if `theta` follows pytorch's default "D, H, W" convention.
                set to `True` if `theta` follows `scipy.ndimage` default "i, j, k" convention.
            zero_centered: whether the affine is applied to coordinates in a zero-centered value range.
                With `zero_centered=True`, for example, the center of rotation will be the
                spatial center of the input; with `zero_centered=False`, the center of rotation will be the
                origin of the input. This option is only available when `normalized=False`,
                where the default behaviour is `False` if unspecified.
                See also: :py:func:`monai.networks.utils.normalize_transform`.
        self.spatial_size = ensure_tuple(spatial_size) if spatial_size is not None else None
        self.normalized = normalized
        self.mode: str = look_up_option(mode, GridSampleMode)
        self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
        self.align_corners = align_corners
        self.reverse_indexing = reverse_indexing
        if zero_centered is not None and self.normalized:
            raise ValueError("`normalized=True` is not compatible with the `zero_centered` option.")
        self.zero_centered = zero_centered if zero_centered is not None else False
Exemplo n.º 2
def convert_mask_to_box(
        boxes_mask: NdarrayOrTensor,
        bg_label: int = -1,
        label_dtype=torch.long) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
    Convert int16 mask image to box, which has the same size with the input image

        boxes_mask: int16 array, sized (num_box, H, W). Each channel represents a box.
            The foreground region in channel c has intensity of labels[c].
            The background intensity is bg_label.
        bg_label: background labels for the boxes_mask
        box_dtype: output dtype for boxes
        label_dtype: output dtype for labels

        - bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.
        - classification foreground(fg) labels, dtype should be int, sized (N,).
    look_up_option(len(boxes_mask.shape), [3, 4])
    spatial_size = list(boxes_mask.shape[1:])
    spatial_dims = get_spatial_dims(spatial_size=spatial_size)

    boxes_mask_np, *_ = convert_data_type(boxes_mask, np.ndarray)

    boxes_list = []
    labels_list = []
    for b in range(boxes_mask_np.shape[0]):
        fg_indices = np.nonzero(boxes_mask_np[b, ...] - bg_label)
        if fg_indices[0].shape[0] == 0:
        boxes_b = []
        for fd_i in fg_indices:
            boxes_b.append(min(fd_i))  # top left corner
        for fd_i in fg_indices:
            boxes_b.append(max(fd_i) + 1 - TO_REMOVE)  # bottom right corner
        if spatial_dims == 2:
            labels_list.append(boxes_mask_np[b, fg_indices[0][0],
        if spatial_dims == 3:
            labels_list.append(boxes_mask_np[b, fg_indices[0][0],

    if len(boxes_list) == 0:
        boxes_np, labels_np = np.zeros([0, 2 * spatial_dims]), np.zeros([0])
        boxes_np, labels_np = np.asarray(boxes_list), np.asarray(labels_list)
    boxes, *_ = convert_to_dst_type(src=boxes_np,
    labels, *_ = convert_to_dst_type(src=labels_np,
    return boxes, labels
Exemplo n.º 3
    def __call__(
        img: torch.Tensor,
        argmax: Optional[bool] = None,
        to_onehot: Optional[bool] = None,
        num_classes: Optional[int] = None,
        threshold_values: Optional[bool] = None,
        logit_thresh: Optional[float] = None,
        rounding: Optional[str] = None,
        n_classes: Optional[int] = None,
    ) -> torch.Tensor:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: whether to convert input data into the one-hot format.
                Defaults to ``self.to_onehot``.
            num_classes: the number of classes to convert to One-Hot format.
                Defaults to ``self.num_classes``.
            threshold_values: whether threshold the float value to int number 0 or 1.
                Defaults to ``self.threshold_values``.
            logit_thresh: the threshold value for thresholding operation..
                Defaults to ``self.logit_thresh``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``num_classes`` instead.

        # in case the new num_classes is default but you still call deprecated n_classes
        if n_classes is not None and num_classes is None:
            num_classes = n_classes
        if argmax or self.argmax:
            img = torch.argmax(img, dim=0, keepdim=True)

        if to_onehot or self.to_onehot:
            _nclasses = self.num_classes if num_classes is None else num_classes
            if not isinstance(_nclasses, int):
                raise AssertionError("One of self.num_classes or num_classes must be an integer")
            img = one_hot(img, num_classes=_nclasses, dim=0)

        if threshold_values or self.threshold_values:
            img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh)

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img = torch.round(img)

        return img.float()
Exemplo n.º 4
def resolve_writer(ext_name, error_if_not_found=True) -> Sequence:
    Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS``
    according to the filename extension key ``ext_name``.

        ext_name: the filename extension of the image.
            As an indexing key it will be converted to a lower case string.
        error_if_not_found: whether to raise an error if no suitable image writer is found.
            if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``.
    fmt = f"{ext_name}".lower()
    if fmt.startswith("."):
        fmt = fmt[1:]
    avail_writers = []
    default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ())
    for _writer in look_up_option(fmt,
            )  # this triggers `monai.utils.module.require_pkg` to check the system availability
        except OptionalImportError:
        except Exception:  # other writer init errors indicating it exists
    if not avail_writers and error_if_not_found:
        raise OptionalImportError(f"No ImageWriter backend found for {fmt}.")
    writer_tuple = ensure_tuple(avail_writers)
    SUPPORTED_WRITERS[fmt] = writer_tuple
    return writer_tuple
Exemplo n.º 5
    def astype(self, dtype, device=None, *_args, **_kwargs):
        Cast to ``dtype``, sharing data whenever possible.

            dtype: dtypes such as np.float32, torch.float, "np.float32", float.
            device: the device if `dtype` is a torch data type.
            _args: additional args (currently unused).
            _kwargs: additional kwargs (currently unused).

            data array instance
        if isinstance(dtype, str):
            mod_str, *dtype = dtype.split(".", 1)
            dtype = mod_str if not dtype else dtype[0]
            mod_str = getattr(dtype, "__module__", "torch")
        mod_str = look_up_option(mod_str, {"torch", "numpy", "np"},
        if mod_str == "torch":
            out_type = torch.Tensor
        elif mod_str in ("numpy", "np"):
            out_type = np.ndarray
            out_type = None
        return self.get_array(output_type=out_type, dtype=dtype, device=device)
    def __init__(
        patch_size: Sequence[int],
        start_pos: Sequence[int] = (),
        mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP,
        **pad_opts: Dict,

            patch_size: size of patches to generate slices for, 0/None selects whole dimension
            start_pos: starting position in the array, default is 0 for each dimension
            mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
                ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
                One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
            pad_opts: padding options, see numpy.pad

            The `patch_size` is the size of the
            patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which
            will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D
            array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be
            specified by a `patch_size` of (10, 10, 10).

        self.patch_size = (None, ) + tuple(patch_size)
        self.start_pos = ensure_tuple(start_pos)
        self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
        self.pad_opts = pad_opts
Exemplo n.º 7
def check_hash(filepath: PathLike,
               val: Optional[str] = None,
               hash_type: str = "md5") -> bool:
    Verify hash signature of specified file.

        filepath: path of source file to verify hash value.
        val: expected hash value of the file.
        hash_type: type of hash algorithm to use, default is `"md5"`.
            The supported hash types are `"md5"`, `"sha1"`, `"sha256"`, `"sha512"`.
            See also: :py:data:`monai.apps.utils.SUPPORTED_HASH_TYPES`.

    if val is None:
            f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}."
        return True
    actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES)
    actual_hash = actual_hash_func()
        with open(filepath, "rb") as f:
            for chunk in iter(lambda: f.read(1024 * 1024), b""):
    except Exception as e:
        logger.error(f"Exception in check_hash: {e}")
        return False
    if val != actual_hash.hexdigest():
        logger.error(f"check_hash failed {actual_hash.hexdigest()}.")
        return False

    logger.info(f"Verified '{_basename(filepath)}', {hash_type}: {val}.")
    return True
Exemplo n.º 8
    def resample_and_clip(
        data_array: NdarrayOrTensor,
        output_spatial_shape: Optional[Sequence[int]] = None,
        mode: str = InterpolateMode.BICUBIC,
        Resample ``data_array`` to ``output_spatial_shape`` if needed.
            data_array: input data array. This method assumes the 'channel-last' format.
            output_spatial_shape: output spatial shape.
            mode: interpolation mode, defautl is ``InterpolateMode.BICUBIC``.

        data: np.ndarray = convert_data_type(data_array, np.ndarray)[0]
        if output_spatial_shape is not None:
            output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2)
            mode = look_up_option(mode, InterpolateMode)
            align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False
            xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners)
            _min, _max = np.min(data), np.max(data)
            if len(data.shape) == 3:
                data = np.moveaxis(data, -1, 0)  # to channel first
                data = convert_data_type(xform(data), np.ndarray)[0]  # type: ignore
                data = np.moveaxis(data, 0, -1)
            else:  # (H, W)
                data = np.expand_dims(data, 0)  # make a channel
                data = convert_data_type(xform(data), np.ndarray)[0][0]  # type: ignore
            if mode != InterpolateMode.NEAREST:
                data = np.clip(data, _min, _max)
        return data
Exemplo n.º 9
def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] = MetricReduction.MEAN):
    This function is to do the metric reduction for calculated `not-nan` metrics of each sample's each class.
    The function also returns `not_nans`, which counts the number of not nans for the metric.

        f: a tensor that contains the calculated metric scores per batch and
            per class. The first two dims should be batch and class.
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``.
            if "none", return the input f tensor and not_nans.
        Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``.

        ValueError: When ``reduction`` is not one of
            ["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].

    # some elements might be Nan (if ground truth y was missing (zeros))
    # we need to account for it
    nans = torch.isnan(f)
    not_nans = (~nans).float()

    t_zero = torch.zeros(1, device=f.device, dtype=f.dtype)
    reduction = look_up_option(reduction, MetricReduction)
    if reduction == MetricReduction.NONE:
        return f, not_nans

    f[nans] = 0
    if reduction == MetricReduction.MEAN:
        # 2 steps, first, mean by channel (accounting for nans), then by batch
        not_nans = not_nans.sum(dim=1)
        f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero)  # channel average

        not_nans = (not_nans > 0).float().sum(dim=0)
        f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero)  # batch average

    elif reduction == MetricReduction.SUM:
        not_nans = not_nans.sum(dim=[0, 1])
        f = torch.sum(f, dim=[0, 1])  # sum over the batch and channel dims
    elif reduction == MetricReduction.MEAN_BATCH:
        not_nans = not_nans.sum(dim=0)
        f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero)  # batch average
    elif reduction == MetricReduction.SUM_BATCH:
        not_nans = not_nans.sum(dim=0)
        f = f.sum(dim=0)  # the batch sum
    elif reduction == MetricReduction.MEAN_CHANNEL:
        not_nans = not_nans.sum(dim=1)
        f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero)  # channel average
    elif reduction == MetricReduction.SUM_CHANNEL:
        not_nans = not_nans.sum(dim=1)
        f = f.sum(dim=1)  # the channel sum
    elif reduction != MetricReduction.NONE:
        raise ValueError(
            f"Unsupported reduction: {reduction}, available options are "
            '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].'
    return f, not_nans
Exemplo n.º 10
                def _compute_op(op: str, d: np.ndarray):
                    if not op.endswith("percentile"):
                        c_op = look_up_option(op, supported_ops)
                        return c_op(d)

                    threshold = int(op.split("percentile")[0])
                    return supported_ops["90percentile"](
                        (d, threshold))  # type: ignore
Exemplo n.º 11
    def __call__(
        img: NdarrayOrTensor,
        argmax: Optional[bool] = None,
        to_onehot: Optional[int] = None,
        threshold: Optional[float] = None,
        rounding: Optional[str] = None
    ) -> NdarrayOrTensor:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        img_t: torch.Tensor
        img_t, *_ = convert_data_type(img, torch.Tensor)  # type: ignore
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise ValueError("the number of classes for One-Hot must be an integer.")
            img_t = one_hot(
                img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float)

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get("dtype", torch.float))
        return img
Exemplo n.º 12
    def __init__(
        include_background: bool = True,
        to_onehot_y: bool = False,
        sigmoid: bool = False,
        softmax: bool = False,
        other_act: Optional[Callable] = None,
        w_type: Union[Weight, str] = Weight.SQUARE,
        reduction: Union[LossReduction, str] = LossReduction.MEAN,
        smooth_nr: float = 1e-5,
        smooth_dr: float = 1e-5,
        batch: bool = False,
    ) -> None:
            include_background: If False channel index 0 (background category) is excluded from the calculation.
            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
            sigmoid: If True, apply a sigmoid function to the prediction.
            softmax: If True, apply a softmax function to the prediction.
            other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
                other activation layers, Defaults to ``None``. for example:
                `other_act = torch.tanh`.
            w_type: {``"square"``, ``"simple"``, ``"uniform"``}
                Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
            smooth_nr: a small constant added to the numerator to avoid zero.
            smooth_dr: a small constant added to the denominator to avoid nan.
            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
                Defaults to False, intersection over union is computed from each item in the batch.

            TypeError: When ``other_act`` is not an ``Optional[Callable]``.
            ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
                Incompatible values.

        if other_act is not None and not callable(other_act):
            raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
        if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
            raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")

        self.include_background = include_background
        self.to_onehot_y = to_onehot_y
        self.sigmoid = sigmoid
        self.softmax = softmax
        self.other_act = other_act

        self.w_type = look_up_option(w_type, Weight)

        self.smooth_nr = float(smooth_nr)
        self.smooth_dr = float(smooth_dr)
        self.batch = batch
Exemplo n.º 13
 def __init__(
     data_type: str = "tensor",
     dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
     device: Optional[torch.device] = None,
 ) -> None:
     self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"})
     self.dtype = dtype
     self.device = device
Exemplo n.º 14
    def __call__(
            img: NdarrayOrTensor,
            meta_data: Optional[Dict] = None,
            mask: Optional[np.ndarray] = None) -> Tuple[NdarrayOrTensor, Dict]:
        Compute statistics for the intensity of input image.

            img: input image to compute intensity stats.
            meta_data: meta data dictionary to store the statistics data, if None, will create an empty dictionary.
            mask: if not None, mask the image to extract only the interested area to compute statistics.
                mask must have the same shape as input `img`.

        img_np: np.ndarray
        img_np, *_ = convert_data_type(img, np.ndarray)  # type: ignore
        if meta_data is None:
            meta_data = {}

        if mask is not None:
            if mask.shape != img_np.shape or mask.dtype != bool:
                raise TypeError(
                    "mask must be bool array with the same shape as input `img`."
            img_np = img_np[mask]

        supported_ops = {
            "mean": np.nanmean,
            "median": np.nanmedian,
            "max": np.nanmax,
            "min": np.nanmin,
            "std": np.nanstd,

        def _compute(op: Callable, data: np.ndarray):
            if self.channel_wise:
                return [op(c) for c in data]
            return op(data)

        custom_index = 0
        for o in self.ops:
            if isinstance(o, str):
                o = look_up_option(o, supported_ops.keys())
                meta_data[self.key_prefix + "_" + o] = _compute(
                    supported_ops[o], img_np)  # type: ignore
            elif callable(o):
                meta_data[self.key_prefix + "_custom_" +
                          str(custom_index)] = _compute(o, img_np)
                custom_index += 1
                raise ValueError(
                    "ops must be key string for predefined operations or callable function."

        return img, meta_data
Exemplo n.º 15
def compute_importance_map(
    patch_size: Tuple[int, ...],
    mode: Union[BlendMode, str] = BlendMode.CONSTANT,
    sigma_scale: Union[Sequence[float], float] = 0.125,
    device: Union[torch.device, int, str] = "cpu",
) -> torch.Tensor:
    """Get importance map for different weight modes.

        patch_size: Size of the required importance map. This should be either H, W [,D].
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: Sigma_scale to calculate sigma for each dimension
            (sigma = sigma_scale * dim_size). Used for gaussian mode only.
        device: Device to put importance map on.

        ValueError: When ``mode`` is not one of ["constant", "gaussian"].

        Tensor of size patch_size.

    mode = look_up_option(mode, BlendMode)
    device = torch.device(device)  # type: ignore[arg-type]
    if mode == BlendMode.CONSTANT:
        importance_map = torch.ones(patch_size, device=device).float()
    elif mode == BlendMode.GAUSSIAN:
        center_coords = [i // 2 for i in patch_size]
        sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))
        sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]

        importance_map = torch.zeros(patch_size, device=device)
        importance_map[tuple(center_coords)] = 1
        pt_gaussian = GaussianFilter(len(patch_size),
        importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
        importance_map = importance_map.squeeze(0).squeeze(0)
        importance_map = importance_map / torch.max(importance_map)
        importance_map = importance_map.float()

        # importance_map cannot be 0, otherwise we may end up with nans!
        min_non_zero = importance_map[importance_map != 0].min().item()
        importance_map = torch.clamp(importance_map, min=min_non_zero)
        raise ValueError(
            f"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}]."

    return importance_map
Exemplo n.º 16
    def __init__(self, submodule, dim: int = 1, mode: Union[str, SkipMode] = "cat") -> None:

            submodule: the module defines the trainable branch.
            dim: the dimension over which the tensors are concatenated.
                Used when mode is ``"cat"``.
            mode: ``"cat"``, ``"add"``, ``"mul"``. defaults to ``"cat"``.
        self.submodule = submodule
        self.dim = dim
        self.mode = look_up_option(mode, SkipMode).value
Exemplo n.º 17
    def get_constructor(self, factory_name: str, *args) -> Any:
        Get the constructor for the given factory name and arguments.

            TypeError: When ``factory_name`` is not a ``str``.


        if not isinstance(factory_name, str):
            raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.")

        func = look_up_option(factory_name.upper(), self.factories)
        return func(*args)
Exemplo n.º 18
    def __init__(
        output_dir: PathLike = "./",
        output_postfix: str = "seg",
        output_ext: str = ".png",
        resample: bool = True,
        mode: Union[InterpolateMode, str] = InterpolateMode.NEAREST,
        scale: Optional[int] = None,
        data_root_dir: PathLike = "",
        separate_folder: bool = True,
        print_log: bool = True,
    ) -> None:
            output_dir: output image directory.
            output_postfix: a string appended to all output file names.
            output_ext: output file extension name.
            resample: whether to resample and resize if providing spatial_shape in the metadata.
            mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
                The interpolation mode. Defaults to ``"nearest"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
                [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
            data_root_dir: if not empty, it specifies the beginning parts of the input file's
                absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from
                `data_root_dir` to preserve folder structure when saving in case there are files in different
                folders with the same file names. for example:
                input_file_name: /foo/bar/test1/image.png,
                postfix: seg
                output_ext: png
                output_dir: /output,
                data_root_dir: /foo/bar,
                output will be: /output/test1/image/image_seg.png
            separate_folder: whether to save every file in a separate folder, for example: if input filename is
                `image.png`, postfix is `seg` and folder_path is `output`, if `True`, save as:
                `output/image/image_seg.png`, if `False`, save as `output/image_seg.nii`. default to `True`.
            print_log: whether to print log about the saved PNG file path, etc. default to `True`.

        self.output_dir = output_dir
        self.output_postfix = output_postfix
        self.output_ext = output_ext
        self.resample = resample
        self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
        self.scale = scale
        self.data_root_dir = data_root_dir
        self.separate_folder = separate_folder
        self.print_log = print_log

        self._data_index = 0
Exemplo n.º 19
    def export_config_file(cls, config: Dict, filepath: PathLike, fmt="json", **kwargs):
        Export the config content to the specified file path (currently support JSON and YAML files).

            config: source config content to export.
            filepath: target file path to save.
            fmt: format of config content, currently support ``"json"`` and ``"yaml"``.
            kwargs: other arguments for ``json.dump`` or ``yaml.safe_dump``, depends on the file format.

        _filepath: str = str(Path(filepath))
        writer = look_up_option(fmt.lower(), {"json", "yaml"})
        with open(_filepath, "w") as f:
            if writer == "json":
                return json.dump(config, f, **kwargs)
            if writer == "yaml":
                return yaml.safe_dump(config, f, **kwargs)
            raise ValueError(f"only support JSON or YAML config file so far, got {writer}.")
Exemplo n.º 20
 def __init__(
     include_background: bool = True,
     reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH,
     weight_type: Union[Weight, str] = Weight.SQUARE,
 ) -> None:
     self.include_background = include_background
     reduction_options = [
     self.reduction = reduction
     if self.reduction not in reduction_options:
         raise ValueError(f"reduction must be one of {reduction_options}")
     self.weight_type = look_up_option(weight_type, Weight)
Exemplo n.º 21
    def __init__(
        hidden_size: int,
        mlp_dim: int,
        dropout_rate: float = 0.0,
        act: Union[Tuple, str] = "GELU",
    ) -> None:
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
            dropout_rate: faction of the input units to drop.
            act: activation type and arguments. Defaults to GELU.
            dropout_mode: dropout mode, can be "vit" or "swin".
                "vit" mode uses two dropout instances as implemented in
                "swin" corresponds to one instance as implemented in



        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")
        mlp_dim = mlp_dim or hidden_size
        self.linear1 = nn.Linear(hidden_size, mlp_dim)
        self.linear2 = nn.Linear(mlp_dim, hidden_size)
        self.fn = get_act_layer(act)
        self.drop1 = nn.Dropout(dropout_rate)
        dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
        if dropout_opt == "vit":
            self.drop2 = nn.Dropout(dropout_rate)
        elif dropout_opt == "swin":
            self.drop2 = self.drop1
            raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
Exemplo n.º 22
def register_writer(ext_name, *im_writers):
    Register ``ImageWriter``, so that writing a file with filename extension ``ext_name``
    could be resolved to a tuple of potentially appropriate ``ImageWriter``.
    The customised writers could be registered by:

    .. code-block:: python

        from monai.data import register_writer
        # `MyWriter` must implement `ImageWriter` interface
        register_writer("nii", MyWriter)

        ext_name: the filename extension of the image.
            As an indexing key, it will be converted to a lower case string.
        im_writers: one or multiple ImageWriter classes with high priority ones first.
    fmt = f"{ext_name}".lower()
    if fmt.startswith("."):
        fmt = fmt[1:]
    existing = look_up_option(fmt, SUPPORTED_WRITERS, default=())
    all_writers = im_writers + existing
    SUPPORTED_WRITERS[fmt] = all_writers
Exemplo n.º 23
    def __init__(
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        mode: Union[ChannelMatching, str] = ChannelMatching.PAD,

            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of input channels.
            out_channels: number of output channels.
            mode: {``"pad"``, ``"project"``}
                Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.

                - ``"pad"``: with zero padding.
                - ``"project"``: with a trainable conv with kernel size one.
        self.project = None
        self.pad = None
        if in_channels == out_channels:
        mode = look_up_option(mode, ChannelMatching)
        if mode == ChannelMatching.PROJECT:
            conv_type = Conv[Conv.CONV, spatial_dims]
            self.project = conv_type(in_channels, out_channels, kernel_size=1)
        if mode == ChannelMatching.PAD:
            if in_channels > out_channels:
                raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.')
            pad_1 = (out_channels - in_channels) // 2
            pad_2 = out_channels - in_channels - pad_1
            pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0]
            self.pad = tuple(pad)
Exemplo n.º 24
    def __init__(
        spatial_dims: int,
        num_classes: int,
        num_anchors: int,
        size_divisible: Union[Sequence[int], int] = 1,

        self.spatial_dims = look_up_option(spatial_dims, supported=[1, 2, 3])
        self.num_classes = num_classes
        self.size_divisible = ensure_tuple_rep(size_divisible,

        if not hasattr(feature_extractor, "out_channels"):
            raise ValueError(
                "feature_extractor should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
                "same for all the levels)")
        self.feature_extractor = feature_extractor

        self.feature_map_channels: int = self.feature_extractor.out_channels
        self.num_anchors = num_anchors
        self.classification_head = RetinaNetClassificationHead(
        self.regression_head = RetinaNetRegressionHead(

        self.cls_key: str = "classification"
        self.box_reg_key: str = "box_regression"
Exemplo n.º 25
    def __init__(
        in_chans: int,
        embed_dim: int,
        window_size: Sequence[int],
        patch_size: Sequence[int],
        depths: Sequence[int],
        num_heads: Sequence[int],
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Type[LayerNorm] = nn.LayerNorm,
        patch_norm: bool = False,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
    ) -> None:
            in_chans: dimension of input channels.
            embed_dim: number of linear projection output channels.
            window_size: local window size.
            patch_size: patch size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            drop_path_rate: stochastic depth rate.
            norm_layer: normalization layer.
            patch_norm: add normalization after patch embedding.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: spatial dimension.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).

        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.window_size = window_size
        self.patch_size = patch_size
        self.patch_embed = PatchEmbed(
            norm_layer=norm_layer if self.patch_norm else None,  # type: ignore
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        self.layers1 = nn.ModuleList()
        self.layers2 = nn.ModuleList()
        self.layers3 = nn.ModuleList()
        self.layers4 = nn.ModuleList()
        down_sample_mod = look_up_option(downsample,
                                         MERGING_MODE) if isinstance(
                                             downsample, str) else downsample
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
            if i_layer == 0:
            elif i_layer == 1:
            elif i_layer == 2:
            elif i_layer == 3:
        self.num_features = int(embed_dim * 2**(self.num_layers - 1))
Exemplo n.º 26
def iter_patch(
    arr: np.ndarray,
    patch_size: Union[Sequence[int], int] = 0,
    start_pos: Sequence[int] = (),
    copy_back: bool = True,
    mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP,
    **pad_opts: Dict,
    Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr`
    but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative
    to start in the padded region). If `copy_back` is True the values from each patch are written back to `arr`.

        arr: array to iterate over
        patch_size: size of patches to generate slices for, 0 or None selects whole dimension
        start_pos: starting position in the array, default is 0 for each dimension
        copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes
        mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
            ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
        pad_opts: padding options, see `numpy.pad`

        Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is
        True these changes will be reflected in `arr` once the iteration completes.

        coordinate format is:

            [1st_dim_start, 1st_dim_end,
             2nd_dim_start, 2nd_dim_end,
             Nth_dim_start, Nth_dim_end]]

    # ensure patchSize and startPos are the right length
    patch_size_ = get_valid_patch_size(arr.shape, patch_size)
    start_pos = ensure_tuple_size(start_pos, arr.ndim)

    # pad image by maximum values needed to ensure patches are taken from inside an image
    arrpad = np.pad(arr, tuple((p, p) for p in patch_size_),
                    look_up_option(mode, NumpyPadMode).value, **pad_opts)

    # choose a start position in the padded image
    start_pos_padded = tuple(s + p for s, p in zip(start_pos, patch_size_))

    # choose a size to iterate over which is smaller than the actual padded image to prevent producing
    # patches which are only in the padded regions
    iter_size = tuple(s + p for s, p in zip(arr.shape, patch_size_))

    for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded):
        # compensate original image padding
        coords_no_pad = tuple((coord.start - p, coord.stop - p)
                              for coord, p in zip(slices, patch_size_))
        yield arrpad[slices], np.asarray(
        )  # data and coords (in numpy; works with torch loader)

    # copy back data from the padded image if required
    if copy_back:
        slices = tuple(slice(p, p + s) for p, s in zip(patch_size_, arr.shape))
        arr[...] = arrpad[slices]
Exemplo n.º 27
def compute_roc_auc(y_pred: torch.Tensor, y: torch.Tensor, average: Union[Average, str] = Average.MACRO):
    """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/

        y_pred: input data to compute, typical classification model output.
            the first dim must be batch, if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        y: ground truth to compute ROC AUC metric, the first dim must be batch.
            if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

        ValueError: When ``y_pred`` dimension is not one of [1, 2].
        ValueError: When ``y`` dimension is not one of [1, 2].
        ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

        ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

    y_pred_ndim = y_pred.ndimension()
    y_ndim = y.ndimension()
    if y_pred_ndim not in (1, 2):
        raise ValueError(
            f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
    if y_ndim not in (1, 2):
        raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
    if y_pred_ndim == 2 and y_pred.shape[1] == 1:
        y_pred = y_pred.squeeze(dim=-1)
        y_pred_ndim = 1
    if y_ndim == 2 and y.shape[1] == 1:
        y = y.squeeze(dim=-1)

    if y_pred_ndim == 1:
        return _calculate(y_pred, y)

    if y.shape != y_pred.shape:
        raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")

    average = look_up_option(average, Average)
    if average == Average.MICRO:
        return _calculate(y_pred.flatten(), y.flatten())
    y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
    auc_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
    if average == Average.NONE:
        return auc_values
    if average == Average.MACRO:
        return np.mean(auc_values)
    if average == Average.WEIGHTED:
        weights = [sum(y_) for y_ in y]
        return np.average(auc_values, weights=weights)
    raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
Exemplo n.º 28
def sliding_window_inference(
    inputs: torch.Tensor,
    roi_size: Union[Sequence[int], int],
    sw_batch_size: int,
    predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor],
                                   Dict[Any, torch.Tensor]]],
    overlap: float = 0.25,
    mode: Union[BlendMode, str] = BlendMode.CONSTANT,
    sigma_scale: Union[Sequence[float], float] = 0.125,
    padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
    cval: float = 0.0,
    sw_device: Union[torch.device, str, None] = None,
    device: Union[torch.device, str, None] = None,
    progress: bool = False,
    roi_weight_map: Union[torch.Tensor, None] = None,
    *args: Any,
    **kwargs: Any,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
    Sliding window inference on `inputs` with `predictor`.

    The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
    Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
    e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
    could be ([128,64,256], [64,32,128]).
    In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
    an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
    so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).

    When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
    To maintain the same spatial sizes, the output image will be cropped to the original input size.

        inputs: input image to be processed (assuming NCHW[D])
        roi_size: the spatial window size for inferences.
            When its components have None or non-positives, the corresponding inputs dimension will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        predictor: given input tensor ``patch_data`` in shape NCHW[D],
            The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
            with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
            where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
            N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
            the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
            In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
            to ensure the scaled output ROI sizes are still integers.
            If the `predictor`'s input and output spatial sizes are different,
            we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
        overlap: Amount of overlap between scans.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a `tqdm` progress bar.
        roi_weight_map: pre-computed (non-negative) weight map for each ROI.
            If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
        args: optional args to be passed to ``predictor``.
        kwargs: optional keyword args to be passed to ``predictor``.

        - input must be channel-first and have a batch dim, supports N-D sliding window.

    compute_dtype = inputs.dtype
    num_spatial_dims = len(inputs.shape) - 2
    if overlap < 0 or overlap >= 1:
        raise ValueError("overlap must be >= 0 and < 1.")

    # determine image spatial size and batch size
    # Note: all input images must have the same image size and batch size
    batch_size, _, *image_size_ = inputs.shape

    if device is None:
        device = inputs.device
    if sw_device is None:
        sw_device = inputs.device

    roi_size = fall_back_tuple(roi_size, image_size_)
    # in case that image size is smaller than roi size
    image_size = tuple(
        max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
    pad_size = []
    for k in range(len(inputs.shape) - 1, 1, -1):
        diff = max(roi_size[k - 2] - inputs.shape[k], 0)
        half = diff // 2
        pad_size.extend([half, diff - half])
    inputs = F.pad(inputs,
                   mode=look_up_option(padding_mode, PytorchPadMode),

    scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims,

    # Store all slices in list
    slices = dense_patch_slices(image_size, roi_size, scan_interval)
    num_win = len(slices)  # number of windows per image
    total_slices = num_win * batch_size  # total number of windows

    # Create window-level importance map
    valid_patch_size = get_valid_patch_size(image_size, roi_size)
    if valid_patch_size == roi_size and (roi_weight_map is not None):
        importance_map = roi_weight_map
            importance_map = compute_importance_map(valid_patch_size,
        except BaseException as e:
            raise RuntimeError(
                "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
            ) from e
    importance_map = convert_data_type(importance_map, torch.Tensor, device,
                                       compute_dtype)[0]  # type: ignore
    # handle non-positive weights
    min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
    importance_map = torch.clamp(importance_map.to(torch.float32),

    # Perform predictions
    dict_key, output_image_list, count_map_list = None, [], []
    _initialized_ss = -1
    is_tensor_output = True  # whether the predictor's output is a tensor (instead of dict/tuple)

    # for each patch
    for slice_g in tqdm(range(0, total_slices,
                              sw_batch_size)) if progress else range(
                                  0, total_slices, sw_batch_size):
        slice_range = range(slice_g, min(slice_g + sw_batch_size,
        unravel_slice = [
            [slice(int(idx / num_win),
                   int(idx / num_win) + 1),
             slice(None)] + list(slices[idx % num_win]) for idx in slice_range
        window_data = torch.cat([
            convert_data_type(inputs[win_slice], torch.Tensor)[0]
            for win_slice in unravel_slice
        seg_prob_out = predictor(window_data, *args,
                                 **kwargs)  # batched patch segmentation

        # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
        seg_prob_tuple: Tuple[torch.Tensor, ...]
        if isinstance(seg_prob_out, torch.Tensor):
            seg_prob_tuple = (seg_prob_out, )
        elif isinstance(seg_prob_out, Mapping):
            if dict_key is None:
                dict_key = sorted(
                    seg_prob_out.keys())  # track predictor's output keys
            seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
            is_tensor_output = False
            seg_prob_tuple = ensure_tuple(seg_prob_out)
            is_tensor_output = False

        # for each output in multi-output list
        for ss, seg_prob in enumerate(seg_prob_tuple):
            seg_prob = seg_prob.to(device)  # BxCxMxNxP or BxCxMxN

            # compute zoom scale: out_roi_size/in_roi_size
            zoom_scale = []
            for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
                    zip(image_size, seg_prob.shape[2:],
                _scale = out_w_i / float(in_w_i)
                if not (img_s_i * _scale).is_integer():
                        f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
                        f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."

            if _initialized_ss < ss:  # init. the ss-th buffer at the first iteration
                # construct multi-resolution outputs
                output_classes = seg_prob.shape[1]
                output_shape = [batch_size, output_classes] + [
                    int(image_size_d * zoom_scale_d)
                    for image_size_d, zoom_scale_d in zip(
                        image_size, zoom_scale)
                # allocate memory to store the full output and the count for overlapping parts
                    torch.zeros([1, 1] + output_shape[2:],
                _initialized_ss += 1

            # resizing the importance_map
            resizer = Resize(spatial_size=seg_prob.shape[2:],

            # store the result in the proper location of the full output. Apply weights from importance map.
            for idx, original_idx in zip(slice_range, unravel_slice):
                # zoom roi
                original_idx_zoom = list(
                    original_idx)  # 4D for 2D image, 5D for 3D image
                for axis in range(2, len(original_idx_zoom)):
                    zoomed_start = original_idx[axis].start * zoom_scale[axis -
                    zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
                    if not zoomed_start.is_integer() or (
                            not zoomed_end.is_integer()):
                            f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
                            f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
                            f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
                            f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
                            f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
                            "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
                    original_idx_zoom[axis] = slice(int(zoomed_start),
                                                    int(zoomed_end), None)
                importance_map_zoom = resizer(
                # store results and weights
                    original_idx_zoom] += importance_map_zoom * seg_prob[
                        idx - slice_g]
                count_map_list[ss][original_idx_zoom] += (

    # account for any overlapping sections
    for ss in range(len(output_image_list)):
        output_image_list[ss] = (output_image_list[ss] /

    # remove padding if image_size smaller than roi_size
    for ss, output_i in enumerate(output_image_list):
        if torch.isnan(output_i).any() or torch.isinf(output_i).any():
                "Sliding window inference results contain NaN or Inf.")

        zoom_scale = [
            seg_prob_map_shape_d / roi_size_d
            for seg_prob_map_shape_d, roi_size_d in zip(
                output_i.shape[2:], roi_size)

        final_slicing: List[slice] = []
        for sp in range(num_spatial_dims):
            slice_dim = slice(
                pad_size[sp * 2],
                image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
            slice_dim = slice(
                    round(slice_dim.start *
                          zoom_scale[num_spatial_dims - sp - 1])),
                    round(slice_dim.stop *
                          zoom_scale[num_spatial_dims - sp - 1])),
            final_slicing.insert(0, slice_dim)
        while len(final_slicing) < len(output_i.shape):
            final_slicing.insert(0, slice(None))
        output_image_list[ss] = output_i[final_slicing]

    if dict_key is not None:  # if output of predictor is a dict
        final_output = dict(zip(dict_key, output_image_list))
        final_output = tuple(output_image_list)  # type: ignore
    final_output = final_output[
        0] if is_tensor_output else final_output  # type: ignore
    if isinstance(inputs, MetaTensor):
        final_output = convert_to_dst_type(final_output,
                                           inputs)[0]  # type: ignore
    return final_output
Exemplo n.º 29
    def __call__(
            img: NdarrayOrTensor,
            argmax: Optional[bool] = None,
            to_onehot: Optional[int] = None,
            threshold: Optional[float] = None,
            rounding: Optional[str] = None,
            n_classes: Optional[int] = None,  # deprecated
            num_classes: Optional[int] = None,  # deprecated
            logit_thresh: Optional[float] = None,  # deprecated
            threshold_values: Optional[bool] = None,  # deprecated
    ) -> NdarrayOrTensor:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``to_onehot`` instead.

        .. deprecated:: 0.7.0
            ``num_classes`` is deprecated, use ``to_onehot`` instead.
            ``logit_thresh`` is deprecated, use ``threshold`` instead.
            ``threshold_values`` is deprecated, use ``threshold`` instead.

        if isinstance(to_onehot, bool):
                "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead."
            to_onehot = num_classes if to_onehot else None
        if isinstance(threshold, bool):
                "`threshold_values=True/False` is deprecated, please use `threshold=value` instead."
            threshold = logit_thresh if threshold else None
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor)
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=0, keepdim=True)

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise AssertionError(
                    "the number of classes for One-Hot must be an integer.")
            img_t = one_hot(img_t, num_classes=to_onehot, dim=0)

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float)
        return img
Exemplo n.º 30
def compute_generalized_dice(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    include_background: bool = True,
    weight_type: Union[Weight, str] = Weight.SQUARE,
) -> torch.Tensor:
    """Computes the Generalized Dice Score and returns a tensor with its per image values.

        y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
            and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
            remaining are the spatial dimensions.
        y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
        include_background (bool, optional): whether to skip score computation on the first channel of the
            predicted output. Defaults to True.
        weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
            transform ground truth volume into a weight factor. Defaults to ``"square"``.

        torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].

        ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
            or `y_pred` and `y` don't have the same shape.
    # Ensure tensors are binarized
    is_binary_tensor(y_pred, "y_pred")
    is_binary_tensor(y, "y")

    # Ensure tensors have at least 3 dimensions and have the same shape
    dims = y_pred.dim()
    if dims < 3:
        raise ValueError(
            f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}."
    if y.shape != y_pred.shape:
        raise ValueError(
            f"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes."

    # Ignore background, if needed
    if not include_background:
        y_pred, y = ignore_background(y_pred=y_pred, y=y)

    # Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator
    reduce_axis = list(range(2, y_pred.dim()))
    intersection = torch.sum(y * y_pred, dim=reduce_axis)
    y_o = torch.sum(y, dim=reduce_axis)
    y_pred_o = torch.sum(y_pred, dim=reduce_axis)
    denominator = y_o + y_pred_o

    # Set the class weights
    weight_type = look_up_option(weight_type, Weight)
    if weight_type == Weight.SIMPLE:
        w = torch.reciprocal(y_o.float())
    elif weight_type == Weight.SQUARE:
        w = torch.reciprocal(y_o.float() * y_o.float())
        w = torch.ones_like(y_o.float())

    # Replace infinite values for non-appearing classes by the maximum weight
    for b in w:
        infs = torch.isinf(b)
        b[infs] = 0
        b[infs] = torch.max(b)

    # Compute the weighted numerator and denominator, summing along the class axis
    numer = 2.0 * (intersection * w).sum(dim=1)
    denom = (denominator * w).sum(dim=1)

    # Compute the score
    generalized_dice_score = numer / denom

    # Handle zero deivision. Where denom == 0 and the prediction volume is 0, score is 1.
    # Where denom == 0 but the prediction volume is not 0, score is 0
    y_pred_o = y_pred_o.sum(dim=-1)
    denom_zeros = denom == 0
    generalized_dice_score[denom_zeros] = torch.where(
        (y_pred_o == 0)[denom_zeros], torch.tensor(1.0), torch.tensor(0.0))

    return generalized_dice_score