def __call__(self, batch: Any): # data is either list of dicts or list of lists is_list_of_dicts = isinstance(batch[0], dict) # loop over items inside of each element in a batch for key_or_idx in batch[0].keys() if is_list_of_dicts else range( len(batch[0])): # calculate max size of each dimension max_shapes = [] for elem in batch: if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): break max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays, else skip as no padding to be done if len(max_shapes) == 0: continue max_shape = np.array(max_shapes).max(axis=0) # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue # Do we need to convert output to Tensor? output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) # Use `SpatialPadd` or `SpatialPad` to match sizes # Default params are central padding, padding with 0's # If input is dictionary, use the dictionary version so that the transformation is recorded padder = SpatialPad(max_shape, self.method, self.mode) # type: ignore transform = padder if not output_to_tensor else Compose( [padder, ToTensor()]) for idx in range(len(batch)): im = batch[idx][key_or_idx] orig_size = im.shape[1:] padded = transform(batch[idx][key_or_idx]) batch = replace_element(padded, batch, idx, key_or_idx) # If we have a dictionary of data, append to list if is_list_of_dicts: self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) # After padding, use default list collator return list_data_collate(batch)
def __init__(self, keys, spatial_size, method='symmetric', mode='constant'): """ Args: keys (hashable items): keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` spatial_size (list): the spatial size of output data after padding. method (str): pad image symmetric on every side or only pad at the end sides. default is 'symmetric'. mode (str): one of the following string values or a user supplied function: {'constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', 'wrap', 'empty', <function>} for more details, please check: https://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html """ super().__init__(keys) self.padder = SpatialPad(spatial_size, method, mode)
def __call__(self, batch: Any): """ Args: batch: batch of data to pad-collate """ # data is either list of dicts or list of lists is_list_of_dicts = isinstance(batch[0], dict) # loop over items inside of each element in a batch batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range( len(batch[0])) for key_or_idx in batch_item: # calculate max size of each dimension max_shapes = [] for elem in batch: if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): break max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays, else skip as no padding to be done if not max_shapes: continue max_shape = np.array(max_shapes).max(axis=0) # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.kwargs) for idx, batch_i in enumerate(batch): orig_size = batch_i[key_or_idx].shape[1:] padded = padder(batch_i[key_or_idx]) batch = replace_element(padded, batch, idx, key_or_idx) # If we have a dictionary of data, append to list if is_list_of_dicts: self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) # After padding, use default list collator return list_data_collate(batch)
def matshow3d( volume, fig=None, title: Optional[str] = None, figsize=(10, 10), frames_per_row: Optional[int] = None, frame_dim: int = -3, channel_dim: Optional[int] = None, vmin=None, vmax=None, every_n: int = 1, interpolation: str = "none", show=False, fill_value=np.nan, margin: int = 1, dtype=np.float32, **kwargs, ): """ Create a 3D volume figure as a grid of images. Args: volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`. Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg. A list of channel-first (C, H[, W, D]) arrays can also be passed in, in which case they will be displayed as a padded and stacked volume. fig: matplotlib figure or Axes to use. If None, a new figure will be created. title: title of the figure. figsize: size of the figure. frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used. frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`. channel_dim: if not None, explicitly specify the channel dimension to be transposed to the last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image. if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W). note that it can only support 3D input image. default is None. vmin: `vmin` for the matplotlib `imshow`. vmax: `vmax` for the matplotlib `imshow`. every_n: factor to subsample the frames so that only every n-th frame is displayed. interpolation: interpolation to use for the matplotlib `matshow`. show: if True, show the figure. fill_value: value to use for the empty part of the grid. margin: margin to use for the grid. dtype: data type of the output stacked frames. kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`. See Also: - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html Example: >>> import numpy as np >>> import matplotlib.pyplot as plt >>> from monai.visualize import matshow3d # create a figure of a 3D volume >>> volume = np.random.rand(10, 10, 10) >>> fig = plt.figure() >>> matshow3d(volume, fig=fig, title="3D Volume") >>> plt.show() # create a figure of a list of channel-first 3D volumes >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)] >>> fig = plt.figure() >>> matshow3d(volumes, fig=fig, title="List of Volumes") >>> plt.show() """ vol = convert_data_type(data=volume, output_type=np.ndarray)[0] if channel_dim is not None: if channel_dim not in [0, 1 ] or vol.shape[channel_dim] not in [1, 3, 4]: raise ValueError( "channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4." ) if isinstance(vol, (list, tuple)): # a sequence of channel-first volumes if not isinstance(vol[0], np.ndarray): raise ValueError("volume must be a list of arrays.") pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0) pad = SpatialPad( pad_size[1:]) # assuming channel-first for item in vol vol = np.concatenate([pad(v) for v in vol], axis=0) else: # ndarray while len(vol.shape) < 3: vol = np.expand_dims( vol, 0) # type: ignore # so that we display 2d as well if channel_dim is not None: # move the expected dim to construct frames with `B` dim vol = np.moveaxis(vol, frame_dim, -4) # type: ignore vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1])) else: vol = np.moveaxis(vol, frame_dim, -3) # type: ignore vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1])) vmin = np.nanmin(vol) if vmin is None else vmin vmax = np.nanmax(vol) if vmax is None else vmax # subsample every_n-th frame of the 3D volume vol = vol[::max(every_n, 1)] if not frames_per_row: frames_per_row = int(np.ceil(np.sqrt(len(vol)))) # create the grid of frames cols = max(min(len(vol), frames_per_row), 1) rows = int(np.ceil(len(vol) / cols)) width = [[0, cols * rows - len(vol)]] if channel_dim is not None: width += [[0, 0]] # add pad width for the channel dim width += [[margin, margin]] * 2 vol = np.pad(vol.astype(dtype, copy=False), width, mode="constant", constant_values=fill_value) # type: ignore im = np.block([[vol[i * cols + j] for j in range(cols)] for i in range(rows)]) if channel_dim is not None: # move channel dim to the end im = np.moveaxis(im, 0, -1) # figure related configurations if isinstance(fig, plt.Axes): ax = fig else: if fig is None: fig = plt.figure(tight_layout=True) if not fig.axes: fig.add_subplot(111) ax = fig.axes[0] ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs) ax.axis("off") if title is not None: ax.set_title(title) if figsize is not None and hasattr(fig, "set_size_inches"): fig.set_size_inches(figsize) if show: plt.show() return fig, im
def pad_list_data_collate( batch: Sequence, method: Union[Method, str] = Method.SYMMETRIC, mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, ): """ Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension. Note: Need to use this collate if apply some transforms that can generate batch data. Args: batch: batch of data to pad-collate method: padding method (see :py:class:`monai.transforms.SpatialPad`) mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) """ list_of_dicts = isinstance(batch[0], dict) for key_or_idx in batch[0].keys() if list_of_dicts else range(len( batch[0])): max_shapes = [] for elem in batch: if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): break max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays if len(max_shapes) == 0: continue max_shape = np.array(max_shapes).max(axis=0) # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue # Do we need to convert output to Tensor? output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) # Use `SpatialPadd` or `SpatialPad` to match sizes # Default params are central padding, padding with 0's # If input is dictionary, use the dictionary version so that the transformation is recorded padder: Union[SpatialPadd, SpatialPad] if list_of_dicts: from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import padder = SpatialPadd(key_or_idx, max_shape, method, mode) # type: ignore else: from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import padder = SpatialPad(max_shape, method, mode) # type: ignore for idx in range(len(batch)): padded = padder( batch[idx])[key_or_idx] if list_of_dicts else padder( batch[idx][key_or_idx]) # since tuple is immutable we'll have to recreate if isinstance(batch[idx], tuple): batch[idx] = list(batch[idx]) # type: ignore batch[idx][key_or_idx] = padded batch[idx] = tuple(batch[idx]) # type: ignore # else, replace else: batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx] if output_to_tensor: batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx]) # After padding, use default list collator return list_data_collate(batch)
def pad_images( input_images: Union[List[Tensor], Tensor], spatial_dims: int, size_divisible: Union[int, Sequence[int]], mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, **kwargs, ) -> Tuple[Tensor, List[List[int]]]: """ Pad the input images, so that the output spatial sizes are divisible by `size_divisible`. It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor. Padded size (H, W) or (H, W, D) is divisible by size_divisible. Default padding uses constant padding with value 0.0 Args: input_images: It can be 1) a tensor sized (B, C, H, W) or (B, C, H, W, D), or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or (C, H_i, W_i, D_i). spatial_dims: number of spatial dimensions of the images, 2D or 3D. size_divisible: int or Sequence[int], is the expected pattern on the input image shape. If an int, the same `size_divisible` will be applied to all the input spatial dimensions. mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html kwargs: other arguments for `torch.pad` function. Return: - images, a (B, C, H, W) or (B, C, H, W, D) Tensor - image_sizes, the original spatial size of each image """ size_divisible = ensure_tuple_rep(size_divisible, spatial_dims) # If input_images: Tensor if isinstance(input_images, Tensor): orig_size = list(input_images.shape[-spatial_dims:]) new_size = compute_divisible_spatial_size(spatial_shape=orig_size, k=size_divisible) all_pad_width = [(0, max(sp_i - orig_size[i], 0)) for i, sp_i in enumerate(new_size)] pt_pad_width = [ val for sublist in all_pad_width for val in sublist[::-1] ][::-1] if max(pt_pad_width) == 0: # if there is no need to pad return input_images, [orig_size] * input_images.shape[0] mode_: str = convert_pad_mode(dst=input_images, mode=mode) return F.pad(input_images, pt_pad_width, mode=mode_, **kwargs), [orig_size] * input_images.shape[0] # If input_images: List[Tensor]) image_sizes = [img.shape[-spatial_dims:] for img in input_images] in_channels = input_images[0].shape[0] dtype = input_images[0].dtype device = input_images[0].device # compute max_spatial_size image_sizes_t = torch.tensor(image_sizes) max_spatial_size_t, _ = torch.max(image_sizes_t, dim=0) if len(max_spatial_size_t) != spatial_dims or len( size_divisible) != spatial_dims: raise ValueError( " Require len(max_spatial_size_t) == spatial_dims ==len(size_divisible)." ) max_spatial_size = compute_divisible_spatial_size( spatial_shape=list(max_spatial_size_t), k=size_divisible) # allocate memory for the padded images images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size, dtype=dtype, device=device) # Use `SpatialPad` to match sizes, padding in the end will not affect boxes padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs) for idx, img in enumerate(input_images): images[idx, ...] = padder(img) return images, [list(ss) for ss in image_sizes]