def __init__( self, spatial_size: Optional[Union[Sequence[int], int]] = None, normalized: bool = False, mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.ZEROS, align_corners: bool = False, reverse_indexing: bool = True, zero_centered: Optional[bool] = None, ) -> None: """ Apply affine transformations with a batch of affine matrices. When `normalized=False` and `reverse_indexing=True`, it does the commonly used resampling in the 'pull' direction following the ``scipy.ndimage.affine_transform`` convention. In this case `theta` is equivalent to (ndim+1, ndim+1) input ``matrix`` of ``scipy.ndimage.affine_transform``, operates on homogeneous coordinates. See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html When `normalized=True` and `reverse_indexing=False`, it applies `theta` to the normalized coordinates (coords. in the range of [-1, 1]) directly. This is often used with `align_corners=False` to achieve resolution-agnostic resampling, thus useful as a part of trainable modules such as the spatial transformer networks. See also: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html Args: spatial_size: output spatial shape, the full output shape will be `[N, C, *spatial_size]` where N and C are inferred from the `src` input of `self.forward`. normalized: indicating whether the provided affine matrix `theta` is defined for the normalized coordinates. If `normalized=False`, `theta` will be converted to operate on normalized coordinates as pytorch affine_grid works with the normalized coordinates. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"zeros"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: see also https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html. reverse_indexing: whether to reverse the spatial indexing of image and coordinates. set to `False` if `theta` follows pytorch's default "D, H, W" convention. set to `True` if `theta` follows `scipy.ndimage` default "i, j, k" convention. zero_centered: whether the affine is applied to coordinates in a zero-centered value range. With `zero_centered=True`, for example, the center of rotation will be the spatial center of the input; with `zero_centered=False`, the center of rotation will be the origin of the input. This option is only available when `normalized=False`, where the default behaviour is `False` if unspecified. See also: :py:func:`monai.networks.utils.normalize_transform`. """ super().__init__() self.spatial_size = ensure_tuple(spatial_size) if spatial_size is not None else None self.normalized = normalized self.mode: str = look_up_option(mode, GridSampleMode) self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.reverse_indexing = reverse_indexing if zero_centered is not None and self.normalized: raise ValueError("`normalized=True` is not compatible with the `zero_centered` option.") self.zero_centered = zero_centered if zero_centered is not None else False
def convert_mask_to_box( boxes_mask: NdarrayOrTensor, bg_label: int = -1, box_dtype=torch.float32, label_dtype=torch.long) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Convert int16 mask image to box, which has the same size with the input image Args: boxes_mask: int16 array, sized (num_box, H, W). Each channel represents a box. The foreground region in channel c has intensity of labels[c]. The background intensity is bg_label. bg_label: background labels for the boxes_mask box_dtype: output dtype for boxes label_dtype: output dtype for labels Return: - bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``. - classification foreground(fg) labels, dtype should be int, sized (N,). """ look_up_option(len(boxes_mask.shape), [3, 4]) spatial_size = list(boxes_mask.shape[1:]) spatial_dims = get_spatial_dims(spatial_size=spatial_size) boxes_mask_np, *_ = convert_data_type(boxes_mask, np.ndarray) boxes_list = [] labels_list = [] for b in range(boxes_mask_np.shape[0]): fg_indices = np.nonzero(boxes_mask_np[b, ...] - bg_label) if fg_indices[0].shape[0] == 0: continue boxes_b = [] for fd_i in fg_indices: boxes_b.append(min(fd_i)) # top left corner for fd_i in fg_indices: boxes_b.append(max(fd_i) + 1 - TO_REMOVE) # bottom right corner boxes_list.append(boxes_b) if spatial_dims == 2: labels_list.append(boxes_mask_np[b, fg_indices[0][0], fg_indices[1][0]]) if spatial_dims == 3: labels_list.append(boxes_mask_np[b, fg_indices[0][0], fg_indices[1][0], fg_indices[2][0]]) if len(boxes_list) == 0: boxes_np, labels_np = np.zeros([0, 2 * spatial_dims]), np.zeros([0]) else: boxes_np, labels_np = np.asarray(boxes_list), np.asarray(labels_list) boxes, *_ = convert_to_dst_type(src=boxes_np, dst=boxes_mask, dtype=box_dtype) labels, *_ = convert_to_dst_type(src=labels_np, dst=boxes_mask, dtype=label_dtype) return boxes, labels
def __call__( self, img: torch.Tensor, argmax: Optional[bool] = None, to_onehot: Optional[bool] = None, num_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, ) -> torch.Tensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: whether to convert input data into the one-hot format. Defaults to ``self.to_onehot``. num_classes: the number of classes to convert to One-Hot format. Defaults to ``self.num_classes``. threshold_values: whether threshold the float value to int number 0 or 1. Defaults to ``self.threshold_values``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``self.logit_thresh``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. .. deprecated:: 0.6.0 ``n_classes`` is deprecated, use ``num_classes`` instead. """ # in case the new num_classes is default but you still call deprecated n_classes if n_classes is not None and num_classes is None: num_classes = n_classes if argmax or self.argmax: img = torch.argmax(img, dim=0, keepdim=True) if to_onehot or self.to_onehot: _nclasses = self.num_classes if num_classes is None else num_classes if not isinstance(_nclasses, int): raise AssertionError("One of self.num_classes or num_classes must be an integer") img = one_hot(img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img = torch.round(img) return img.float()
def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: """ Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS`` according to the filename extension key ``ext_name``. Args: ext_name: the filename extension of the image. As an indexing key it will be converted to a lower case string. error_if_not_found: whether to raise an error if no suitable image writer is found. if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``. """ if not SUPPORTED_WRITERS: init() fmt = f"{ext_name}".lower() if fmt.startswith("."): fmt = fmt[1:] avail_writers = [] default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ()) for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers): try: _writer( ) # this triggers `monai.utils.module.require_pkg` to check the system availability avail_writers.append(_writer) except OptionalImportError: continue except Exception: # other writer init errors indicating it exists avail_writers.append(_writer) if not avail_writers and error_if_not_found: raise OptionalImportError(f"No ImageWriter backend found for {fmt}.") writer_tuple = ensure_tuple(avail_writers) SUPPORTED_WRITERS[fmt] = writer_tuple return writer_tuple
def astype(self, dtype, device=None, *_args, **_kwargs): """ Cast to ``dtype``, sharing data whenever possible. Args: dtype: dtypes such as np.float32, torch.float, "np.float32", float. device: the device if `dtype` is a torch data type. _args: additional args (currently unused). _kwargs: additional kwargs (currently unused). Returns: data array instance """ if isinstance(dtype, str): mod_str, *dtype = dtype.split(".", 1) dtype = mod_str if not dtype else dtype[0] else: mod_str = getattr(dtype, "__module__", "torch") mod_str = look_up_option(mod_str, {"torch", "numpy", "np"}, default="numpy") if mod_str == "torch": out_type = torch.Tensor elif mod_str in ("numpy", "np"): out_type = np.ndarray else: out_type = None return self.get_array(output_type=out_type, dtype=dtype, device=device)
def __init__( self, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP, **pad_opts: Dict, ): """ Args: patch_size: size of patches to generate slices for, 0/None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"wrap"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html pad_opts: padding options, see numpy.pad Note: The `patch_size` is the size of the patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be specified by a `patch_size` of (10, 10, 10). """ self.patch_size = (None, ) + tuple(patch_size) self.start_pos = ensure_tuple(start_pos) self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.pad_opts = pad_opts
def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = "md5") -> bool: """ Verify hash signature of specified file. Args: filepath: path of source file to verify hash value. val: expected hash value of the file. hash_type: type of hash algorithm to use, default is `"md5"`. The supported hash types are `"md5"`, `"sha1"`, `"sha256"`, `"sha512"`. See also: :py:data:`monai.apps.utils.SUPPORTED_HASH_TYPES`. """ if val is None: logger.info( f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}." ) return True actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES) actual_hash = actual_hash_func() try: with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): actual_hash.update(chunk) except Exception as e: logger.error(f"Exception in check_hash: {e}") return False if val != actual_hash.hexdigest(): logger.error(f"check_hash failed {actual_hash.hexdigest()}.") return False logger.info(f"Verified '{_basename(filepath)}', {hash_type}: {val}.") return True
def resample_and_clip( cls, data_array: NdarrayOrTensor, output_spatial_shape: Optional[Sequence[int]] = None, mode: str = InterpolateMode.BICUBIC, ): """ Resample ``data_array`` to ``output_spatial_shape`` if needed. Args: data_array: input data array. This method assumes the 'channel-last' format. output_spatial_shape: output spatial shape. mode: interpolation mode, defautl is ``InterpolateMode.BICUBIC``. """ data: np.ndarray = convert_data_type(data_array, np.ndarray)[0] if output_spatial_shape is not None: output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) mode = look_up_option(mode, InterpolateMode) align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) _min, _max = np.min(data), np.max(data) if len(data.shape) == 3: data = np.moveaxis(data, -1, 0) # to channel first data = convert_data_type(xform(data), np.ndarray)[0] # type: ignore data = np.moveaxis(data, 0, -1) else: # (H, W) data = np.expand_dims(data, 0) # make a channel data = convert_data_type(xform(data), np.ndarray)[0][0] # type: ignore if mode != InterpolateMode.NEAREST: data = np.clip(data, _min, _max) return data
def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] = MetricReduction.MEAN): """ This function is to do the metric reduction for calculated `not-nan` metrics of each sample's each class. The function also returns `not_nans`, which counts the number of not nans for the metric. Args: f: a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class. reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", return the input f tensor and not_nans. Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. Raises: ValueError: When ``reduction`` is not one of ["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"]. """ # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) not_nans = (~nans).float() t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = look_up_option(reduction, MetricReduction) if reduction == MetricReduction.NONE: return f, not_nans f[nans] = 0 if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average not_nans = (not_nans > 0).float().sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: not_nans = not_nans.sum(dim=[0, 1]) f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims elif reduction == MetricReduction.MEAN_BATCH: not_nans = not_nans.sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM_BATCH: not_nans = not_nans.sum(dim=0) f = f.sum(dim=0) # the batch sum elif reduction == MetricReduction.MEAN_CHANNEL: not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average elif reduction == MetricReduction.SUM_CHANNEL: not_nans = not_nans.sum(dim=1) f = f.sum(dim=1) # the channel sum elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].' ) return f, not_nans
def _compute_op(op: str, d: np.ndarray): if not op.endswith("percentile"): c_op = look_up_option(op, supported_ops) return c_op(d) threshold = int(op.split("percentile")[0]) return supported_ops["90percentile"]( (d, threshold)) # type: ignore
def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. """ img_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise ValueError("the number of classes for One-Hot must be an integer.") img_t = one_hot( img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float) ) threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img_t = torch.round(img_t) img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get("dtype", torch.float)) return img
def __init__( self, include_background: bool = True, to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, other_act: Optional[Callable] = None, w_type: Union[Weight, str] = Weight.SQUARE, reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, ) -> None: """ Args: include_background: If False channel index 0 (background category) is excluded from the calculation. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. sigmoid: If True, apply a sigmoid function to the prediction. softmax: If True, apply a softmax function to the prediction. other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`. w_type: {``"square"``, ``"simple"``, ``"uniform"``} Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. Incompatible values. """ super().__init__(reduction=LossReduction(reduction).value) if other_act is not None and not callable(other_act): raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") self.include_background = include_background self.to_onehot_y = to_onehot_y self.sigmoid = sigmoid self.softmax = softmax self.other_act = other_act self.w_type = look_up_option(w_type, Weight) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch
def __init__( self, data_type: str = "tensor", dtype: Optional[Union[DtypeLike, torch.dtype]] = None, device: Optional[torch.device] = None, ) -> None: self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) self.dtype = dtype self.device = device
def __call__( self, img: NdarrayOrTensor, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None) -> Tuple[NdarrayOrTensor, Dict]: """ Compute statistics for the intensity of input image. Args: img: input image to compute intensity stats. meta_data: meta data dictionary to store the statistics data, if None, will create an empty dictionary. mask: if not None, mask the image to extract only the interested area to compute statistics. mask must have the same shape as input `img`. """ img_np: np.ndarray img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} if mask is not None: if mask.shape != img_np.shape or mask.dtype != bool: raise TypeError( "mask must be bool array with the same shape as input `img`." ) img_np = img_np[mask] supported_ops = { "mean": np.nanmean, "median": np.nanmedian, "max": np.nanmax, "min": np.nanmin, "std": np.nanstd, } def _compute(op: Callable, data: np.ndarray): if self.channel_wise: return [op(c) for c in data] return op(data) custom_index = 0 for o in self.ops: if isinstance(o, str): o = look_up_option(o, supported_ops.keys()) meta_data[self.key_prefix + "_" + o] = _compute( supported_ops[o], img_np) # type: ignore elif callable(o): meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_np) custom_index += 1 else: raise ValueError( "ops must be key string for predefined operations or callable function." ) return img, meta_data
def compute_importance_map( patch_size: Tuple[int, ...], mode: Union[BlendMode, str] = BlendMode.CONSTANT, sigma_scale: Union[Sequence[float], float] = 0.125, device: Union[torch.device, int, str] = "cpu", ) -> torch.Tensor: """Get importance map for different weight modes. Args: patch_size: Size of the required importance map. This should be either H, W [,D]. mode: {``"constant"``, ``"gaussian"``} How to blend output of overlapping windows. Defaults to ``"constant"``. - ``"constant``": gives equal weight to all predictions. - ``"gaussian``": gives less weight to predictions on edges of windows. sigma_scale: Sigma_scale to calculate sigma for each dimension (sigma = sigma_scale * dim_size). Used for gaussian mode only. device: Device to put importance map on. Raises: ValueError: When ``mode`` is not one of ["constant", "gaussian"]. Returns: Tensor of size patch_size. """ mode = look_up_option(mode, BlendMode) device = torch.device(device) # type: ignore[arg-type] if mode == BlendMode.CONSTANT: importance_map = torch.ones(patch_size, device=device).float() elif mode == BlendMode.GAUSSIAN: center_coords = [i // 2 for i in patch_size] sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size)) sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)] importance_map = torch.zeros(patch_size, device=device) importance_map[tuple(center_coords)] = 1 pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(device=device, dtype=torch.float) importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0)) importance_map = importance_map.squeeze(0).squeeze(0) importance_map = importance_map / torch.max(importance_map) importance_map = importance_map.float() # importance_map cannot be 0, otherwise we may end up with nans! min_non_zero = importance_map[importance_map != 0].min().item() importance_map = torch.clamp(importance_map, min=min_non_zero) else: raise ValueError( f"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}]." ) return importance_map
def __init__(self, submodule, dim: int = 1, mode: Union[str, SkipMode] = "cat") -> None: """ Args: submodule: the module defines the trainable branch. dim: the dimension over which the tensors are concatenated. Used when mode is ``"cat"``. mode: ``"cat"``, ``"add"``, ``"mul"``. defaults to ``"cat"``. """ super().__init__() self.submodule = submodule self.dim = dim self.mode = look_up_option(mode, SkipMode).value
def get_constructor(self, factory_name: str, *args) -> Any: """ Get the constructor for the given factory name and arguments. Raises: TypeError: When ``factory_name`` is not a ``str``. """ if not isinstance(factory_name, str): raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.") func = look_up_option(factory_name.upper(), self.factories) return func(*args)
def __init__( self, output_dir: PathLike = "./", output_postfix: str = "seg", output_ext: str = ".png", resample: bool = True, mode: Union[InterpolateMode, str] = InterpolateMode.NEAREST, scale: Optional[int] = None, data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: """ Args: output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. resample: whether to resample and resize if providing spatial_shape in the metadata. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"nearest"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. data_root_dir: if not empty, it specifies the beginning parts of the input file's absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different folders with the same file names. for example: input_file_name: /foo/bar/test1/image.png, postfix: seg output_ext: png output_dir: /output, data_root_dir: /foo/bar, output will be: /output/test1/image/image_seg.png separate_folder: whether to save every file in a separate folder, for example: if input filename is `image.png`, postfix is `seg` and folder_path is `output`, if `True`, save as: `output/image/image_seg.png`, if `False`, save as `output/image_seg.nii`. default to `True`. print_log: whether to print log about the saved PNG file path, etc. default to `True`. """ self.output_dir = output_dir self.output_postfix = output_postfix self.output_ext = output_ext self.resample = resample self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.scale = scale self.data_root_dir = data_root_dir self.separate_folder = separate_folder self.print_log = print_log self._data_index = 0
def export_config_file(cls, config: Dict, filepath: PathLike, fmt="json", **kwargs): """ Export the config content to the specified file path (currently support JSON and YAML files). Args: config: source config content to export. filepath: target file path to save. fmt: format of config content, currently support ``"json"`` and ``"yaml"``. kwargs: other arguments for ``json.dump`` or ``yaml.safe_dump``, depends on the file format. """ _filepath: str = str(Path(filepath)) writer = look_up_option(fmt.lower(), {"json", "yaml"}) with open(_filepath, "w") as f: if writer == "json": return json.dump(config, f, **kwargs) if writer == "yaml": return yaml.safe_dump(config, f, **kwargs) raise ValueError(f"only support JSON or YAML config file so far, got {writer}.")
def __init__( self, include_background: bool = True, reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH, weight_type: Union[Weight, str] = Weight.SQUARE, ) -> None: super().__init__() self.include_background = include_background reduction_options = [ "none", "mean_batch", "sum_batch", MetricReduction.NONE, MetricReduction.MEAN_BATCH, MetricReduction.SUM_BATCH, ] self.reduction = reduction if self.reduction not in reduction_options: raise ValueError(f"reduction must be one of {reduction_options}") self.weight_type = look_up_option(weight_type, Weight)
def __init__( self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0, act: Union[Tuple, str] = "GELU", dropout_mode="vit", ) -> None: """ Args: hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. act: activation type and arguments. Defaults to GELU. dropout_mode: dropout mode, can be "vit" or "swin". "vit" mode uses two dropout instances as implemented in https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 "swin" corresponds to one instance as implemented in https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": self.drop2 = nn.Dropout(dropout_rate) elif dropout_opt == "swin": self.drop2 = self.drop1 else: raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
def register_writer(ext_name, *im_writers): """ Register ``ImageWriter``, so that writing a file with filename extension ``ext_name`` could be resolved to a tuple of potentially appropriate ``ImageWriter``. The customised writers could be registered by: .. code-block:: python from monai.data import register_writer # `MyWriter` must implement `ImageWriter` interface register_writer("nii", MyWriter) Args: ext_name: the filename extension of the image. As an indexing key, it will be converted to a lower case string. im_writers: one or multiple ImageWriter classes with high priority ones first. """ fmt = f"{ext_name}".lower() if fmt.startswith("."): fmt = fmt[1:] existing = look_up_option(fmt, SUPPORTED_WRITERS, default=()) all_writers = im_writers + existing SUPPORTED_WRITERS[fmt] = all_writers
def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, mode: Union[ChannelMatching, str] = ChannelMatching.PAD, ): """ Args: spatial_dims: number of spatial dimensions of the input image. in_channels: number of input channels. out_channels: number of output channels. mode: {``"pad"``, ``"project"``} Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. - ``"pad"``: with zero padding. - ``"project"``: with a trainable conv with kernel size one. """ super().__init__() self.project = None self.pad = None if in_channels == out_channels: return mode = look_up_option(mode, ChannelMatching) if mode == ChannelMatching.PROJECT: conv_type = Conv[Conv.CONV, spatial_dims] self.project = conv_type(in_channels, out_channels, kernel_size=1) return if mode == ChannelMatching.PAD: if in_channels > out_channels: raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.') pad_1 = (out_channels - in_channels) // 2 pad_2 = out_channels - in_channels - pad_1 pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0] self.pad = tuple(pad) return
def __init__( self, spatial_dims: int, num_classes: int, num_anchors: int, feature_extractor, size_divisible: Union[Sequence[int], int] = 1, ): super().__init__() self.spatial_dims = look_up_option(spatial_dims, supported=[1, 2, 3]) self.num_classes = num_classes self.size_divisible = ensure_tuple_rep(size_divisible, self.spatial_dims) if not hasattr(feature_extractor, "out_channels"): raise ValueError( "feature_extractor should contain an attribute out_channels " "specifying the number of output channels (assumed to be the " "same for all the levels)") self.feature_extractor = feature_extractor self.feature_map_channels: int = self.feature_extractor.out_channels self.num_anchors = num_anchors self.classification_head = RetinaNetClassificationHead( self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims) self.regression_head = RetinaNetRegressionHead( self.feature_map_channels, self.num_anchors, spatial_dims=self.spatial_dims) self.cls_key: str = "classification" self.box_reg_key: str = "box_regression"
def __init__( self, in_chans: int, embed_dim: int, window_size: Sequence[int], patch_size: Sequence[int], depths: Sequence[int], num_heads: Sequence[int], mlp_ratio: float = 4.0, qkv_bias: bool = True, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Type[LayerNorm] = nn.LayerNorm, patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", ) -> None: """ Args: in_chans: dimension of input channels. embed_dim: number of linear projection output channels. window_size: local window size. patch_size: patch size. depths: number of layers in each stage. num_heads: number of attention heads. mlp_ratio: ratio of mlp hidden dim to embedding dim. qkv_bias: add a learnable bias to query, key, value. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. drop_path_rate: stochastic depth rate. norm_layer: normalization layer. patch_norm: add normalization after patch embedding. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: spatial dimension. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. The default is currently `"merging"` (the original version defined in v0.9.0). """ super().__init__() self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.window_size = window_size self.patch_size = patch_size self.patch_embed = PatchEmbed( patch_size=self.patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, # type: ignore spatial_dims=spatial_dims, ) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] self.layers1 = nn.ModuleList() self.layers2 = nn.ModuleList() self.layers3 = nn.ModuleList() self.layers4 = nn.ModuleList() down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance( downsample, str) else downsample for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2**i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=self.window_size, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, downsample=down_sample_mod, use_checkpoint=use_checkpoint, ) if i_layer == 0: self.layers1.append(layer) elif i_layer == 1: self.layers2.append(layer) elif i_layer == 2: self.layers3.append(layer) elif i_layer == 3: self.layers4.append(layer) self.num_features = int(embed_dim * 2**(self.num_layers - 1))
def iter_patch( arr: np.ndarray, patch_size: Union[Sequence[int], int] = 0, start_pos: Sequence[int] = (), copy_back: bool = True, mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP, **pad_opts: Dict, ): """ Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr` but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative to start in the padded region). If `copy_back` is True the values from each patch are written back to `arr`. Args: arr: array to iterate over patch_size: size of patches to generate slices for, 0 or None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"wrap"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html pad_opts: padding options, see `numpy.pad` Yields: Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is True these changes will be reflected in `arr` once the iteration completes. Note: coordinate format is: [1st_dim_start, 1st_dim_end, 2nd_dim_start, 2nd_dim_end, ..., Nth_dim_start, Nth_dim_end]] """ # ensure patchSize and startPos are the right length patch_size_ = get_valid_patch_size(arr.shape, patch_size) start_pos = ensure_tuple_size(start_pos, arr.ndim) # pad image by maximum values needed to ensure patches are taken from inside an image arrpad = np.pad(arr, tuple((p, p) for p in patch_size_), look_up_option(mode, NumpyPadMode).value, **pad_opts) # choose a start position in the padded image start_pos_padded = tuple(s + p for s, p in zip(start_pos, patch_size_)) # choose a size to iterate over which is smaller than the actual padded image to prevent producing # patches which are only in the padded regions iter_size = tuple(s + p for s, p in zip(arr.shape, patch_size_)) for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded): # compensate original image padding coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, patch_size_)) yield arrpad[slices], np.asarray( coords_no_pad ) # data and coords (in numpy; works with torch loader) # copy back data from the padded image if required if copy_back: slices = tuple(slice(p, p + s) for p, s in zip(patch_size_, arr.shape)) arr[...] = arrpad[slices]
def compute_roc_auc(y_pred: torch.Tensor, y: torch.Tensor, average: Union[Average, str] = Average.MACRO): """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/ sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_. Args: y_pred: input data to compute, typical classification model output. the first dim must be batch, if multi-classes, it must be in One-Hot format. for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. y: ground truth to compute ROC AUC metric, the first dim must be batch. if multi-classes, it must be in One-Hot format. for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} Type of averaging performed if not binary classification. Defaults to ``"macro"``. - ``"macro"``: calculate metrics for each label, and find their unweighted mean. This does not take label imbalance into account. - ``"weighted"``: calculate metrics for each label, and find their average, weighted by support (the number of true instances for each label). - ``"micro"``: calculate metrics globally by considering each element of the label indicator matrix as a label. - ``"none"``: the scores for each class are returned. Raises: ValueError: When ``y_pred`` dimension is not one of [1, 2]. ValueError: When ``y`` dimension is not one of [1, 2]. ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"]. Note: ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. """ y_pred_ndim = y_pred.ndimension() y_ndim = y.ndimension() if y_pred_ndim not in (1, 2): raise ValueError( f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}." ) if y_ndim not in (1, 2): raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.") if y_pred_ndim == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) y_pred_ndim = 1 if y_ndim == 2 and y.shape[1] == 1: y = y.squeeze(dim=-1) if y_pred_ndim == 1: return _calculate(y_pred, y) if y.shape != y_pred.shape: raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.") average = look_up_option(average, Average) if average == Average.MICRO: return _calculate(y_pred.flatten(), y.flatten()) y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) auc_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)] if average == Average.NONE: return auc_values if average == Average.MACRO: return np.mean(auc_values) if average == Average.WEIGHTED: weights = [sum(y_) for y_ in y] return np.average(auc_values, weights=weights) raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
def sliding_window_inference( inputs: torch.Tensor, roi_size: Union[Sequence[int], int], sw_batch_size: int, predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], overlap: float = 0.25, mode: Union[BlendMode, str] = BlendMode.CONSTANT, sigma_scale: Union[Sequence[float], float] = 0.125, padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, progress: bool = False, roi_weight_map: Union[torch.Tensor, None] = None, *args: Any, **kwargs: Any, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: """ Sliding window inference on `inputs` with `predictor`. The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes could be ([128,64,256], [64,32,128]). In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). When roi_size is larger than the inputs' spatial size, the input image are padded during inference. To maintain the same spatial sizes, the output image will be cropped to the original input size. Args: inputs: input image to be processed (assuming NCHW[D]) roi_size: the spatial window size for inferences. When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the `roi_size` are non-positive values, the transform will use the corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. predictor: given input tensor ``patch_data`` in shape NCHW[D], The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the scaled output ROI sizes are still integers. If the `predictor`'s input and output spatial sizes are different, we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. overlap: Amount of overlap between scans. mode: {``"constant"``, ``"gaussian"``} How to blend output of overlapping windows. Defaults to ``"constant"``. - ``"constant``": gives equal weight to all predictions. - ``"gaussian``": gives less weight to predictions on edges of windows. sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding spatial dimensions. padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html cval: fill value for 'constant' padding mode. Default: 0 sw_device: device for the window data. By default the device (and accordingly the memory) of the `inputs` is used. Normally `sw_device` should be consistent with the device where `predictor` is defined. device: device for the stitched output prediction. By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. progress: whether to print a `tqdm` progress bar. roi_weight_map: pre-computed (non-negative) weight map for each ROI. If not given, and ``mode`` is not `constant`, this map will be computed on the fly. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. Note: - input must be channel-first and have a batch dim, supports N-D sliding window. """ compute_dtype = inputs.dtype num_spatial_dims = len(inputs.shape) - 2 if overlap < 0 or overlap >= 1: raise ValueError("overlap must be >= 0 and < 1.") # determine image spatial size and batch size # Note: all input images must have the same image size and batch size batch_size, _, *image_size_ = inputs.shape if device is None: device = inputs.device if sw_device is None: sw_device = inputs.device roi_size = fall_back_tuple(roi_size, image_size_) # in case that image size is smaller than roi size image_size = tuple( max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) pad_size = [] for k in range(len(inputs.shape) - 1, 1, -1): diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) # Store all slices in list slices = dense_patch_slices(image_size, roi_size, scan_interval) num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) if valid_patch_size == roi_size and (roi_weight_map is not None): importance_map = roi_weight_map else: try: importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) except BaseException as e: raise RuntimeError( "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore # handle non-positive weights min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) # Perform predictions dict_key, output_image_list, count_map_list = None, [], [] _initialized_ss = -1 is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple) # for each patch for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range( 0, total_slices, sw_batch_size): slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) unravel_slice = [ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] window_data = torch.cat([ convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice ]).to(sw_device) seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. seg_prob_tuple: Tuple[torch.Tensor, ...] if isinstance(seg_prob_out, torch.Tensor): seg_prob_tuple = (seg_prob_out, ) elif isinstance(seg_prob_out, Mapping): if dict_key is None: dict_key = sorted( seg_prob_out.keys()) # track predictor's output keys seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) is_tensor_output = False else: seg_prob_tuple = ensure_tuple(seg_prob_out) is_tensor_output = False # for each output in multi-output list for ss, seg_prob in enumerate(seg_prob_tuple): seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN # compute zoom scale: out_roi_size/in_roi_size zoom_scale = [] for axis, (img_s_i, out_w_i, in_w_i) in enumerate( zip(image_size, seg_prob.shape[2:], window_data.shape[2:])): _scale = out_w_i / float(in_w_i) if not (img_s_i * _scale).is_integer(): warnings.warn( f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." ) zoom_scale.append(_scale) if _initialized_ss < ss: # init. the ss-th buffer at the first iteration # construct multi-resolution outputs output_classes = seg_prob.shape[1] output_shape = [batch_size, output_classes] + [ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip( image_size, zoom_scale) ] # allocate memory to store the full output and the count for overlapping parts output_image_list.append( torch.zeros(output_shape, dtype=compute_dtype, device=device)) count_map_list.append( torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) _initialized_ss += 1 # resizing the importance_map resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) # store the result in the proper location of the full output. Apply weights from importance map. for idx, original_idx in zip(slice_range, unravel_slice): # zoom roi original_idx_zoom = list( original_idx) # 4D for 2D image, 5D for 3D image for axis in range(2, len(original_idx_zoom)): zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] if not zoomed_start.is_integer() or ( not zoomed_end.is_integer()): warnings.warn( f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." ) original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) importance_map_zoom = resizer( importance_map.unsqueeze(0))[0].to(compute_dtype) # store results and weights output_image_list[ss][ original_idx_zoom] += importance_map_zoom * seg_prob[ idx - slice_g] count_map_list[ss][original_idx_zoom] += ( importance_map_zoom.unsqueeze(0).unsqueeze(0).expand( count_map_list[ss][original_idx_zoom].shape)) # account for any overlapping sections for ss in range(len(output_image_list)): output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) # remove padding if image_size smaller than roi_size for ss, output_i in enumerate(output_image_list): if torch.isnan(output_i).any() or torch.isinf(output_i).any(): warnings.warn( "Sliding window inference results contain NaN or Inf.") zoom_scale = [ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip( output_i.shape[2:], roi_size) ] final_slicing: List[slice] = [] for sp in range(num_spatial_dims): slice_dim = slice( pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) slice_dim = slice( int( round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), int( round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), ) final_slicing.insert(0, slice_dim) while len(final_slicing) < len(output_i.shape): final_slicing.insert(0, slice(None)) output_image_list[ss] = output_i[final_slicing] if dict_key is not None: # if output of predictor is a dict final_output = dict(zip(dict_key, output_image_list)) else: final_output = tuple(output_image_list) # type: ignore final_output = final_output[ 0] if is_tensor_output else final_output # type: ignore if isinstance(inputs, MetaTensor): final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore return final_output
def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, # deprecated num_classes: Optional[int] = None, # deprecated logit_thresh: Optional[float] = None, # deprecated threshold_values: Optional[bool] = None, # deprecated ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. .. deprecated:: 0.6.0 ``n_classes`` is deprecated, use ``to_onehot`` instead. .. deprecated:: 0.7.0 ``num_classes`` is deprecated, use ``to_onehot`` instead. ``logit_thresh`` is deprecated, use ``threshold`` instead. ``threshold_values`` is deprecated, use ``threshold`` instead. """ if isinstance(to_onehot, bool): warnings.warn( "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead." ) to_onehot = num_classes if to_onehot else None if isinstance(threshold, bool): warnings.warn( "`threshold_values=True/False` is deprecated, please use `threshold=value` instead." ) threshold = logit_thresh if threshold else None img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise AssertionError( "the number of classes for One-Hot must be an integer.") img_t = one_hot(img_t, num_classes=to_onehot, dim=0) threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img_t = torch.round(img_t) img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float) return img
def compute_generalized_dice( y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Union[Weight, str] = Weight.SQUARE, ) -> torch.Tensor: """Computes the Generalized Dice Score and returns a tensor with its per image values. Args: y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. include_background (bool, optional): whether to skip score computation on the first channel of the predicted output. Defaults to True. weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. Returns: torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. Raises: ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, or `y_pred` and `y` don't have the same shape. """ # Ensure tensors are binarized is_binary_tensor(y_pred, "y_pred") is_binary_tensor(y, "y") # Ensure tensors have at least 3 dimensions and have the same shape dims = y_pred.dim() if dims < 3: raise ValueError( f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}." ) if y.shape != y_pred.shape: raise ValueError( f"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes." ) # Ignore background, if needed if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) # Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator reduce_axis = list(range(2, y_pred.dim())) intersection = torch.sum(y * y_pred, dim=reduce_axis) y_o = torch.sum(y, dim=reduce_axis) y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o # Set the class weights weight_type = look_up_option(weight_type, Weight) if weight_type == Weight.SIMPLE: w = torch.reciprocal(y_o.float()) elif weight_type == Weight.SQUARE: w = torch.reciprocal(y_o.float() * y_o.float()) else: w = torch.ones_like(y_o.float()) # Replace infinite values for non-appearing classes by the maximum weight for b in w: infs = torch.isinf(b) b[infs] = 0 b[infs] = torch.max(b) # Compute the weighted numerator and denominator, summing along the class axis numer = 2.0 * (intersection * w).sum(dim=1) denom = (denominator * w).sum(dim=1) # Compute the score generalized_dice_score = numer / denom # Handle zero deivision. Where denom == 0 and the prediction volume is 0, score is 1. # Where denom == 0 but the prediction volume is not 0, score is 0 y_pred_o = y_pred_o.sum(dim=-1) denom_zeros = denom == 0 generalized_dice_score[denom_zeros] = torch.where( (y_pred_o == 0)[denom_zeros], torch.tensor(1.0), torch.tensor(0.0)) return generalized_dice_score