Beispiel #1
0
    def _evaluate_patch_locations(self, sample):
        """Calculate the location for each patch based on the mask at different resolution level"""
        patch_size = self._get_size(sample)
        patch_level = self._get_level(sample)
        wsi_obj = self._get_wsi_object(sample)

        # load the entire image at level=mask_level
        wsi, _ = self.wsi_reader.get_data(wsi_obj, level=self.mask_level)

        # create the foreground tissue mask and get all indices for non-zero pixels
        mask = np.squeeze(
            convert_to_dst_type(
                ForegroundMask(hsv_threshold={"S": "otsu"})(wsi), dst=wsi)[0])
        mask_locations = np.vstack(mask.nonzero()).T

        # convert mask locations to image locations at level=0
        mask_ratio = self.wsi_reader.get_downsample_ratio(
            wsi_obj, self.mask_level)
        patch_ratio = self.wsi_reader.get_downsample_ratio(
            wsi_obj, patch_level)
        patch_size_0 = np.array([p * patch_ratio
                                 for p in patch_size])  # patch size at level 0
        patch_locations = np.round((mask_locations + 0.5) * float(mask_ratio) -
                                   patch_size_0 // 2).astype(int)

        # fill out samples with location and metadata
        sample[WSIPatchKeys.SIZE.value] = patch_size
        sample[WSIPatchKeys.LEVEL.value] = patch_level
        sample[ProbMapKeys.NAME.value] = os.path.basename(
            sample[CommonKeys.IMAGE])
        sample[ProbMapKeys.COUNT.value] = len(patch_locations)
        sample[ProbMapKeys.SIZE.value] = mask.shape
        return [{
            **sample, WSIPatchKeys.LOCATION.value: np.array(loc),
            ProbMapKeys.LOCATION.value: mask_loc
        } for loc, mask_loc in zip(patch_locations, mask_locations)]
Beispiel #2
0
def sliding_window_inference(
    inputs: torch.Tensor,
    roi_size: Union[Sequence[int], int],
    sw_batch_size: int,
    predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor],
                                   Dict[Any, torch.Tensor]]],
    overlap: float = 0.25,
    mode: Union[BlendMode, str] = BlendMode.CONSTANT,
    sigma_scale: Union[Sequence[float], float] = 0.125,
    padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
    cval: float = 0.0,
    sw_device: Union[torch.device, str, None] = None,
    device: Union[torch.device, str, None] = None,
    progress: bool = False,
    roi_weight_map: Union[torch.Tensor, None] = None,
    *args: Any,
    **kwargs: Any,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
    """
    Sliding window inference on `inputs` with `predictor`.

    The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
    Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
    e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
    could be ([128,64,256], [64,32,128]).
    In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
    an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
    so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).

    When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
    To maintain the same spatial sizes, the output image will be cropped to the original input size.

    Args:
        inputs: input image to be processed (assuming NCHW[D])
        roi_size: the spatial window size for inferences.
            When its components have None or non-positives, the corresponding inputs dimension will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        predictor: given input tensor ``patch_data`` in shape NCHW[D],
            The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
            with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
            where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
            N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
            the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
            In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
            to ensure the scaled output ROI sizes are still integers.
            If the `predictor`'s input and output spatial sizes are different,
            we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
        overlap: Amount of overlap between scans.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a `tqdm` progress bar.
        roi_weight_map: pre-computed (non-negative) weight map for each ROI.
            If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
        args: optional args to be passed to ``predictor``.
        kwargs: optional keyword args to be passed to ``predictor``.

    Note:
        - input must be channel-first and have a batch dim, supports N-D sliding window.

    """
    compute_dtype = inputs.dtype
    num_spatial_dims = len(inputs.shape) - 2
    if overlap < 0 or overlap >= 1:
        raise ValueError("overlap must be >= 0 and < 1.")

    # determine image spatial size and batch size
    # Note: all input images must have the same image size and batch size
    batch_size, _, *image_size_ = inputs.shape

    if device is None:
        device = inputs.device
    if sw_device is None:
        sw_device = inputs.device

    roi_size = fall_back_tuple(roi_size, image_size_)
    # in case that image size is smaller than roi size
    image_size = tuple(
        max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
    pad_size = []
    for k in range(len(inputs.shape) - 1, 1, -1):
        diff = max(roi_size[k - 2] - inputs.shape[k], 0)
        half = diff // 2
        pad_size.extend([half, diff - half])
    inputs = F.pad(inputs,
                   pad=pad_size,
                   mode=look_up_option(padding_mode, PytorchPadMode),
                   value=cval)

    scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims,
                                       overlap)

    # Store all slices in list
    slices = dense_patch_slices(image_size, roi_size, scan_interval)
    num_win = len(slices)  # number of windows per image
    total_slices = num_win * batch_size  # total number of windows

    # Create window-level importance map
    valid_patch_size = get_valid_patch_size(image_size, roi_size)
    if valid_patch_size == roi_size and (roi_weight_map is not None):
        importance_map = roi_weight_map
    else:
        try:
            importance_map = compute_importance_map(valid_patch_size,
                                                    mode=mode,
                                                    sigma_scale=sigma_scale,
                                                    device=device)
        except BaseException as e:
            raise RuntimeError(
                "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
            ) from e
    importance_map = convert_data_type(importance_map, torch.Tensor, device,
                                       compute_dtype)[0]  # type: ignore
    # handle non-positive weights
    min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
    importance_map = torch.clamp(importance_map.to(torch.float32),
                                 min=min_non_zero).to(compute_dtype)

    # Perform predictions
    dict_key, output_image_list, count_map_list = None, [], []
    _initialized_ss = -1
    is_tensor_output = True  # whether the predictor's output is a tensor (instead of dict/tuple)

    # for each patch
    for slice_g in tqdm(range(0, total_slices,
                              sw_batch_size)) if progress else range(
                                  0, total_slices, sw_batch_size):
        slice_range = range(slice_g, min(slice_g + sw_batch_size,
                                         total_slices))
        unravel_slice = [
            [slice(int(idx / num_win),
                   int(idx / num_win) + 1),
             slice(None)] + list(slices[idx % num_win]) for idx in slice_range
        ]
        window_data = torch.cat([
            convert_data_type(inputs[win_slice], torch.Tensor)[0]
            for win_slice in unravel_slice
        ]).to(sw_device)
        seg_prob_out = predictor(window_data, *args,
                                 **kwargs)  # batched patch segmentation

        # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
        seg_prob_tuple: Tuple[torch.Tensor, ...]
        if isinstance(seg_prob_out, torch.Tensor):
            seg_prob_tuple = (seg_prob_out, )
        elif isinstance(seg_prob_out, Mapping):
            if dict_key is None:
                dict_key = sorted(
                    seg_prob_out.keys())  # track predictor's output keys
            seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
            is_tensor_output = False
        else:
            seg_prob_tuple = ensure_tuple(seg_prob_out)
            is_tensor_output = False

        # for each output in multi-output list
        for ss, seg_prob in enumerate(seg_prob_tuple):
            seg_prob = seg_prob.to(device)  # BxCxMxNxP or BxCxMxN

            # compute zoom scale: out_roi_size/in_roi_size
            zoom_scale = []
            for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
                    zip(image_size, seg_prob.shape[2:],
                        window_data.shape[2:])):
                _scale = out_w_i / float(in_w_i)
                if not (img_s_i * _scale).is_integer():
                    warnings.warn(
                        f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
                        f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
                    )
                zoom_scale.append(_scale)

            if _initialized_ss < ss:  # init. the ss-th buffer at the first iteration
                # construct multi-resolution outputs
                output_classes = seg_prob.shape[1]
                output_shape = [batch_size, output_classes] + [
                    int(image_size_d * zoom_scale_d)
                    for image_size_d, zoom_scale_d in zip(
                        image_size, zoom_scale)
                ]
                # allocate memory to store the full output and the count for overlapping parts
                output_image_list.append(
                    torch.zeros(output_shape,
                                dtype=compute_dtype,
                                device=device))
                count_map_list.append(
                    torch.zeros([1, 1] + output_shape[2:],
                                dtype=compute_dtype,
                                device=device))
                _initialized_ss += 1

            # resizing the importance_map
            resizer = Resize(spatial_size=seg_prob.shape[2:],
                             mode="nearest",
                             anti_aliasing=False)

            # store the result in the proper location of the full output. Apply weights from importance map.
            for idx, original_idx in zip(slice_range, unravel_slice):
                # zoom roi
                original_idx_zoom = list(
                    original_idx)  # 4D for 2D image, 5D for 3D image
                for axis in range(2, len(original_idx_zoom)):
                    zoomed_start = original_idx[axis].start * zoom_scale[axis -
                                                                         2]
                    zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
                    if not zoomed_start.is_integer() or (
                            not zoomed_end.is_integer()):
                        warnings.warn(
                            f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
                            f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
                            f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
                            f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
                            f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
                            "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
                        )
                    original_idx_zoom[axis] = slice(int(zoomed_start),
                                                    int(zoomed_end), None)
                importance_map_zoom = resizer(
                    importance_map.unsqueeze(0))[0].to(compute_dtype)
                # store results and weights
                output_image_list[ss][
                    original_idx_zoom] += importance_map_zoom * seg_prob[
                        idx - slice_g]
                count_map_list[ss][original_idx_zoom] += (
                    importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(
                        count_map_list[ss][original_idx_zoom].shape))

    # account for any overlapping sections
    for ss in range(len(output_image_list)):
        output_image_list[ss] = (output_image_list[ss] /
                                 count_map_list.pop(0)).to(compute_dtype)

    # remove padding if image_size smaller than roi_size
    for ss, output_i in enumerate(output_image_list):
        if torch.isnan(output_i).any() or torch.isinf(output_i).any():
            warnings.warn(
                "Sliding window inference results contain NaN or Inf.")

        zoom_scale = [
            seg_prob_map_shape_d / roi_size_d
            for seg_prob_map_shape_d, roi_size_d in zip(
                output_i.shape[2:], roi_size)
        ]

        final_slicing: List[slice] = []
        for sp in range(num_spatial_dims):
            slice_dim = slice(
                pad_size[sp * 2],
                image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
            slice_dim = slice(
                int(
                    round(slice_dim.start *
                          zoom_scale[num_spatial_dims - sp - 1])),
                int(
                    round(slice_dim.stop *
                          zoom_scale[num_spatial_dims - sp - 1])),
            )
            final_slicing.insert(0, slice_dim)
        while len(final_slicing) < len(output_i.shape):
            final_slicing.insert(0, slice(None))
        output_image_list[ss] = output_i[final_slicing]

    if dict_key is not None:  # if output of predictor is a dict
        final_output = dict(zip(dict_key, output_image_list))
    else:
        final_output = tuple(output_image_list)  # type: ignore
    final_output = final_output[
        0] if is_tensor_output else final_output  # type: ignore
    if isinstance(inputs, MetaTensor):
        final_output = convert_to_dst_type(final_output,
                                           inputs)[0]  # type: ignore
    return final_output
Beispiel #3
0
def grid_pull(
    input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True
) -> torch.Tensor:
    """
    Sample an image with respect to a deformation field.

    `interpolation` can be an int, a string or an InterpolationType.
    Possible values are::

        - 0 or 'nearest'    or InterpolationType.nearest
        - 1 or 'linear'     or InterpolationType.linear
        - 2 or 'quadratic'  or InterpolationType.quadratic
        - 3 or 'cubic'      or InterpolationType.cubic
        - 4 or 'fourth'     or InterpolationType.fourth
        - 5 or 'fifth'      or InterpolationType.fifth
        - 6 or 'sixth'      or InterpolationType.sixth
        - 7 or 'seventh'    or InterpolationType.seventh

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific interpolation orders.

    `bound` can be an int, a string or a BoundType.
    Possible values are::

        - 0 or 'replicate' or 'nearest'      or BoundType.replicate or 'border'
        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1
        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2
        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1
        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2
        - 5 or 'dft'       or 'wrap'         or BoundType.dft
        - 7 or 'zero'      or 'zeros'        or BoundType.zero

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific boundary conditions.
    `sliding` is a specific condition than only applies to flow fields
    (with as many channels as dimensions). It cannot be dimension-specific.
    Note that:

        - `dft` corresponds to circular padding
        - `dct2` corresponds to Neumann boundary conditions (symmetric)
        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)

    See Also:
        - https://en.wikipedia.org/wiki/Discrete_cosine_transform
        - https://en.wikipedia.org/wiki/Discrete_sine_transform
        - ``help(monai._C.BoundType)``
        - ``help(monai._C.InterpolationType)``

    Args:
        input: Input image. `(B, C, Wi, Hi, Di)`.
        grid: Deformation field. `(B, Wo, Ho, Do, 1|2|3)`.
        interpolation (int or list[int] , optional): Interpolation order.
            Defaults to `'linear'`.
        bound (BoundType, or list[BoundType], optional): Boundary conditions.
            Defaults to `'zero'`.
        extrapolate: Extrapolate out-of-bound data.
            Defaults to `True`.

    Returns:
        output (torch.Tensor): Deformed image `(B, C, Wo, Ho, Do)`.

    """
    # Convert parameters
    bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
    interpolation = [
        _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
        for i in ensure_tuple(interpolation)
    ]
    out: torch.Tensor
    out = _GridPull.apply(input, grid, interpolation, bound, extrapolate)
    if isinstance(input, monai.data.MetaTensor):
        out = convert_to_dst_type(out, dst=input)[0]
    return out
Beispiel #4
0
    def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
        img_np, *_ = convert_data_type(image, np.ndarray)

        # add random offset
        self.randomize(img_size=img_np.shape)

        if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
            img_np = img_np[:, self.offset[0]:, self.offset[1]:]

        # pad to full size, divisible by tile_size
        if self.pad_full:
            c, h, w = img_np.shape
            pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
            pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
            img_np = np.pad(  # type: ignore
                img_np,
                [[0, 0], [pad_h // 2, pad_h - pad_h // 2],
                 [pad_w // 2, pad_w - pad_w // 2]],
                constant_values=self.background_val,
            )

        # extact tiles
        x_step, y_step = self.step, self.step
        h_tile, w_tile = self.tile_size, self.tile_size
        c_image, h_image, w_image = img_np.shape
        c_stride, x_stride, y_stride = img_np.strides
        llw = as_strided(
            img_np,
            shape=((h_image - h_tile) // x_step + 1,
                   (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile),
            strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride,
                     y_stride),
            writeable=False,
        )
        img_np = llw.reshape(-1, c_image, h_tile, w_tile)  # type: ignore

        # if keeping all patches
        if self.tile_count is None:
            # retain only patches with significant foreground content to speed up inference
            # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference
            thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
            if self.filter_mode == "min":
                # default, keep non-background tiles (small values)
                idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh)
                img_np = img_np[idxs.reshape(-1)]
            elif self.filter_mode == "max":
                idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh)
                img_np = img_np[idxs.reshape(-1)]

        else:
            if len(img_np) > self.tile_count:

                if self.filter_mode == "min":
                    # default, keep non-background tiles (smallest values)
                    idxs = np.argsort(img_np.sum(axis=(1, 2,
                                                       3)))[:self.tile_count]
                    img_np = img_np[idxs]
                elif self.filter_mode == "max":
                    idxs = np.argsort(img_np.sum(axis=(1, 2,
                                                       3)))[-self.tile_count:]
                    img_np = img_np[idxs]
                else:
                    # random subset (more appropriate for WSIs without distinct background)
                    if self.random_idxs is not None:
                        img_np = img_np[self.random_idxs]

            elif len(img_np) < self.tile_count:
                img_np = np.pad(  # type: ignore
                    img_np,
                    [[0, self.tile_count - len(img_np)], [0, 0], [0, 0],
                     [0, 0]],
                    constant_values=self.background_val,
                )

        image, *_ = convert_to_dst_type(src=img_np,
                                        dst=image,
                                        dtype=image.dtype)

        return image
Beispiel #5
0
    def __call__(self,
                 filename: Union[Sequence[PathLike], PathLike],
                 reader: Optional[ImageReader] = None):
        """
        Load image file and metadata from the given filename(s).
        If `reader` is not specified, this class automatically chooses readers based on the
        reversed order of registered readers `self.readers`.

        Args:
            filename: path file or file-like object or a list of files.
                will save the filename to meta_data with key `filename_or_obj`.
                if provided a list of files, use the filename of first file to save,
                and will stack them together as multi-channels data.
                if provided directory path instead of file path, will treat it as
                DICOM images series and read.
            reader: runtime reader to load image file and metadata.

        """
        filename = tuple(f"{Path(s).expanduser()}"
                         for s in ensure_tuple(filename))  # allow Path objects
        img, err = None, []
        if reader is not None:
            img = reader.read(filename)  # runtime specified reader
        else:
            for reader in self.readers[::-1]:
                if self.auto_select:  # rely on the filename extension to choose the reader
                    if reader.verify_suffix(filename):
                        img = reader.read(filename)
                        break
                else:  # try the user designated readers
                    try:
                        img = reader.read(filename)
                    except Exception as e:
                        err.append(traceback.format_exc())
                        logging.getLogger(self.__class__.__name__).debug(
                            e, exc_info=True)
                        logging.getLogger(self.__class__.__name__).info(
                            f"{reader.__class__.__name__}: unable to load {filename}.\n"
                        )
                    else:
                        err = []
                        break

        if img is None or reader is None:
            if isinstance(filename, tuple) and len(filename) == 1:
                filename = filename[0]
            msg = "\n".join([f"{e}" for e in err])
            raise RuntimeError(
                f"{self.__class__.__name__} cannot find a suitable reader for file: {filename}.\n"
                "    Please install the reader libraries, see also the installation instructions:\n"
                "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
                f"   The current registered: {self.readers}.\n{msg}")

        img_array: NdarrayOrTensor
        img_array, meta_data = reader.get_data(img)
        img_array = convert_to_dst_type(img_array,
                                        dst=img_array,
                                        dtype=self.dtype)[0]
        if not isinstance(meta_data, dict):
            raise ValueError("`meta_data` must be a dict.")
        # make sure all elements in metadata are little endian
        meta_data = switch_endianness(meta_data, "<")

        meta_data[
            Key.
            FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}"  # Path obj should be strings for data loader
        img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data,
                                                     self.simple_keys)
        if self.ensure_channel_first:
            img = EnsureChannelFirst()(img)
        if self.image_only:
            return img
        return img, img.meta  # for compatibility purpose