예제 #1
    def __init__(
        keys: KeysCollection,
        spatial_size: Union[Sequence[int], int],
        method: Union[Method, str] = Method.SYMMETRIC,
        mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT,
        allow_missing_keys: bool = False,
    ) -> None:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            spatial_size: the spatial size of output data after padding.
                If its components have non-positive values, the corresponding size of input image will be used.
            method: {``"symmetric"``, ``"end"``}
                Pad image symmetric on every side or only pad at the end sides. Defaults to ``"symmetric"``.
            mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
                ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
                One of the listed string values or a user supplied function. Defaults to ``"constant"``.
                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
                It also can be a sequence of string, each element corresponds to a key in ``keys``.
            allow_missing_keys: don't raise exception if key is missing.

        super().__init__(keys, allow_missing_keys)
        self.mode = ensure_tuple_rep(mode, len(self.keys))
        self.padder = SpatialPad(spatial_size, method)
예제 #2
    def __call__(self, batch: Any):
            batch: batch of data to pad-collate
        # data is either list of dicts or list of lists
        is_list_of_dicts = isinstance(batch[0], dict)
        # loop over items inside of each element in a batch
        for key_or_idx in batch[0].keys() if is_list_of_dicts else range(
            # calculate max size of each dimension
            max_shapes = []
            for elem in batch:
                if not isinstance(elem[key_or_idx],
                                  (torch.Tensor, np.ndarray)):
            # len > 0 if objects were arrays, else skip as no padding to be done
            if not max_shapes:
            max_shape = np.array(max_shapes).max(axis=0)
            # If all same size, skip
            if np.all(np.array(max_shapes).min(axis=0) == max_shape):
            # Do we need to convert output to Tensor?
            output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

            # Use `SpatialPadd` or `SpatialPad` to match sizes
            # Default params are central padding, padding with 0's
            # If input is dictionary, use the dictionary version so that the transformation is recorded

            padder = SpatialPad(spatial_size=max_shape,
            transform = padder if not output_to_tensor else Compose(
                [padder, ToTensor()])

            for idx, batch_i in enumerate(batch):
                im = batch_i[key_or_idx]
                orig_size = im.shape[1:]
                padded = transform(batch_i[key_or_idx])
                batch = replace_element(padded, batch, idx, key_or_idx)

                # If we have a dictionary of data, append to list
                if is_list_of_dicts:

        # After padding, use default list collator
        return list_data_collate(batch)
예제 #3
 def __init__(self,
         keys (hashable items): keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
         spatial_size (list): the spatial size of output data after padding.
         method (str): pad image symmetric on every side or only pad at the end sides. default is 'symmetric'.
         mode (str): one of the following string values or a user supplied function: {'constant', 'edge',
             'linear_ramp', 'maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric',
             'wrap', 'empty', <function>}
             for more details, please check: https://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html
     self.padder = SpatialPad(spatial_size, method, mode)
예제 #4
    def __call__(self, batch: Any):
            batch: batch of data to pad-collate
        # data is either list of dicts or list of lists
        is_list_of_dicts = isinstance(batch[0], dict)
        # loop over items inside of each element in a batch
        batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range(
        for key_or_idx in batch_item:
            # calculate max size of each dimension
            max_shapes = []
            for elem in batch:
                if not isinstance(elem[key_or_idx],
                                  (torch.Tensor, np.ndarray)):
            # len > 0 if objects were arrays, else skip as no padding to be done
            if not max_shapes:
            max_shape = np.array(max_shapes).max(axis=0)
            # If all same size, skip
            if np.all(np.array(max_shapes).min(axis=0) == max_shape):

            # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's
            padder = SpatialPad(spatial_size=max_shape,
            for idx, batch_i in enumerate(batch):
                orig_size = batch_i[key_or_idx].shape[1:]
                padded = padder(batch_i[key_or_idx])
                batch = replace_element(padded, batch, idx, key_or_idx)

                # If we have a dictionary of data, append to list
                if is_list_of_dicts:

        # After padding, use default list collator
        return list_data_collate(batch)
예제 #5
def matshow3d(
    title: Optional[str] = None,
    figsize=(10, 10),
    frames_per_row: Optional[int] = None,
    frame_dim: int = -3,
    channel_dim: Optional[int] = None,
    every_n: int = 1,
    interpolation: str = "none",
    margin: int = 1,
    Create a 3D volume figure as a grid of images.

        volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`.
            Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg.
            A list of channel-first (C, H[, W, D]) arrays can also be passed in,
            in which case they will be displayed as a padded and stacked volume.
        fig: matplotlib figure or Axes to use. If None, a new figure will be created.
        title: title of the figure.
        figsize: size of the figure.
        frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used.
        frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to
            the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`.
        channel_dim: if not None, explicitly specify the channel dimension to be transposed to the
            last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image.
            if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W).
            note that it can only support 3D input image. default is None.
        vmin: `vmin` for the matplotlib `imshow`.
        vmax: `vmax` for the matplotlib `imshow`.
        every_n: factor to subsample the frames so that only every n-th frame is displayed.
        interpolation: interpolation to use for the matplotlib `matshow`.
        show: if True, show the figure.
        fill_value: value to use for the empty part of the grid.
        margin: margin to use for the grid.
        dtype: data type of the output stacked frames.
        kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`.

    See Also:
        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html
        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html


        >>> import numpy as np
        >>> import matplotlib.pyplot as plt
        >>> from monai.visualize import matshow3d
        # create a figure of a 3D volume
        >>> volume = np.random.rand(10, 10, 10)
        >>> fig = plt.figure()
        >>> matshow3d(volume, fig=fig, title="3D Volume")
        >>> plt.show()
        # create a figure of a list of channel-first 3D volumes
        >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)]
        >>> fig = plt.figure()
        >>> matshow3d(volumes, fig=fig, title="List of Volumes")
        >>> plt.show()

    vol = convert_data_type(data=volume, output_type=np.ndarray)[0]
    if channel_dim is not None:
        if channel_dim not in [0, 1
                               ] or vol.shape[channel_dim] not in [1, 3, 4]:
            raise ValueError(
                "channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4."

    if isinstance(vol, (list, tuple)):
        # a sequence of channel-first volumes
        if not isinstance(vol[0], np.ndarray):
            raise ValueError("volume must be a list of arrays.")
        pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0)
        pad = SpatialPad(
            pad_size[1:])  # assuming channel-first for item in vol
        vol = np.concatenate([pad(v) for v in vol], axis=0)
    else:  # ndarray
        while len(vol.shape) < 3:
            vol = np.expand_dims(
                vol, 0)  # type: ignore  # so that we display 2d as well

    if channel_dim is not None:  # move the expected dim to construct frames with `B` dim
        vol = np.moveaxis(vol, frame_dim, -4)  # type: ignore
        vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1]))
        vol = np.moveaxis(vol, frame_dim, -3)  # type: ignore
        vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1]))
    vmin = np.nanmin(vol) if vmin is None else vmin
    vmax = np.nanmax(vol) if vmax is None else vmax

    # subsample every_n-th frame of the 3D volume
    vol = vol[::max(every_n, 1)]
    if not frames_per_row:
        frames_per_row = int(np.ceil(np.sqrt(len(vol))))
    # create the grid of frames
    cols = max(min(len(vol), frames_per_row), 1)
    rows = int(np.ceil(len(vol) / cols))
    width = [[0, cols * rows - len(vol)]]
    if channel_dim is not None:
        width += [[0, 0]]  # add pad width for the channel dim
    width += [[margin, margin]] * 2
    vol = np.pad(vol.astype(dtype, copy=False),
                 constant_values=fill_value)  # type: ignore
    im = np.block([[vol[i * cols + j] for j in range(cols)]
                   for i in range(rows)])
    if channel_dim is not None:
        # move channel dim to the end
        im = np.moveaxis(im, 0, -1)

    # figure related configurations
    if isinstance(fig, plt.Axes):
        ax = fig
        if fig is None:
            fig = plt.figure(tight_layout=True)
        if not fig.axes:
        ax = fig.axes[0]
    ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs)

    if title is not None:
    if figsize is not None and hasattr(fig, "set_size_inches"):
    if show:
    return fig, im
예제 #6
def pad_list_data_collate(
    batch: Sequence,
    method: Union[Method, str] = Method.SYMMETRIC,
    mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
    Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
    tensor in each dimension.

        Need to use this collate if apply some transforms that can generate batch data.

        batch: batch of data to pad-collate
        method: padding method (see :py:class:`monai.transforms.SpatialPad`)
        mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
    list_of_dicts = isinstance(batch[0], dict)
    for key_or_idx in batch[0].keys() if list_of_dicts else range(len(
        max_shapes = []
        for elem in batch:
            if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)):
        # len > 0 if objects were arrays
        if len(max_shapes) == 0:
        max_shape = np.array(max_shapes).max(axis=0)
        # If all same size, skip
        if np.all(np.array(max_shapes).min(axis=0) == max_shape):
        # Do we need to convert output to Tensor?
        output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

        # Use `SpatialPadd` or `SpatialPad` to match sizes
        # Default params are central padding, padding with 0's
        # If input is dictionary, use the dictionary version so that the transformation is recorded
        padder: Union[SpatialPadd, SpatialPad]
        if list_of_dicts:
            from monai.transforms.croppad.dictionary import SpatialPadd  # needs to be here to avoid circular import

            padder = SpatialPadd(key_or_idx, max_shape, method,
                                 mode)  # type: ignore

            from monai.transforms.croppad.array import SpatialPad  # needs to be here to avoid circular import

            padder = SpatialPad(max_shape, method, mode)  # type: ignore

        for idx in range(len(batch)):
            padded = padder(
                batch[idx])[key_or_idx] if list_of_dicts else padder(
            # since tuple is immutable we'll have to recreate
            if isinstance(batch[idx], tuple):
                batch[idx] = list(batch[idx])  # type: ignore
                batch[idx][key_or_idx] = padded
                batch[idx] = tuple(batch[idx])  # type: ignore
            # else, replace
                batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx]

            if output_to_tensor:
                batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx])

    # After padding, use default list collator
    return list_data_collate(batch)
예제 #7
def pad_images(
    input_images: Union[List[Tensor], Tensor],
    spatial_dims: int,
    size_divisible: Union[int, Sequence[int]],
    mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
) -> Tuple[Tensor, List[List[int]]]:
    Pad the input images, so that the output spatial sizes are divisible by `size_divisible`.
    It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.
    Padded size (H, W) or (H, W, D) is divisible by size_divisible.
    Default padding uses constant padding with value 0.0

        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2D or 3D.
        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.
            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.
        mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for `torch.pad` function.

        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor
        - image_sizes, the original spatial size of each image
    size_divisible = ensure_tuple_rep(size_divisible, spatial_dims)

    # If input_images: Tensor
    if isinstance(input_images, Tensor):
        orig_size = list(input_images.shape[-spatial_dims:])
        new_size = compute_divisible_spatial_size(spatial_shape=orig_size,
        all_pad_width = [(0, max(sp_i - orig_size[i], 0))
                         for i, sp_i in enumerate(new_size)]
        pt_pad_width = [
            val for sublist in all_pad_width for val in sublist[::-1]
        if max(pt_pad_width) == 0:
            # if there is no need to pad
            return input_images, [orig_size] * input_images.shape[0]
        mode_: str = convert_pad_mode(dst=input_images, mode=mode)
        return F.pad(input_images, pt_pad_width, mode=mode_,
                     **kwargs), [orig_size] * input_images.shape[0]

    # If input_images: List[Tensor])
    image_sizes = [img.shape[-spatial_dims:] for img in input_images]
    in_channels = input_images[0].shape[0]
    dtype = input_images[0].dtype
    device = input_images[0].device

    # compute max_spatial_size
    image_sizes_t = torch.tensor(image_sizes)
    max_spatial_size_t, _ = torch.max(image_sizes_t, dim=0)

    if len(max_spatial_size_t) != spatial_dims or len(
            size_divisible) != spatial_dims:
        raise ValueError(
            " Require len(max_spatial_size_t) == spatial_dims ==len(size_divisible)."

    max_spatial_size = compute_divisible_spatial_size(
        spatial_shape=list(max_spatial_size_t), k=size_divisible)

    # allocate memory for the padded images
    images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size,

    # Use `SpatialPad` to match sizes, padding in the end will not affect boxes
    padder = SpatialPad(spatial_size=max_spatial_size,
    for idx, img in enumerate(input_images):
        images[idx, ...] = padder(img)

    return images, [list(ss) for ss in image_sizes]