Exemple #1
0
    def __call__(self, batch: Any):
        # data is either list of dicts or list of lists
        is_list_of_dicts = isinstance(batch[0], dict)
        # loop over items inside of each element in a batch
        for key_or_idx in batch[0].keys() if is_list_of_dicts else range(
                len(batch[0])):
            # calculate max size of each dimension
            max_shapes = []
            for elem in batch:
                if not isinstance(elem[key_or_idx],
                                  (torch.Tensor, np.ndarray)):
                    break
                max_shapes.append(elem[key_or_idx].shape[1:])
            # len > 0 if objects were arrays, else skip as no padding to be done
            if len(max_shapes) == 0:
                continue
            max_shape = np.array(max_shapes).max(axis=0)
            # If all same size, skip
            if np.all(np.array(max_shapes).min(axis=0) == max_shape):
                continue
            # Do we need to convert output to Tensor?
            output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

            # Use `SpatialPadd` or `SpatialPad` to match sizes
            # Default params are central padding, padding with 0's
            # If input is dictionary, use the dictionary version so that the transformation is recorded

            padder = SpatialPad(max_shape, self.method,
                                self.mode)  # type: ignore
            transform = padder if not output_to_tensor else Compose(
                [padder, ToTensor()])

            for idx in range(len(batch)):
                im = batch[idx][key_or_idx]
                orig_size = im.shape[1:]
                padded = transform(batch[idx][key_or_idx])
                batch = replace_element(padded, batch, idx, key_or_idx)

                # If we have a dictionary of data, append to list
                if is_list_of_dicts:
                    self.push_transform(batch[idx],
                                        key_or_idx,
                                        orig_size=orig_size)

        # After padding, use default list collator
        return list_data_collate(batch)
Exemple #2
0
 def __init__(self,
              keys,
              spatial_size,
              method='symmetric',
              mode='constant'):
     """
     Args:
         keys (hashable items): keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
         spatial_size (list): the spatial size of output data after padding.
         method (str): pad image symmetric on every side or only pad at the end sides. default is 'symmetric'.
         mode (str): one of the following string values or a user supplied function: {'constant', 'edge',
             'linear_ramp', 'maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric',
             'wrap', 'empty', <function>}
             for more details, please check: https://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html
     """
     super().__init__(keys)
     self.padder = SpatialPad(spatial_size, method, mode)
Exemple #3
0
    def __call__(self, batch: Any):
        """
        Args:
            batch: batch of data to pad-collate
        """
        # data is either list of dicts or list of lists
        is_list_of_dicts = isinstance(batch[0], dict)
        # loop over items inside of each element in a batch
        batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range(
            len(batch[0]))
        for key_or_idx in batch_item:
            # calculate max size of each dimension
            max_shapes = []
            for elem in batch:
                if not isinstance(elem[key_or_idx],
                                  (torch.Tensor, np.ndarray)):
                    break
                max_shapes.append(elem[key_or_idx].shape[1:])
            # len > 0 if objects were arrays, else skip as no padding to be done
            if not max_shapes:
                continue
            max_shape = np.array(max_shapes).max(axis=0)
            # If all same size, skip
            if np.all(np.array(max_shapes).min(axis=0) == max_shape):
                continue

            # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's
            padder = SpatialPad(spatial_size=max_shape,
                                method=self.method,
                                mode=self.mode,
                                **self.kwargs)
            for idx, batch_i in enumerate(batch):
                orig_size = batch_i[key_or_idx].shape[1:]
                padded = padder(batch_i[key_or_idx])
                batch = replace_element(padded, batch, idx, key_or_idx)

                # If we have a dictionary of data, append to list
                if is_list_of_dicts:
                    self.push_transform(batch[idx],
                                        key_or_idx,
                                        orig_size=orig_size)

        # After padding, use default list collator
        return list_data_collate(batch)
Exemple #4
0
def matshow3d(
    volume,
    fig=None,
    title: Optional[str] = None,
    figsize=(10, 10),
    frames_per_row: Optional[int] = None,
    frame_dim: int = -3,
    channel_dim: Optional[int] = None,
    vmin=None,
    vmax=None,
    every_n: int = 1,
    interpolation: str = "none",
    show=False,
    fill_value=np.nan,
    margin: int = 1,
    dtype=np.float32,
    **kwargs,
):
    """
    Create a 3D volume figure as a grid of images.

    Args:
        volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`.
            Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg.
            A list of channel-first (C, H[, W, D]) arrays can also be passed in,
            in which case they will be displayed as a padded and stacked volume.
        fig: matplotlib figure or Axes to use. If None, a new figure will be created.
        title: title of the figure.
        figsize: size of the figure.
        frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used.
        frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to
            the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`.
        channel_dim: if not None, explicitly specify the channel dimension to be transposed to the
            last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image.
            if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W).
            note that it can only support 3D input image. default is None.
        vmin: `vmin` for the matplotlib `imshow`.
        vmax: `vmax` for the matplotlib `imshow`.
        every_n: factor to subsample the frames so that only every n-th frame is displayed.
        interpolation: interpolation to use for the matplotlib `matshow`.
        show: if True, show the figure.
        fill_value: value to use for the empty part of the grid.
        margin: margin to use for the grid.
        dtype: data type of the output stacked frames.
        kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`.

    See Also:
        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html
        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html

    Example:

        >>> import numpy as np
        >>> import matplotlib.pyplot as plt
        >>> from monai.visualize import matshow3d
        # create a figure of a 3D volume
        >>> volume = np.random.rand(10, 10, 10)
        >>> fig = plt.figure()
        >>> matshow3d(volume, fig=fig, title="3D Volume")
        >>> plt.show()
        # create a figure of a list of channel-first 3D volumes
        >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)]
        >>> fig = plt.figure()
        >>> matshow3d(volumes, fig=fig, title="List of Volumes")
        >>> plt.show()

    """
    vol = convert_data_type(data=volume, output_type=np.ndarray)[0]
    if channel_dim is not None:
        if channel_dim not in [0, 1
                               ] or vol.shape[channel_dim] not in [1, 3, 4]:
            raise ValueError(
                "channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4."
            )

    if isinstance(vol, (list, tuple)):
        # a sequence of channel-first volumes
        if not isinstance(vol[0], np.ndarray):
            raise ValueError("volume must be a list of arrays.")
        pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0)
        pad = SpatialPad(
            pad_size[1:])  # assuming channel-first for item in vol
        vol = np.concatenate([pad(v) for v in vol], axis=0)
    else:  # ndarray
        while len(vol.shape) < 3:
            vol = np.expand_dims(
                vol, 0)  # type: ignore  # so that we display 2d as well

    if channel_dim is not None:  # move the expected dim to construct frames with `B` dim
        vol = np.moveaxis(vol, frame_dim, -4)  # type: ignore
        vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1]))
    else:
        vol = np.moveaxis(vol, frame_dim, -3)  # type: ignore
        vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1]))
    vmin = np.nanmin(vol) if vmin is None else vmin
    vmax = np.nanmax(vol) if vmax is None else vmax

    # subsample every_n-th frame of the 3D volume
    vol = vol[::max(every_n, 1)]
    if not frames_per_row:
        frames_per_row = int(np.ceil(np.sqrt(len(vol))))
    # create the grid of frames
    cols = max(min(len(vol), frames_per_row), 1)
    rows = int(np.ceil(len(vol) / cols))
    width = [[0, cols * rows - len(vol)]]
    if channel_dim is not None:
        width += [[0, 0]]  # add pad width for the channel dim
    width += [[margin, margin]] * 2
    vol = np.pad(vol.astype(dtype, copy=False),
                 width,
                 mode="constant",
                 constant_values=fill_value)  # type: ignore
    im = np.block([[vol[i * cols + j] for j in range(cols)]
                   for i in range(rows)])
    if channel_dim is not None:
        # move channel dim to the end
        im = np.moveaxis(im, 0, -1)

    # figure related configurations
    if isinstance(fig, plt.Axes):
        ax = fig
    else:
        if fig is None:
            fig = plt.figure(tight_layout=True)
        if not fig.axes:
            fig.add_subplot(111)
        ax = fig.axes[0]
    ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs)
    ax.axis("off")

    if title is not None:
        ax.set_title(title)
    if figsize is not None and hasattr(fig, "set_size_inches"):
        fig.set_size_inches(figsize)
    if show:
        plt.show()
    return fig, im
Exemple #5
0
def pad_list_data_collate(
    batch: Sequence,
    method: Union[Method, str] = Method.SYMMETRIC,
    mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
):
    """
    Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
    tensor in each dimension.

    Note:
        Need to use this collate if apply some transforms that can generate batch data.

    Args:
        batch: batch of data to pad-collate
        method: padding method (see :py:class:`monai.transforms.SpatialPad`)
        mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
    """
    list_of_dicts = isinstance(batch[0], dict)
    for key_or_idx in batch[0].keys() if list_of_dicts else range(len(
            batch[0])):
        max_shapes = []
        for elem in batch:
            if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)):
                break
            max_shapes.append(elem[key_or_idx].shape[1:])
        # len > 0 if objects were arrays
        if len(max_shapes) == 0:
            continue
        max_shape = np.array(max_shapes).max(axis=0)
        # If all same size, skip
        if np.all(np.array(max_shapes).min(axis=0) == max_shape):
            continue
        # Do we need to convert output to Tensor?
        output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

        # Use `SpatialPadd` or `SpatialPad` to match sizes
        # Default params are central padding, padding with 0's
        # If input is dictionary, use the dictionary version so that the transformation is recorded
        padder: Union[SpatialPadd, SpatialPad]
        if list_of_dicts:
            from monai.transforms.croppad.dictionary import SpatialPadd  # needs to be here to avoid circular import

            padder = SpatialPadd(key_or_idx, max_shape, method,
                                 mode)  # type: ignore

        else:
            from monai.transforms.croppad.array import SpatialPad  # needs to be here to avoid circular import

            padder = SpatialPad(max_shape, method, mode)  # type: ignore

        for idx in range(len(batch)):
            padded = padder(
                batch[idx])[key_or_idx] if list_of_dicts else padder(
                    batch[idx][key_or_idx])
            # since tuple is immutable we'll have to recreate
            if isinstance(batch[idx], tuple):
                batch[idx] = list(batch[idx])  # type: ignore
                batch[idx][key_or_idx] = padded
                batch[idx] = tuple(batch[idx])  # type: ignore
            # else, replace
            else:
                batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx]

            if output_to_tensor:
                batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx])

    # After padding, use default list collator
    return list_data_collate(batch)
Exemple #6
0
def pad_images(
    input_images: Union[List[Tensor], Tensor],
    spatial_dims: int,
    size_divisible: Union[int, Sequence[int]],
    mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
    **kwargs,
) -> Tuple[Tensor, List[List[int]]]:
    """
    Pad the input images, so that the output spatial sizes are divisible by `size_divisible`.
    It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.
    Padded size (H, W) or (H, W, D) is divisible by size_divisible.
    Default padding uses constant padding with value 0.0

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2D or 3D.
        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.
            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.
        mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for `torch.pad` function.

    Return:
        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor
        - image_sizes, the original spatial size of each image
    """
    size_divisible = ensure_tuple_rep(size_divisible, spatial_dims)

    # If input_images: Tensor
    if isinstance(input_images, Tensor):
        orig_size = list(input_images.shape[-spatial_dims:])
        new_size = compute_divisible_spatial_size(spatial_shape=orig_size,
                                                  k=size_divisible)
        all_pad_width = [(0, max(sp_i - orig_size[i], 0))
                         for i, sp_i in enumerate(new_size)]
        pt_pad_width = [
            val for sublist in all_pad_width for val in sublist[::-1]
        ][::-1]
        if max(pt_pad_width) == 0:
            # if there is no need to pad
            return input_images, [orig_size] * input_images.shape[0]
        mode_: str = convert_pad_mode(dst=input_images, mode=mode)
        return F.pad(input_images, pt_pad_width, mode=mode_,
                     **kwargs), [orig_size] * input_images.shape[0]

    # If input_images: List[Tensor])
    image_sizes = [img.shape[-spatial_dims:] for img in input_images]
    in_channels = input_images[0].shape[0]
    dtype = input_images[0].dtype
    device = input_images[0].device

    # compute max_spatial_size
    image_sizes_t = torch.tensor(image_sizes)
    max_spatial_size_t, _ = torch.max(image_sizes_t, dim=0)

    if len(max_spatial_size_t) != spatial_dims or len(
            size_divisible) != spatial_dims:
        raise ValueError(
            " Require len(max_spatial_size_t) == spatial_dims ==len(size_divisible)."
        )

    max_spatial_size = compute_divisible_spatial_size(
        spatial_shape=list(max_spatial_size_t), k=size_divisible)

    # allocate memory for the padded images
    images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size,
                         dtype=dtype,
                         device=device)

    # Use `SpatialPad` to match sizes, padding in the end will not affect boxes
    padder = SpatialPad(spatial_size=max_spatial_size,
                        method="end",
                        mode=mode,
                        **kwargs)
    for idx, img in enumerate(input_images):
        images[idx, ...] = padder(img)

    return images, [list(ss) for ss in image_sizes]