def _evaluate_patch_locations(self, sample): """Calculate the location for each patch based on the mask at different resolution level""" patch_size = self._get_size(sample) patch_level = self._get_level(sample) wsi_obj = self._get_wsi_object(sample) # load the entire image at level=mask_level wsi, _ = self.wsi_reader.get_data(wsi_obj, level=self.mask_level) # create the foreground tissue mask and get all indices for non-zero pixels mask = np.squeeze( convert_to_dst_type( ForegroundMask(hsv_threshold={"S": "otsu"})(wsi), dst=wsi)[0]) mask_locations = np.vstack(mask.nonzero()).T # convert mask locations to image locations at level=0 mask_ratio = self.wsi_reader.get_downsample_ratio( wsi_obj, self.mask_level) patch_ratio = self.wsi_reader.get_downsample_ratio( wsi_obj, patch_level) patch_size_0 = np.array([p * patch_ratio for p in patch_size]) # patch size at level 0 patch_locations = np.round((mask_locations + 0.5) * float(mask_ratio) - patch_size_0 // 2).astype(int) # fill out samples with location and metadata sample[WSIPatchKeys.SIZE.value] = patch_size sample[WSIPatchKeys.LEVEL.value] = patch_level sample[ProbMapKeys.NAME.value] = os.path.basename( sample[CommonKeys.IMAGE]) sample[ProbMapKeys.COUNT.value] = len(patch_locations) sample[ProbMapKeys.SIZE.value] = mask.shape return [{ **sample, WSIPatchKeys.LOCATION.value: np.array(loc), ProbMapKeys.LOCATION.value: mask_loc } for loc, mask_loc in zip(patch_locations, mask_locations)]
def sliding_window_inference( inputs: torch.Tensor, roi_size: Union[Sequence[int], int], sw_batch_size: int, predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], overlap: float = 0.25, mode: Union[BlendMode, str] = BlendMode.CONSTANT, sigma_scale: Union[Sequence[float], float] = 0.125, padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, progress: bool = False, roi_weight_map: Union[torch.Tensor, None] = None, *args: Any, **kwargs: Any, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: """ Sliding window inference on `inputs` with `predictor`. The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes could be ([128,64,256], [64,32,128]). In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). When roi_size is larger than the inputs' spatial size, the input image are padded during inference. To maintain the same spatial sizes, the output image will be cropped to the original input size. Args: inputs: input image to be processed (assuming NCHW[D]) roi_size: the spatial window size for inferences. When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the `roi_size` are non-positive values, the transform will use the corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. predictor: given input tensor ``patch_data`` in shape NCHW[D], The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the scaled output ROI sizes are still integers. If the `predictor`'s input and output spatial sizes are different, we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. overlap: Amount of overlap between scans. mode: {``"constant"``, ``"gaussian"``} How to blend output of overlapping windows. Defaults to ``"constant"``. - ``"constant``": gives equal weight to all predictions. - ``"gaussian``": gives less weight to predictions on edges of windows. sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding spatial dimensions. padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html cval: fill value for 'constant' padding mode. Default: 0 sw_device: device for the window data. By default the device (and accordingly the memory) of the `inputs` is used. Normally `sw_device` should be consistent with the device where `predictor` is defined. device: device for the stitched output prediction. By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. progress: whether to print a `tqdm` progress bar. roi_weight_map: pre-computed (non-negative) weight map for each ROI. If not given, and ``mode`` is not `constant`, this map will be computed on the fly. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. Note: - input must be channel-first and have a batch dim, supports N-D sliding window. """ compute_dtype = inputs.dtype num_spatial_dims = len(inputs.shape) - 2 if overlap < 0 or overlap >= 1: raise ValueError("overlap must be >= 0 and < 1.") # determine image spatial size and batch size # Note: all input images must have the same image size and batch size batch_size, _, *image_size_ = inputs.shape if device is None: device = inputs.device if sw_device is None: sw_device = inputs.device roi_size = fall_back_tuple(roi_size, image_size_) # in case that image size is smaller than roi size image_size = tuple( max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) pad_size = [] for k in range(len(inputs.shape) - 1, 1, -1): diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) # Store all slices in list slices = dense_patch_slices(image_size, roi_size, scan_interval) num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) if valid_patch_size == roi_size and (roi_weight_map is not None): importance_map = roi_weight_map else: try: importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) except BaseException as e: raise RuntimeError( "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore # handle non-positive weights min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) # Perform predictions dict_key, output_image_list, count_map_list = None, [], [] _initialized_ss = -1 is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple) # for each patch for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range( 0, total_slices, sw_batch_size): slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) unravel_slice = [ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] window_data = torch.cat([ convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice ]).to(sw_device) seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. seg_prob_tuple: Tuple[torch.Tensor, ...] if isinstance(seg_prob_out, torch.Tensor): seg_prob_tuple = (seg_prob_out, ) elif isinstance(seg_prob_out, Mapping): if dict_key is None: dict_key = sorted( seg_prob_out.keys()) # track predictor's output keys seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) is_tensor_output = False else: seg_prob_tuple = ensure_tuple(seg_prob_out) is_tensor_output = False # for each output in multi-output list for ss, seg_prob in enumerate(seg_prob_tuple): seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN # compute zoom scale: out_roi_size/in_roi_size zoom_scale = [] for axis, (img_s_i, out_w_i, in_w_i) in enumerate( zip(image_size, seg_prob.shape[2:], window_data.shape[2:])): _scale = out_w_i / float(in_w_i) if not (img_s_i * _scale).is_integer(): warnings.warn( f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." ) zoom_scale.append(_scale) if _initialized_ss < ss: # init. the ss-th buffer at the first iteration # construct multi-resolution outputs output_classes = seg_prob.shape[1] output_shape = [batch_size, output_classes] + [ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip( image_size, zoom_scale) ] # allocate memory to store the full output and the count for overlapping parts output_image_list.append( torch.zeros(output_shape, dtype=compute_dtype, device=device)) count_map_list.append( torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) _initialized_ss += 1 # resizing the importance_map resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) # store the result in the proper location of the full output. Apply weights from importance map. for idx, original_idx in zip(slice_range, unravel_slice): # zoom roi original_idx_zoom = list( original_idx) # 4D for 2D image, 5D for 3D image for axis in range(2, len(original_idx_zoom)): zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] if not zoomed_start.is_integer() or ( not zoomed_end.is_integer()): warnings.warn( f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." ) original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) importance_map_zoom = resizer( importance_map.unsqueeze(0))[0].to(compute_dtype) # store results and weights output_image_list[ss][ original_idx_zoom] += importance_map_zoom * seg_prob[ idx - slice_g] count_map_list[ss][original_idx_zoom] += ( importance_map_zoom.unsqueeze(0).unsqueeze(0).expand( count_map_list[ss][original_idx_zoom].shape)) # account for any overlapping sections for ss in range(len(output_image_list)): output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) # remove padding if image_size smaller than roi_size for ss, output_i in enumerate(output_image_list): if torch.isnan(output_i).any() or torch.isinf(output_i).any(): warnings.warn( "Sliding window inference results contain NaN or Inf.") zoom_scale = [ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip( output_i.shape[2:], roi_size) ] final_slicing: List[slice] = [] for sp in range(num_spatial_dims): slice_dim = slice( pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) slice_dim = slice( int( round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), int( round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), ) final_slicing.insert(0, slice_dim) while len(final_slicing) < len(output_i.shape): final_slicing.insert(0, slice(None)) output_image_list[ss] = output_i[final_slicing] if dict_key is not None: # if output of predictor is a dict final_output = dict(zip(dict_key, output_image_list)) else: final_output = tuple(output_image_list) # type: ignore final_output = final_output[ 0] if is_tensor_output else final_output # type: ignore if isinstance(inputs, MetaTensor): final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore return final_output
def grid_pull( input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True ) -> torch.Tensor: """ Sample an image with respect to a deformation field. `interpolation` can be an int, a string or an InterpolationType. Possible values are:: - 0 or 'nearest' or InterpolationType.nearest - 1 or 'linear' or InterpolationType.linear - 2 or 'quadratic' or InterpolationType.quadratic - 3 or 'cubic' or InterpolationType.cubic - 4 or 'fourth' or InterpolationType.fourth - 5 or 'fifth' or InterpolationType.fifth - 6 or 'sixth' or InterpolationType.sixth - 7 or 'seventh' or InterpolationType.seventh A list of values can be provided, in the order [W, H, D], to specify dimension-specific interpolation orders. `bound` can be an int, a string or a BoundType. Possible values are:: - 0 or 'replicate' or 'nearest' or BoundType.replicate or 'border' - 1 or 'dct1' or 'mirror' or BoundType.dct1 - 2 or 'dct2' or 'reflect' or BoundType.dct2 - 3 or 'dst1' or 'antimirror' or BoundType.dst1 - 4 or 'dst2' or 'antireflect' or BoundType.dst2 - 5 or 'dft' or 'wrap' or BoundType.dft - 7 or 'zero' or 'zeros' or BoundType.zero A list of values can be provided, in the order [W, H, D], to specify dimension-specific boundary conditions. `sliding` is a specific condition than only applies to flow fields (with as many channels as dimensions). It cannot be dimension-specific. Note that: - `dft` corresponds to circular padding - `dct2` corresponds to Neumann boundary conditions (symmetric) - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) See Also: - https://en.wikipedia.org/wiki/Discrete_cosine_transform - https://en.wikipedia.org/wiki/Discrete_sine_transform - ``help(monai._C.BoundType)`` - ``help(monai._C.InterpolationType)`` Args: input: Input image. `(B, C, Wi, Hi, Di)`. grid: Deformation field. `(B, Wo, Ho, Do, 1|2|3)`. interpolation (int or list[int] , optional): Interpolation order. Defaults to `'linear'`. bound (BoundType, or list[BoundType], optional): Boundary conditions. Defaults to `'zero'`. extrapolate: Extrapolate out-of-bound data. Defaults to `True`. Returns: output (torch.Tensor): Deformed image `(B, C, Wo, Ho, Do)`. """ # Convert parameters bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)] interpolation = [ _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in ensure_tuple(interpolation) ] out: torch.Tensor out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) if isinstance(input, monai.data.MetaTensor): out = convert_to_dst_type(out, dst=input)[0] return out
def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: img_np, *_ = convert_data_type(image, np.ndarray) # add random offset self.randomize(img_size=img_np.shape) if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): img_np = img_np[:, self.offset[0]:, self.offset[1]:] # pad to full size, divisible by tile_size if self.pad_full: c, h, w = img_np.shape pad_h = (self.tile_size - h % self.tile_size) % self.tile_size pad_w = (self.tile_size - w % self.tile_size) % self.tile_size img_np = np.pad( # type: ignore img_np, [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], constant_values=self.background_val, ) # extact tiles x_step, y_step = self.step, self.step h_tile, w_tile = self.tile_size, self.tile_size c_image, h_image, w_image = img_np.shape c_stride, x_stride, y_stride = img_np.strides llw = as_strided( img_np, shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) img_np = llw.reshape(-1, c_image, h_tile, w_tile) # type: ignore # if keeping all patches if self.tile_count is None: # retain only patches with significant foreground content to speed up inference # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size if self.filter_mode == "min": # default, keep non-background tiles (small values) idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) img_np = img_np[idxs.reshape(-1)] elif self.filter_mode == "max": idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) img_np = img_np[idxs.reshape(-1)] else: if len(img_np) > self.tile_count: if self.filter_mode == "min": # default, keep non-background tiles (smallest values) idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[:self.tile_count] img_np = img_np[idxs] elif self.filter_mode == "max": idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count:] img_np = img_np[idxs] else: # random subset (more appropriate for WSIs without distinct background) if self.random_idxs is not None: img_np = img_np[self.random_idxs] elif len(img_np) < self.tile_count: img_np = np.pad( # type: ignore img_np, [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], constant_values=self.background_val, ) image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) return image
def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Optional[ImageReader] = None): """ Load image file and metadata from the given filename(s). If `reader` is not specified, this class automatically chooses readers based on the reversed order of registered readers `self.readers`. Args: filename: path file or file-like object or a list of files. will save the filename to meta_data with key `filename_or_obj`. if provided a list of files, use the filename of first file to save, and will stack them together as multi-channels data. if provided directory path instead of file path, will treat it as DICOM images series and read. reader: runtime reader to load image file and metadata. """ filename = tuple(f"{Path(s).expanduser()}" for s in ensure_tuple(filename)) # allow Path objects img, err = None, [] if reader is not None: img = reader.read(filename) # runtime specified reader else: for reader in self.readers[::-1]: if self.auto_select: # rely on the filename extension to choose the reader if reader.verify_suffix(filename): img = reader.read(filename) break else: # try the user designated readers try: img = reader.read(filename) except Exception as e: err.append(traceback.format_exc()) logging.getLogger(self.__class__.__name__).debug( e, exc_info=True) logging.getLogger(self.__class__.__name__).info( f"{reader.__class__.__name__}: unable to load {filename}.\n" ) else: err = [] break if img is None or reader is None: if isinstance(filename, tuple) and len(filename) == 1: filename = filename[0] msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( f"{self.__class__.__name__} cannot find a suitable reader for file: {filename}.\n" " Please install the reader libraries, see also the installation instructions:\n" " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered: {self.readers}.\n{msg}") img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] if not isinstance(meta_data, dict): raise ValueError("`meta_data` must be a dict.") # make sure all elements in metadata are little endian meta_data = switch_endianness(meta_data, "<") meta_data[ Key. FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data, self.simple_keys) if self.ensure_channel_first: img = EnsureChannelFirst()(img) if self.image_only: return img return img, img.meta # for compatibility purpose