def resample_and_clip( cls, data_array: NdarrayOrTensor, output_spatial_shape: Optional[Sequence[int]] = None, mode: str = InterpolateMode.BICUBIC, ): """ Resample ``data_array`` to ``output_spatial_shape`` if needed. Args: data_array: input data array. This method assumes the 'channel-last' format. output_spatial_shape: output spatial shape. mode: interpolation mode, defautl is ``InterpolateMode.BICUBIC``. """ data: np.ndarray = convert_data_type(data_array, np.ndarray)[0] if output_spatial_shape is not None: output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) mode = look_up_option(mode, InterpolateMode) align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) _min, _max = np.min(data), np.max(data) if len(data.shape) == 3: data = np.moveaxis(data, -1, 0) # to channel first data = convert_data_type(xform(data), np.ndarray)[0] # type: ignore data = np.moveaxis(data, 0, -1) else: # (H, W) data = np.expand_dims(data, 0) # make a channel data = convert_data_type(xform(data), np.ndarray)[0][0] # type: ignore if mode != InterpolateMode.NEAREST: data = np.clip(data, _min, _max) return data
def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: """ Subclass should implement this method to return a backend-specific data representation object. This method is used by ``cls.write`` and the input ``data_array`` is assumed 'channel-last'. """ return convert_data_type(data_array, np.ndarray)[0]
def create_backend_obj(cls, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = None, **kwargs): """ Create an Nifti1Image object from ``data_array``. This method assumes a 'channel-last' ``data_array``. Args: data_array: input data array. affine: affine matrix of the data array. dtype: output data type. kwargs: keyword arguments. Current ``nib.nifti1.Nifti1Image`` will read ``header``, ``extra``, ``file_map`` from this dictionary. See also: - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.Nifti1Image """ data_array = super().create_backend_obj(data_array) if dtype is not None: data_array = data_array.astype(dtype, copy=False) affine = convert_data_type(affine, np.ndarray)[0] if affine is None: affine = np.eye(4) affine = to_affine_nd(r=3, affine=affine) return nib.nifti1.Nifti1Image( data_array, affine, header=kwargs.pop("header", None), extra=kwargs.pop("extra", None), file_map=kwargs.pop("file_map", None), )
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]] Raises: ValueError: When ``image`` ndim is not one of [3, 4]. Returns: A torch tensor with the same shape as img, note: 1. it's the binary classification result of whether a pixel is edge or not. 2. in order to keep the original shape of mask image, we use padding as default. 3. the edge detection is just approximate because it defects inherent to Laplace kernel, ideally the edge should be thin enough, but now it has a thickness. """ img_: torch.Tensor = convert_data_type(img, torch.Tensor)[0] spatial_dims = len(img_.shape) - 1 img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) elif spatial_dims == 3: kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32) kernel[1, 1, 1] = 26.0 else: raise ValueError( f"{self.__class__} can only handle 2D or 3D images.") contour_img = apply_filter(img_, kernel) contour_img.clamp_(min=0.0, max=1.0) output, *_ = convert_to_dst_type(contour_img.squeeze(0), img) return output
def _image3_animated_gif( tag: str, image: Union[np.ndarray, torch.Tensor], writer, frame_dim: int = 0, scale_factor: float = 1.0 ): """Function to actually create the animated gif. Args: tag: Data identifier image: 3D image tensors expected to be in `HWD` format writer: the tensorboard writer to plot image frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`. scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ if len(image.shape) != 3: raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3") image_np, *_ = convert_data_type(image, output_type=np.ndarray) ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)] ims = [GifImage.fromarray(im) for im in ims] img_str = b"" for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]: img_str += b_data img_str += b"\x21\xFF\x0B\x4E\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2E\x30\x03\x01\x00\x00\x00" for i in ims: for b_data in PIL.GifImagePlugin.getdata(i): img_str += b_data img_str += b"\x3B" summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) image_summary = summary.Value(tag=tag, image=summary_image_str) return summary(value=[image_summary])
def extend(self, *data) -> None: """ Extend the local buffers with new ("batch-first") data. A buffer will be allocated for each `data` item. Compared with `self.append`, this method adds a "batch" of data to the local buffers. Args: data: each item can be a "batch-first" tensor or a list of "channel-first" tensors. they will be concatenated at the 0-th dimension when `get_buffer()` is called. """ if self._buffers is None: self._buffers = [[] for _ in data] for b, d in zip(self._buffers, data): # converting to pytorch tensors so that we can use the distributed API d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) try: b.extend([x[0] for x in torch.split(d_t, 1, dim=0)]) except (AttributeError, IndexError, RuntimeError) as e: raise TypeError( f"{e}. `data` should be a batch-first tensor or" f" a list of channel-first tensors, got {type(d_t)}" ) from e self._synced = False
def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): """ Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: spacing_key: key of spacing in meta data (default: ``pixdim``). anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to replace that axis. """ if len(self.all_meta_data) == 0: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: largest_axis = np.argmax(target_spacing) target_spacing[largest_axis] = np.percentile(all_spacings[:, largest_axis], percentile) output = list(target_spacing) return tuple(output)
def _create_itk_obj(array, affine): itk_img = deepcopy(array) itk_img = convert_data_type(itk_img, np.ndarray)[0] itk_obj = ITKWriter.create_backend_obj(itk_img, channel_dim=None, affine=affine, affine_lps=True) return itk_obj
def __call__( self, img: NdarrayOrTensor, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None) -> Tuple[NdarrayOrTensor, Dict]: """ Compute statistics for the intensity of input image. Args: img: input image to compute intensity stats. meta_data: meta data dictionary to store the statistics data, if None, will create an empty dictionary. mask: if not None, mask the image to extract only the interested area to compute statistics. mask must have the same shape as input `img`. """ img_np: np.ndarray img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} if mask is not None: if mask.shape != img_np.shape or mask.dtype != bool: raise TypeError( "mask must be bool array with the same shape as input `img`." ) img_np = img_np[mask] supported_ops = { "mean": np.nanmean, "median": np.nanmedian, "max": np.nanmax, "min": np.nanmin, "std": np.nanstd, } def _compute(op: Callable, data: np.ndarray): if self.channel_wise: return [op(c) for c in data] return op(data) custom_index = 0 for o in self.ops: if isinstance(o, str): o = look_up_option(o, supported_ops.keys()) meta_data[self.key_prefix + "_" + o] = _compute( supported_ops[o], img_np) # type: ignore elif callable(o): meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_np) custom_index += 1 else: raise ValueError( "ops must be key string for predefined operations or callable function." ) return img, meta_data
def calculate_percentiles( self, foreground_threshold: int = 0, sampling_flag: bool = True, interval: int = 10, min_percentile: float = 0.5, max_percentile: float = 99.5, ): """ This function is used to calculate the percentiles of intensities (and median) of the input dataset. To get the required values, all voxels need to be accumulated. To reduce the memory used, this function can be set to accumulate only a part of the voxels. Args: foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set the threshold to ``-1`` (default: ``0``). sampling_flag: whether to sample only a part of the voxels (default: ``True``). interval: the sampling interval for accumulating voxels (default: ``10``). min_percentile: minimal percentile (default: ``0.5``). max_percentile: maximal percentile (default: ``99.5``). """ all_intensities = [] for data in self.data_loader: if self.image_key and self.label_key: image, label = data[self.image_key], data[self.label_key] else: image, label = data image, *_ = convert_data_type(data=image, output_type=torch.Tensor) label, *_ = convert_data_type(data=label, output_type=torch.Tensor) intensities = image[torch.where( label > foreground_threshold)].tolist() if sampling_flag: intensities = intensities[::interval] all_intensities.append(intensities) all_intensities = list(chain(*all_intensities)) self.data_min_percentile, self.data_max_percentile = np.percentile( all_intensities, [min_percentile, max_percentile]) self.data_median = np.median(all_intensities)
def __call__(self, img: NdarrayOrTensor): """ Args: img: PyTorch Tensor data for the TorchVision transform. """ img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore out = self.trans(img_t) out, *_ = convert_to_dst_type(src=out, dst=img) return out
def calculate_statistics(self, foreground_threshold: int = 0): """ This function is used to calculate the maximum, minimum, mean and standard deviation of intensities of the input dataset. Args: foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set the threshold to ``-1`` (default: ``0``). """ voxel_sum = torch.as_tensor(0.0) voxel_square_sum = torch.as_tensor(0.0) voxel_max, voxel_min = [], [] voxel_ct = 0 for data in self.data_loader: if self.image_key and self.label_key: image, label = data[self.image_key], data[self.label_key] else: image, label = data image, *_ = convert_data_type(data=image, output_type=torch.Tensor) label, *_ = convert_data_type(data=label, output_type=torch.Tensor) image_foreground = image[torch.where(label > foreground_threshold)] voxel_max.append(image_foreground.max().item()) voxel_min.append(image_foreground.min().item()) voxel_ct += len(image_foreground) voxel_sum += image_foreground.sum() voxel_square_sum += torch.square(image_foreground).sum() self.data_max, self.data_min = max(voxel_max), min(voxel_min) self.data_mean = (voxel_sum / voxel_ct).item() self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean**2)).item()
def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0): """ Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. The spacing is computed from `affine_to_spacing(data[spacing_key][0], 3)` if `data[spacing_key]` is a matrix, otherwise, the `data[spacing_key]` must be a vector of pixdim values. Args: spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``). anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to replace that axis. """ if len(self.all_meta_data) == 0: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError( "The provided spacing_key is not in self.all_meta_data.") spacings = [] for data in self.all_meta_data: spacing_vals = convert_to_tensor(data[spacing_key][0], track_meta=False, wrap_sequence=True) if spacing_vals.ndim == 1: # vector spacings.append(spacing_vals[:3][None]) elif spacing_vals.ndim == 2: # matrix spacings.append(affine_to_spacing(spacing_vals, 3)[None]) else: raise ValueError( "data[spacing_key] must be a vector or a matrix.") all_spacings = concatenate(to_cat=spacings, axis=0) all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: largest_axis = np.argmax(target_spacing) target_spacing[largest_axis] = np.percentile( all_spacings[:, largest_axis], percentile) output = list(target_spacing) return tuple(output)
def create_backend_obj( cls, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, affine: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = np.float32, affine_lps_to_ras: bool = True, **kwargs, ): """ Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``. Args: data_array: input data array. channel_dim: channel dimension of the data array. This is used to create a Vector Image if it is not ``None``. affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`. dtype: output data type. affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. Set to ``True`` to be consistent with ``NibabelWriter``, otherwise the affine matrix is assumed already in the ITK convention. kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary. see also: - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389 """ data_array = super().create_backend_obj(data_array) _is_vec = channel_dim is not None if _is_vec: data_array = np.moveaxis(data_array, -1, 0) # from channel last to channel first data_array = data_array.T.astype(dtype, copy=True, order="C") itk_obj = itk.GetImageFromArray(data_array, is_vector=_is_vec, ttype=kwargs.pop("ttype", None)) d = len(itk.size(itk_obj)) if affine is None: affine = np.eye(d + 1, dtype=np.float64) _affine = convert_data_type(affine, np.ndarray)[0] if affine_lps_to_ras: _affine = orientation_ras_lps(to_affine_nd(d, _affine)) spacing = affine_to_spacing(_affine, r=d) _direction: np.ndarray = np.diag(1 / spacing) _direction = _affine[:d, :d] @ _direction itk_obj.SetSpacing(spacing.tolist()) itk_obj.SetOrigin(_affine[:d, -1].tolist()) itk_obj.SetDirection(itk.GetMatrixFromArray(_direction)) return itk_obj
def __call__(self, data: NdarrayOrTensor): """ Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and objects keep the original. for dictionary, list or tuple, ensure every item as expected type if applicable. """ output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray out, *_ = convert_data_type(data, output_type=output_type, dtype=self.dtype, device=self.device) return out
def __call__(self, img: NdarrayOrTensor): img_np, *_ = convert_data_type(img, np.ndarray) img_flat = img_np.flatten() try: out_flat = np.copy(img_flat).astype(self.dtype) except ValueError: # can't copy unchanged labels as the expected dtype is not supported, must map all the label values out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) for o, t in zip(self.orig_labels, self.target_labels): if o == t: continue np.place(out_flat, img_flat == o, t) out = out_flat.reshape(img_np.shape) out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype) return out
def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. """ img_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise ValueError("the number of classes for One-Hot must be an integer.") img_t = one_hot( img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float) ) threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img_t = torch.round(img_t) img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get("dtype", torch.float)) return img
def append(self, *data) -> None: """ Add samples to the local cumulative buffers. A buffer will be allocated for each `data` item. Compared with `self.extend`, this method adds a single sample (instead of a "batch") to the local buffers. Args: data: each item will be converted into a torch tensor. they will be stacked at the 0-th dim with a new dimension when `get_buffer()` is called. """ if self._buffers is None: self._buffers = [[] for _ in data] for b, d in zip(self._buffers, data): # converting to pytorch tensors so that we can use the distributed API d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True) b.append(d_t) self._synced = False
def aggregate(self): """ Sync data from all the ranks and compute the average value with previous sum value. """ data = self.get_buffer() # compute SUM across the batch dimension nans = isnan(data) not_nans = convert_data_type((~nans), dtype=torch.float32)[0].sum(0) data[nans] = 0 f = data.sum(0) # clear the buffer for next update super().reset() self.sum = f if self.sum is None else (self.sum + f) self.not_nans = not_nans if self.not_nans is None else (self.not_nans + not_nans) return self.sum / self.not_nans
def __call__( self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. Args: dtype: convert image to this data type, default is `self.dtype`. Raises: TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype) return img_out
def __call__( self, img: NdarrayOrTensor, sigmoid: Optional[bool] = None, softmax: Optional[bool] = None, other: Optional[Callable] = None, ) -> NdarrayOrTensor: """ Args: sigmoid: whether to execute sigmoid function on model output before transform. Defaults to ``self.sigmoid``. softmax: whether to execute softmax function on model output before transform. Defaults to ``self.softmax``. other: callable function to execute other activation layers, for example: `other = torch.tanh`. Defaults to ``self.other``. Raises: ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values. TypeError: When ``other`` is not an ``Optional[Callable]``. ValueError: When ``self.other=None`` and ``other=None``. Incompatible values. """ if sigmoid and softmax: raise ValueError( "Incompatible values: sigmoid=True and softmax=True.") if other is not None and not callable(other): raise TypeError( f"other must be None or callable but is {type(other).__name__}." ) # convert to float as activation must operate on float tensor img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if sigmoid or self.sigmoid: img_t = torch.sigmoid(img_t) if softmax or self.softmax: img_t = torch.softmax(img_t, dim=0) act_func = self.other if other is None else other if act_func is not None: img_t = act_func(img_t) out, *_ = convert_to_dst_type(img_t, img) return out
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Fill the holes in the provided image. Note: The value 0 is assumed as background label. Args: img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. Raises: NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. Returns: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_np, *_ = convert_data_type(img, np.ndarray) out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity) out, *_ = convert_to_dst_type(out_np, img) return out
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Fill the holes in the provided image. Note: The value 0 is assumed as background label. Args: img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. Raises: NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. Returns: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. """ if not isinstance(img, (np.ndarray, torch.Tensor)): raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") img_np: np.ndarray img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity) out, *_ = convert_to_dst_type(out_np, img) return out
def compute_surface_dice( y_pred: torch.Tensor, y: torch.Tensor, class_thresholds: List[float], include_background: bool = False, distance_metric: str = "euclidean", ): r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in pixels. The NSD is bounded between 0 and 1. This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`. The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function: .. math:: \operatorname {NSD}_{b,c} \left(Y_{b,c}, \hat{Y}_{b,c}\right) = \frac{\left|\mathcal{D}_{Y_{b,c}}^{'}\right| + \left| \mathcal{D}_{\hat{Y}_{b,c}}^{'} \right|}{\left|\mathcal{D}_{Y_{b,c}}\right| + \left|\mathcal{D}_{\hat{Y}_{b,c}}\right|} :label: nsd with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the acceptable distance :math:`\tau_c`: .. math:: \mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}. In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value will be returned for this class. In the case of a class being present in only one of predicted segmentation or reference segmentation, the class NSD will be 0. This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images. Be aware that the computation of boundaries is different from DeepMind's implementation https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). Args: y_pred: Predicted segmentation, typically segmentation model output. It must be a one-hot encoded, batch-first tensor [B,C,H,W]. y: Reference segmentation. It must be a one-hot encoded, batch-first tensor [B,C,H,W]. class_thresholds: List of class-specific thresholds. The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels. Each threshold needs to be a finite, non-negative number. include_background: Whether to skip the surface dice computation on the first channel of the predicted output. Defaults to ``False``. distance_metric: The metric used to compute surface distances. One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]. Defaults to ``"euclidean"``. Raises: ValueError: If `y_pred` and/or `y` are not PyTorch tensors. ValueError: If `y_pred` and/or `y` do not have four dimensions. ValueError: If `y_pred` and/or `y` have different shapes. ValueError: If `y_pred` and/or `y` are not one-hot encoded ValueError: If the number of channels of `y_pred` and/or `y` is different from the number of class thresholds. ValueError: If any class threshold is not finite. ValueError: If any class threshold is negative. Returns: Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index :math:`b` and class :math:`c`. """ if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): raise ValueError("y_pred and y must be PyTorch Tensor.") if y_pred.ndimension() != 4 or y.ndimension() != 4: raise ValueError( "y_pred and y should have four dimensions: [B,C,H,W].") if y_pred.shape != y.shape: raise ValueError( f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." ) if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y): raise ValueError( "y_pred and y should be binarized tensors (e.g. torch.int64).") if torch.any(y_pred > 1) or torch.any(y > 1): raise ValueError("y_pred and y should be one-hot encoded.") y = y.float() y_pred = y_pred.float() batch_size, n_class = y_pred.shape[:2] if n_class != len(class_thresholds): raise ValueError( f"number of classes ({n_class}) does not match number of class thresholds ({len(class_thresholds)})." ) if any(~np.isfinite(class_thresholds)): raise ValueError("All class thresholds need to be finite.") if any(np.array(class_thresholds) < 0): raise ValueError("All class thresholds need to be >= 0.") nsd = np.empty((batch_size, n_class)) for b, c in np.ndindex(batch_size, n_class): (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False) if not np.any(edges_gt): warnings.warn( f"the ground truth of class {c} is all 0, this may result in nan/inf distance." ) if not np.any(edges_pred): warnings.warn( f"the prediction of class {c} is all 0, this may result in nan/inf distance." ) distances_pred_gt = get_surface_distance( edges_pred, edges_gt, distance_metric=distance_metric) distances_gt_pred = get_surface_distance( edges_gt, edges_pred, distance_metric=distance_metric) boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) boundary_correct = np.sum( distances_pred_gt <= class_thresholds[c]) + np.sum( distances_gt_pred <= class_thresholds[c]) if boundary_complete == 0: # the class is neither present in the prediction, nor in the reference segmentation nsd[b, c] = np.nan else: nsd[b, c] = boundary_correct / boundary_complete return convert_data_type(nsd, torch.Tensor)[0]
def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: img_np, *_ = convert_data_type(image, np.ndarray) # add random offset self.randomize(img_size=img_np.shape) if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): img_np = img_np[:, self.offset[0]:, self.offset[1]:] # pad to full size, divisible by tile_size if self.pad_full: c, h, w = img_np.shape pad_h = (self.tile_size - h % self.tile_size) % self.tile_size pad_w = (self.tile_size - w % self.tile_size) % self.tile_size img_np = np.pad( # type: ignore img_np, [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], constant_values=self.background_val, ) # extact tiles x_step, y_step = self.step, self.step h_tile, w_tile = self.tile_size, self.tile_size c_image, h_image, w_image = img_np.shape c_stride, x_stride, y_stride = img_np.strides llw = as_strided( img_np, shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) img_np = llw.reshape(-1, c_image, h_tile, w_tile) # type: ignore # if keeping all patches if self.tile_count is None: # retain only patches with significant foreground content to speed up inference # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size if self.filter_mode == "min": # default, keep non-background tiles (small values) idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) img_np = img_np[idxs.reshape(-1)] elif self.filter_mode == "max": idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) img_np = img_np[idxs.reshape(-1)] else: if len(img_np) > self.tile_count: if self.filter_mode == "min": # default, keep non-background tiles (smallest values) idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[:self.tile_count] img_np = img_np[idxs] elif self.filter_mode == "max": idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count:] img_np = img_np[idxs] else: # random subset (more appropriate for WSIs without distinct background) if self.random_idxs is not None: img_np = img_np[self.random_idxs] elif len(img_np) < self.tile_count: img_np = np.pad( # type: ignore img_np, [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], constant_values=self.background_val, ) image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) return image
def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Do binary erosion and use XOR for input to get the edges. This function is helpful to further calculate metrics such as Average Surface Distance and Hausdorff Distance. The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using `label_idx`. `scipy`'s binary erosion is used to calculate the edges of the binary labelfield. In order to improve the computing efficiency, before getting the edges, the images can be cropped and only keep the foreground if not specifies ``crop = False``. We require that images are the same size, and assume that they occupy the same space (spacing, orientation, etc.). Args: seg_pred: the predicted binary or labelfield image. seg_gt: the actual binary or labelfield image. label_idx: for labelfield images, convert to binary with `seg_pred = seg_pred == label_idx`. crop: crop input images and only keep the foregrounds. In order to maintain two inputs' shapes, here the bounding box is achieved by ``(seg_pred | seg_gt)`` which represents the union set of two images. Defaults to ``True``. """ # Get both labelfields as np arrays if isinstance(seg_pred, torch.Tensor): seg_pred = seg_pred.detach().cpu().numpy() if isinstance(seg_gt, torch.Tensor): seg_gt = seg_gt.detach().cpu().numpy() if seg_pred.shape != seg_gt.shape: raise ValueError( f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}." ) # If not binary images, convert them if seg_pred.dtype != bool: seg_pred = seg_pred == label_idx if seg_gt.dtype != bool: seg_gt = seg_gt == label_idx if crop: if not np.any(seg_pred | seg_gt): return np.zeros_like(seg_pred), np.zeros_like(seg_gt) channel_dim = 0 seg_pred, seg_gt = np.expand_dims( seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim) box_start, box_end = generate_spatial_bounding_box( np.asarray(seg_pred | seg_gt)) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) seg_pred = convert_data_type( np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0] seg_gt = convert_data_type( np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0] # Do binary erosion and use XOR to get edges edges_pred = binary_erosion(seg_pred) ^ seg_pred edges_gt = binary_erosion(seg_gt) ^ seg_gt return edges_pred, edges_gt
def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, # deprecated num_classes: Optional[int] = None, # deprecated logit_thresh: Optional[float] = None, # deprecated threshold_values: Optional[bool] = None, # deprecated ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. .. deprecated:: 0.6.0 ``n_classes`` is deprecated, use ``to_onehot`` instead. .. deprecated:: 0.7.0 ``num_classes`` is deprecated, use ``to_onehot`` instead. ``logit_thresh`` is deprecated, use ``threshold`` instead. ``threshold_values`` is deprecated, use ``threshold`` instead. """ if isinstance(to_onehot, bool): warnings.warn( "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead." ) to_onehot = num_classes if to_onehot else None if isinstance(threshold, bool): warnings.warn( "`threshold_values=True/False` is deprecated, please use `threshold=value` instead." ) threshold = logit_thresh if threshold else None img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise AssertionError( "the number of classes for One-Hot must be an integer.") img_t = one_hot(img_t, num_classes=to_onehot, dim=0) threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img_t = torch.round(img_t) img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float) return img
def resample_if_needed( cls, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None, target_affine: Optional[NdarrayOrTensor] = None, output_spatial_shape: Union[Sequence[int], int, None] = None, mode: str = GridSampleMode.BILINEAR, padding_mode: str = GridSamplePadMode.BORDER, align_corners: bool = False, dtype: DtypeLike = np.float64, ): """ Convert the ``data_array`` into the coordinate system specified by ``target_affine``, from the current coordinate definition of ``affine``. If the transform between ``affine`` and ``target_affine`` could be achieved by simply transposing and flipping ``data_array``, no resampling will happen. Otherwise, this function resamples ``data_array`` using the transformation computed from ``affine`` and ``target_affine``. This function assumes the NIfTI dimension notations. Spatially it supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D respectively. When saving multiple time steps or multiple channels, time and/or modality axes should be appended after the first three dimensions. For example, shape of 2D eight-class segmentation probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a single-channel 3D image. The ``convert_to_channel_last`` method can be used to convert the data to the format described here. Note that the shape of the resampled ``data_array`` may subject to some rounding errors. For example, resampling a 20x20 pixel image from pixel size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel image. However, resampling a 20x20-pixel image from pixel size (2.0, 2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where the image shape is rounded from 13.333x13.333 pixels. In this case ``output_spatial_shape`` could be specified so that this function writes image data to a designated shape. Args: data_array: input data array to be converted. affine: the current affine of ``data_array``. Defaults to identity target_affine: the designated affine of ``data_array``. The actual output affine might be different from this value due to precision changes. output_spatial_shape: spatial shape of the output image. This option is used when resampling is needed. mode: available options are {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. This option is used when resampling is needed. Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample padding_mode: available options are {``"zeros"``, ``"border"``, ``"reflection"``}. This option is used when resampling is needed. Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample align_corners: boolean option of ``grid_sample`` to handle the corner convention. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If ``None``, use the data type of input data. The output data type of this method is always ``np.float32``. """ orig_type = type(data_array) data_array = convert_to_tensor(data_array, track_meta=True) if affine is not None: data_array.affine = convert_to_tensor( affine, track_meta=False) # type: ignore resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) output_array = resampler(data_array[None], dst_affine=target_affine, spatial_size=output_spatial_shape) # convert back at the end if isinstance(output_array, MetaTensor): output_array.applied_operations = [] data_array, *_ = convert_data_type( output_array, output_type=orig_type) # type: ignore affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore return data_array[0], affine
def sliding_window_inference( inputs: torch.Tensor, roi_size: Union[Sequence[int], int], sw_batch_size: int, predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], overlap: float = 0.25, mode: Union[BlendMode, str] = BlendMode.CONSTANT, sigma_scale: Union[Sequence[float], float] = 0.125, padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, progress: bool = False, roi_weight_map: Union[torch.Tensor, None] = None, *args: Any, **kwargs: Any, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: """ Sliding window inference on `inputs` with `predictor`. The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes could be ([128,64,256], [64,32,128]). In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). When roi_size is larger than the inputs' spatial size, the input image are padded during inference. To maintain the same spatial sizes, the output image will be cropped to the original input size. Args: inputs: input image to be processed (assuming NCHW[D]) roi_size: the spatial window size for inferences. When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the `roi_size` are non-positive values, the transform will use the corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. predictor: given input tensor ``patch_data`` in shape NCHW[D], The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the scaled output ROI sizes are still integers. If the `predictor`'s input and output spatial sizes are different, we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. overlap: Amount of overlap between scans. mode: {``"constant"``, ``"gaussian"``} How to blend output of overlapping windows. Defaults to ``"constant"``. - ``"constant``": gives equal weight to all predictions. - ``"gaussian``": gives less weight to predictions on edges of windows. sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding spatial dimensions. padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html cval: fill value for 'constant' padding mode. Default: 0 sw_device: device for the window data. By default the device (and accordingly the memory) of the `inputs` is used. Normally `sw_device` should be consistent with the device where `predictor` is defined. device: device for the stitched output prediction. By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. progress: whether to print a `tqdm` progress bar. roi_weight_map: pre-computed (non-negative) weight map for each ROI. If not given, and ``mode`` is not `constant`, this map will be computed on the fly. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. Note: - input must be channel-first and have a batch dim, supports N-D sliding window. """ compute_dtype = inputs.dtype num_spatial_dims = len(inputs.shape) - 2 if overlap < 0 or overlap >= 1: raise ValueError("overlap must be >= 0 and < 1.") # determine image spatial size and batch size # Note: all input images must have the same image size and batch size batch_size, _, *image_size_ = inputs.shape if device is None: device = inputs.device if sw_device is None: sw_device = inputs.device roi_size = fall_back_tuple(roi_size, image_size_) # in case that image size is smaller than roi size image_size = tuple( max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) pad_size = [] for k in range(len(inputs.shape) - 1, 1, -1): diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) # Store all slices in list slices = dense_patch_slices(image_size, roi_size, scan_interval) num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) if valid_patch_size == roi_size and (roi_weight_map is not None): importance_map = roi_weight_map else: try: importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) except BaseException as e: raise RuntimeError( "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore # handle non-positive weights min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) # Perform predictions dict_key, output_image_list, count_map_list = None, [], [] _initialized_ss = -1 is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple) # for each patch for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range( 0, total_slices, sw_batch_size): slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) unravel_slice = [ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] window_data = torch.cat([ convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice ]).to(sw_device) seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. seg_prob_tuple: Tuple[torch.Tensor, ...] if isinstance(seg_prob_out, torch.Tensor): seg_prob_tuple = (seg_prob_out, ) elif isinstance(seg_prob_out, Mapping): if dict_key is None: dict_key = sorted( seg_prob_out.keys()) # track predictor's output keys seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) is_tensor_output = False else: seg_prob_tuple = ensure_tuple(seg_prob_out) is_tensor_output = False # for each output in multi-output list for ss, seg_prob in enumerate(seg_prob_tuple): seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN # compute zoom scale: out_roi_size/in_roi_size zoom_scale = [] for axis, (img_s_i, out_w_i, in_w_i) in enumerate( zip(image_size, seg_prob.shape[2:], window_data.shape[2:])): _scale = out_w_i / float(in_w_i) if not (img_s_i * _scale).is_integer(): warnings.warn( f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." ) zoom_scale.append(_scale) if _initialized_ss < ss: # init. the ss-th buffer at the first iteration # construct multi-resolution outputs output_classes = seg_prob.shape[1] output_shape = [batch_size, output_classes] + [ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip( image_size, zoom_scale) ] # allocate memory to store the full output and the count for overlapping parts output_image_list.append( torch.zeros(output_shape, dtype=compute_dtype, device=device)) count_map_list.append( torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) _initialized_ss += 1 # resizing the importance_map resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) # store the result in the proper location of the full output. Apply weights from importance map. for idx, original_idx in zip(slice_range, unravel_slice): # zoom roi original_idx_zoom = list( original_idx) # 4D for 2D image, 5D for 3D image for axis in range(2, len(original_idx_zoom)): zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] if not zoomed_start.is_integer() or ( not zoomed_end.is_integer()): warnings.warn( f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." ) original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) importance_map_zoom = resizer( importance_map.unsqueeze(0))[0].to(compute_dtype) # store results and weights output_image_list[ss][ original_idx_zoom] += importance_map_zoom * seg_prob[ idx - slice_g] count_map_list[ss][original_idx_zoom] += ( importance_map_zoom.unsqueeze(0).unsqueeze(0).expand( count_map_list[ss][original_idx_zoom].shape)) # account for any overlapping sections for ss in range(len(output_image_list)): output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) # remove padding if image_size smaller than roi_size for ss, output_i in enumerate(output_image_list): if torch.isnan(output_i).any() or torch.isinf(output_i).any(): warnings.warn( "Sliding window inference results contain NaN or Inf.") zoom_scale = [ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip( output_i.shape[2:], roi_size) ] final_slicing: List[slice] = [] for sp in range(num_spatial_dims): slice_dim = slice( pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) slice_dim = slice( int( round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), int( round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), ) final_slicing.insert(0, slice_dim) while len(final_slicing) < len(output_i.shape): final_slicing.insert(0, slice(None)) output_image_list[ss] = output_i[final_slicing] if dict_key is not None: # if output of predictor is a dict final_output = dict(zip(dict_key, output_image_list)) else: final_output = tuple(output_image_list) # type: ignore final_output = final_output[ 0] if is_tensor_output else final_output # type: ignore if isinstance(inputs, MetaTensor): final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore return final_output
def compute_average_surface_distance( y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor], include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", ): """ This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting. In addition, if sets ``symmetric = True``, the average symmetric surface distance between these two inputs will be returned. The implementation refers to `DeepMind's implementation <https://github.com/deepmind/surface-distance>`_. Args: y_pred: input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized. y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. The values should be binarized. include_background: whether to skip distance computation on the first channel of the predicted output. Defaults to ``False``. symmetric: whether to calculate the symmetric average surface distance between `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. """ if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) if isinstance(y, torch.Tensor): y = y.float() if isinstance(y_pred, torch.Tensor): y_pred = y_pred.float() if y.shape != y_pred.shape: raise ValueError( f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}." ) batch_size, n_class = y_pred.shape[:2] asd = np.empty((batch_size, n_class)) for b, c in np.ndindex(batch_size, n_class): (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) if not np.any(edges_gt): warnings.warn( f"the ground truth of class {c} is all 0, this may result in nan/inf distance." ) if not np.any(edges_pred): warnings.warn( f"the prediction of class {c} is all 0, this may result in nan/inf distance." ) surface_distance = get_surface_distance( edges_pred, edges_gt, distance_metric=distance_metric) if symmetric: surface_distance_2 = get_surface_distance( edges_gt, edges_pred, distance_metric=distance_metric) surface_distance = np.concatenate( [surface_distance, surface_distance_2]) asd[b, c] = np.nan if surface_distance.shape == ( 0, ) else surface_distance.mean() return convert_data_type(asd, torch.Tensor)[0]