Beispiel #1
0
def make_nifti_image(array: NdarrayOrTensor,
                     affine=None,
                     dir=None,
                     fname=None,
                     suffix=".nii.gz",
                     verbose=False):
    """
    Create a temporary nifti image on the disk and return the image name.
    User is responsible for deleting the temporary file when done with it.
    """
    if isinstance(array, torch.Tensor):
        array, *_ = convert_data_type(array, np.ndarray)
    if isinstance(affine, torch.Tensor):
        affine, *_ = convert_data_type(affine, np.ndarray)
    if affine is None:
        affine = np.eye(4)
    test_image = nib.Nifti1Image(array, affine)

    # if dir not given, create random. Else, make sure it exists.
    if dir is None:
        dir = tempfile.mkdtemp()
    else:
        os.makedirs(dir, exist_ok=True)

    # If fname not given, get random one. Else, concat dir, fname and suffix.
    if fname is None:
        temp_f, fname = tempfile.mkstemp(suffix=suffix, dir=dir)
        os.close(temp_f)
    else:
        fname = os.path.join(dir, fname + suffix)

    nib.save(test_image, fname)
    if verbose:
        print(f"File written: {fname}.")
    return fname
Beispiel #2
0
def select_labels(labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor],
                  keep: NdarrayOrTensor) -> Union[Tuple, NdarrayOrTensor]:
    """
    For element in labels, select indice keep from it.

    Args:
        labels: Sequence of array. Each element represents classification labels or scores
            corresponding to ``boxes``, sized (N,).
        keep: the indices to keep, same length with each element in labels.

    Return:
        selected labels, does not share memory with original labels.
    """
    labels_tuple = ensure_tuple(labels, True)

    labels_select_list = []
    keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0]
    for i in range(len(labels_tuple)):
        labels_t: torch.Tensor = convert_data_type(labels_tuple[i],
                                                   torch.Tensor)[0]
        labels_t = labels_t[keep_t, ...]
        labels_select_list.append(
            convert_to_dst_type(src=labels_t, dst=labels_tuple[i])[0])

    if isinstance(labels, (torch.Tensor, np.ndarray)):
        return labels_select_list[0]  # type: ignore

    return tuple(labels_select_list)
Beispiel #3
0
def fftn_centered_t(im: Tensor,
                    spatial_dims: int,
                    is_complex: bool = True) -> Tensor:
    """
    Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care
    of the required ifft and fft shifts.
    This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift

    Args:
        im: image that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)
        is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels)

    Returns:
        "out" which is the output kspace (fourier of im)

    Example:

        .. code-block:: python

            import torch
            im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts
            # output1 and output2 will be identical
            output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho")
            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )

            output2 = fftn_centered(im, spatial_dims=2, is_complex=True)
    """
    # define spatial dims to perform ifftshift, fftshift, and fft
    shift = tuple(range(-spatial_dims, 0))
    if is_complex:
        if im.shape[-1] != 2:
            raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).")
        shift = tuple(range(-spatial_dims - 1, -1))
    dims = tuple(range(-spatial_dims, 0))

    # apply fft
    if hasattr(torch.fft, "ifftshift"):  # ifftshift was added in pytorch 1.8
        x = torch.fft.ifftshift(im, dim=shift)
    else:
        x = ifftshift(im, shift)

    if is_complex:
        x = torch.view_as_real(
            torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho"))
    else:
        x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho"))

    if hasattr(torch.fft, "fftshift"):
        out = convert_data_type(torch.fft.fftshift(x, dim=shift),
                                torch.Tensor)[0]
    else:
        out = convert_data_type(fftshift(x, shift), torch.Tensor)[0]

    return out
Beispiel #4
0
    def astype(self, dtype, device=None, *unused_args, **unused_kwargs):
        """
        Cast to ``dtype``, sharing data whenever possible.

        Args:
            dtype: dtypes such as np.float32, torch.float, "np.float32", float.
            device: the device if `dtype` is a torch data type.
            unused_args: additional args (currently unused).
            unused_kwargs: additional kwargs (currently unused).

        Returns:
            data array instance
        """
        if isinstance(dtype, str):
            mod_str, *dtype = dtype.split(".", 1)
            dtype = mod_str if not dtype else dtype[0]
        else:
            mod_str = getattr(dtype, "__module__", "torch")
        mod_str = look_up_option(mod_str, {"torch", "numpy", "np"},
                                 default="numpy")
        if mod_str == "torch":
            out_type = torch.Tensor
        elif mod_str in ("numpy", "np"):
            out_type = np.ndarray
        else:
            out_type = None
        return convert_data_type(self,
                                 output_type=out_type,
                                 device=device,
                                 dtype=dtype,
                                 wrap_sequence=True)[0]
Beispiel #5
0
 def get_label_rgb(cmap: str, label: NdarrayOrTensor):
     _cmap = cm.get_cmap(cmap)
     label_np, *_ = convert_data_type(label, np.ndarray)
     label_rgb_np = _cmap(label_np[0])
     label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3]
     label_rgb, *_ = convert_to_dst_type(label_rgb_np, label)
     return label_rgb
Beispiel #6
0
def boxes_center_distance(
    boxes1: NdarrayOrTensor,
    boxes2: NdarrayOrTensor,
    euclidean: bool = True
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]:
    """
    Distance of center points between two sets of boxes

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        euclidean: computed the euclidean distance otherwise it uses the l1 distance

    Returns:
        - The pairwise distances for every element in boxes1 and boxes2,
          with size of (N,M) and same data type as ``boxes1``.
        - Center points of boxes1, with size of (N,spatial_dims) and same data type as ``boxes1``.
        - Center points of boxes2, with size of (M,spatial_dims) and same data type as ``boxes1``.

    Reference:
        https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/ops.py

    """

    if not isinstance(boxes1, type(boxes2)):
        warnings.warn(
            f"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}."
        )

    # convert numpy to tensor if needed
    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)
    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)

    center1 = box_centers(boxes1_t.to(COMPUTE_DTYPE))  # (N, spatial_dims)
    center2 = box_centers(boxes2_t.to(COMPUTE_DTYPE))  # (M, spatial_dims)

    if euclidean:
        dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt()
    else:
        # before sum: (N, M, spatial_dims)
        dists = (center1[:, None] - center2[None]).sum(-1)

    # convert tensor back to numpy if needed
    (dists, center1, center2), *_ = convert_to_dst_type(src=(dists, center1,
                                                             center2),
                                                        dst=boxes1)
    return dists, center1, center2
Beispiel #7
0
def make_nifti_image(array: NdarrayOrTensor, affine=None):
    """
    Create a temporary nifti image on the disk and return the image name.
    User is responsible for deleting the temporary file when done with it.
    """
    if isinstance(array, torch.Tensor):
        array, *_ = convert_data_type(array, np.ndarray)
    if isinstance(affine, torch.Tensor):
        affine, *_ = convert_data_type(affine, np.ndarray)
    if affine is None:
        affine = np.eye(4)
    test_image = nib.Nifti1Image(array, affine)

    temp_f, image_name = tempfile.mkstemp(suffix=".nii.gz")
    nib.save(test_image, image_name)
    os.close(temp_f)
    return image_name
Beispiel #8
0
def convert_mask_to_box(
        boxes_mask: NdarrayOrTensor,
        bg_label: int = -1,
        box_dtype=torch.float32,
        label_dtype=torch.long) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
    """
    Convert int16 mask image to box, which has the same size with the input image

    Args:
        boxes_mask: int16 array, sized (num_box, H, W). Each channel represents a box.
            The foreground region in channel c has intensity of labels[c].
            The background intensity is bg_label.
        bg_label: background labels for the boxes_mask
        box_dtype: output dtype for boxes
        label_dtype: output dtype for labels

    Return:
        - bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.
        - classification foreground(fg) labels, dtype should be int, sized (N,).
    """
    look_up_option(len(boxes_mask.shape), [3, 4])
    spatial_size = list(boxes_mask.shape[1:])
    spatial_dims = get_spatial_dims(spatial_size=spatial_size)

    boxes_mask_np, *_ = convert_data_type(boxes_mask, np.ndarray)

    boxes_list = []
    labels_list = []
    for b in range(boxes_mask_np.shape[0]):
        fg_indices = np.nonzero(boxes_mask_np[b, ...] - bg_label)
        if fg_indices[0].shape[0] == 0:
            continue
        boxes_b = []
        for fd_i in fg_indices:
            boxes_b.append(min(fd_i))  # top left corner
        for fd_i in fg_indices:
            boxes_b.append(max(fd_i) + 1 - TO_REMOVE)  # bottom right corner
        boxes_list.append(boxes_b)
        if spatial_dims == 2:
            labels_list.append(boxes_mask_np[b, fg_indices[0][0],
                                             fg_indices[1][0]])
        if spatial_dims == 3:
            labels_list.append(boxes_mask_np[b, fg_indices[0][0],
                                             fg_indices[1][0],
                                             fg_indices[2][0]])

    if len(boxes_list) == 0:
        boxes_np, labels_np = np.zeros([0, 2 * spatial_dims]), np.zeros([0])
    else:
        boxes_np, labels_np = np.asarray(boxes_list), np.asarray(labels_list)
    boxes, *_ = convert_to_dst_type(src=boxes_np,
                                    dst=boxes_mask,
                                    dtype=box_dtype)
    labels, *_ = convert_to_dst_type(src=labels_np,
                                     dst=boxes_mask,
                                     dtype=label_dtype)
    return boxes, labels
Beispiel #9
0
def convert_box_mode(
    boxes: NdarrayOrTensor,
    src_mode: Union[str, BoxMode, Type[BoxMode], None] = None,
    dst_mode: Union[str, BoxMode, Type[BoxMode], None] = None,
) -> NdarrayOrTensor:
    """
    This function converts the boxes in src_mode to the dst_mode.

    Args:
        boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray.
        src_mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.
            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.
        dst_mode: target box mode. If it is not given, this func will assume it is ``StandardMode()``.
            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.

    Returns:
        bounding boxes with target mode, with same data type as ``boxes``, does not share memory with ``boxes``

    Example:
        .. code-block:: python

            boxes = torch.ones(10,4)
            # The following three lines are equivalent
            # They convert boxes with format [xmin, ymin, xmax, ymax] to [xcenter, ycenter, xsize, ysize].
            convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode="ccwh")
            convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode)
            convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode())
    """
    src_boxmode = get_boxmode(src_mode)
    dst_boxmode = get_boxmode(dst_mode)

    # if mode not changed, deepcopy the original boxes
    if isinstance(src_boxmode, type(dst_boxmode)):
        return deepcopy(boxes)

    # convert box mode
    # convert numpy to tensor if needed
    boxes_t, *_ = convert_data_type(boxes, torch.Tensor)

    # convert boxes to corners
    corners = src_boxmode.boxes_to_corners(boxes_t)

    # check validity of corners
    spatial_dims = get_spatial_dims(boxes=boxes_t)
    for axis in range(0, spatial_dims):
        if (corners[spatial_dims + axis] < corners[axis]).sum() > 0:
            warnings.warn(
                "Given boxes has invalid values. The box size must be non-negative."
            )

    # convert corners to boxes
    boxes_t_dst = dst_boxmode.corners_to_boxes(corners)

    # convert tensor back to numpy if needed
    boxes_dst, *_ = convert_to_dst_type(src=boxes_t_dst, dst=boxes)
    return boxes_dst
Beispiel #10
0
 def test_convert_data_type(self, in_image, im_out):
     converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out))
     # check input is unchanged
     self.assertEqual(type(in_image), orig_type)
     if isinstance(in_image, torch.Tensor):
         self.assertEqual(in_image.device, orig_device)
     # check output is desired type
     self.assertEqual(type(converted_im), type(im_out))
     # check dtype is unchanged
     if isinstance(in_type, (np.ndarray, torch.Tensor)):
         self.assertEqual(converted_im.dtype, im_out.dtype)
Beispiel #11
0
 def test_convert_list(self, in_image, im_out, wrap):
     output_type = type(im_out) if wrap else type(im_out[0])
     converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap)
     # check output is desired type
     if not wrap:
         converted_im = converted_im[0]
         im_out = im_out[0]
     self.assertEqual(type(converted_im), type(im_out))
     # check dtype is unchanged
     if isinstance(in_type, (np.ndarray, torch.Tensor)):
         self.assertEqual(converted_im.dtype, im_out.dtype)
Beispiel #12
0
def box_iou(boxes1: NdarrayOrTensor,
            boxes2: NdarrayOrTensor) -> NdarrayOrTensor:
    """
    Compute the intersection over union (IoU) of two set of boxes.

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        IoU, with size of (N,M) and same data type as ``boxes1``

    """

    if not isinstance(boxes1, type(boxes2)):
        warnings.warn(
            f"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}."
        )

    # convert numpy to tensor if needed
    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)
    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)

    # we do computation with compute_dtype to avoid overflow
    box_dtype = boxes1_t.dtype

    inter, union = _box_inter_union(boxes1_t,
                                    boxes2_t,
                                    compute_dtype=COMPUTE_DTYPE)

    # compute IoU and convert back to original box_dtype
    iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps)  # (N,M)
    iou_t = iou_t.to(dtype=box_dtype)

    # check if NaN or Inf
    if torch.isnan(iou_t).any() or torch.isinf(iou_t).any():
        raise ValueError("Box IoU is NaN or Inf.")

    # convert tensor back to numpy if needed
    iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1)
    return iou
Beispiel #13
0
def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor:
    """`torch.mode` with equivalent implementation for numpy.

    Args:
        x: array/tensor
        dim: dimension along which to perform `mode` (referred to as `axis` by numpy)
        to_long: convert input to long before performing mode.
    """
    dtype = torch.int64 if to_long else None
    x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype)
    o_t = torch.mode(x_t, dim).values
    o, *_ = convert_to_dst_type(o_t, x)
    return o
Beispiel #14
0
    def test_skipped_transform_consistency(self, im, in_dtype):
        t1 = RandAffine(prob=0)
        t2 = RandAffine(prob=1, spatial_size=(10, 11))

        im, *_ = convert_data_type(im, dtype=in_dtype)

        out1 = t1(im)
        out2 = t2(im)

        # check same type
        self.assertEqual(type(out1), type(out2))
        # check matching dtype
        self.assertEqual(out1.dtype, out2.dtype)
Beispiel #15
0
def batched_nms(
    boxes: NdarrayOrTensor,
    scores: NdarrayOrTensor,
    labels: NdarrayOrTensor,
    nms_thresh: float,
    max_proposals: int = -1,
    box_overlap_metric: Callable = box_iou,
) -> NdarrayOrTensor:
    """
    Performs non-maximum suppression in a batched fashion.
    Each labels value correspond to a category, and NMS will not be applied between elements of different categories.

    Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/nms.py

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        scores: prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.
        labels: indices of the categories for each one of the boxes. sized(N,), value range is (0, num_classes)
        nms_thresh: threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.
        max_proposals: maximum number of boxes it keeps.
            If ``max_proposals`` = -1, there is no limit on the number of boxes that are kept.
        box_overlap_metric: the metric to compute overlap between boxes.

    Returns:
        Indexes of ``boxes`` that are kept after NMS.
    """
    # returns empty array if boxes is empty
    if boxes.shape[0] == 0:
        return convert_to_dst_type(src=np.array([]),
                                   dst=boxes,
                                   dtype=torch.long)[0]

    # convert numpy to tensor if needed
    boxes_t, *_ = convert_data_type(boxes, torch.Tensor, dtype=torch.float32)
    scores_t, *_ = convert_to_dst_type(scores, boxes_t)
    labels_t, *_ = convert_to_dst_type(labels, boxes_t, dtype=torch.long)

    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
    max_coordinate = boxes_t.max()
    offsets = labels_t.to(boxes_t) * (max_coordinate + 1)
    boxes_for_nms = boxes + offsets[:, None]
    keep = non_max_suppression(boxes_for_nms, scores_t, nms_thresh,
                               max_proposals, box_overlap_metric)

    # convert tensor back to numpy if needed
    return convert_to_dst_type(src=keep, dst=boxes, dtype=keep.dtype)[0]
Beispiel #16
0
def box_area(boxes: NdarrayOrTensor) -> NdarrayOrTensor:
    """
    This function computes the area (2D) or volume (3D) of each box.
    Half precision is not recommended for this function as it may cause overflow, especially for 3D images.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        area (2D) or volume (3D) of boxes, with size of (N,).

    Example:
        .. code-block:: python

            boxes = torch.ones(10,6)
            # we do computation with torch.float32 to avoid overflow
            compute_dtype = torch.float32
            area = box_area(boxes=boxes.to(dtype=compute_dtype))  # torch.float32, size of (10,)
    """

    if not is_valid_box_values(boxes):
        raise ValueError(
            "Given boxes has invalid values. The box size must be non-negative."
        )

    spatial_dims = get_spatial_dims(boxes=boxes)

    area = boxes[:, spatial_dims] - boxes[:, 0] + TO_REMOVE
    for axis in range(1, spatial_dims):
        area = area * (boxes[:, axis + spatial_dims] - boxes[:, axis] +
                       TO_REMOVE)

    # convert numpy to tensor if needed
    area_t, *_ = convert_data_type(area, torch.Tensor)

    # check if NaN or Inf, especially for half precision
    if area_t.isnan().any() or area_t.isinf().any():
        if area_t.dtype is torch.float16:
            raise ValueError(
                "Box area is NaN or Inf. boxes is float16. Please change to float32 and test it again."
            )
        else:
            raise ValueError("Box area is NaN or Inf.")

    return area
Beispiel #17
0
def apply_affine_to_boxes(boxes: NdarrayOrTensor,
                          affine: NdarrayOrTensor) -> NdarrayOrTensor:
    """
    This function applies affine matrices to the boxes

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode
        affine: affine matrix to be applied to the box coordinates, sized (spatial_dims+1,spatial_dims+1)

    Returns:
        returned affine transformed boxes, with same data type as ``boxes``, does not share memory with ``boxes``
    """

    # convert numpy to tensor if needed
    boxes_t, *_ = convert_data_type(boxes, torch.Tensor)

    # some operation does not support torch.float16
    # convert to float32

    boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE)
    affine_t, *_ = convert_to_dst_type(src=affine, dst=boxes_t)

    spatial_dims = get_spatial_dims(boxes=boxes_t)

    # affine transform left top and bottom right points
    # might flipped, thus lt may not be left top any more
    lt: torch.Tensor = _apply_affine_to_points(boxes_t[:, :spatial_dims],
                                               affine_t,
                                               include_shift=True)
    rb: torch.Tensor = _apply_affine_to_points(boxes_t[:, spatial_dims:],
                                               affine_t,
                                               include_shift=True)

    # make sure lt_new is left top, and rb_new is bottom right
    lt_new, _ = torch.min(torch.stack([lt, rb], dim=2), dim=2)
    rb_new, _ = torch.max(torch.stack([lt, rb], dim=2), dim=2)

    boxes_t_affine = torch.cat([lt_new, rb_new], dim=1)

    # convert tensor back to numpy if needed
    boxes_affine: NdarrayOrTensor
    boxes_affine, *_ = convert_to_dst_type(src=boxes_t_affine, dst=boxes)
    return boxes_affine
Beispiel #18
0
def fftn_centered(im: NdarrayOrTensor,
                  spatial_dims: int,
                  is_complex: bool = True) -> NdarrayOrTensor:
    """
    Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care
    of the required ifft and fft shifts. This function calls monai.metworks.blocks.fft_utils_t.fftn_centered_t.
    This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift

    Args:
        im: image that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)
        is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels)

    Returns:
        "out" which is the output kspace (fourier of im)

    Example:

        .. code-block:: python

            import torch
            im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts
            # output1 and output2 will be identical
            output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho")
            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )

            output2 = fftn_centered(im, spatial_dims=2, is_complex=True)
    """
    # handle numpy format
    im_t, *_ = convert_data_type(im, torch.Tensor)

    # compute ifftn
    out_t = fftn_centered_t(im_t,
                            spatial_dims=spatial_dims,
                            is_complex=is_complex)

    # handle numpy format
    out, *_ = convert_to_dst_type(src=out_t, dst=im)
    return out
Beispiel #19
0
    def get_array(self,
                  output_type=np.ndarray,
                  dtype=None,
                  device=None,
                  *_args,
                  **_kwargs):
        """
        Returns a new array in `output_type`, the array shares the same underlying storage when the output is a
        numpy array. Changes to self tensor will be reflected in the ndarray and vice versa.

        Args:
            output_type: output type, see also: :py:func:`monai.utils.convert_data_type`.
            dtype: dtype of output data. Converted to correct library type (e.g.,
                `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
                If left blank, it remains unchanged.
            device: if the output is a `torch.Tensor`, select device (if `None`, unchanged).
            _args: currently unused parameters.
            _kwargs: currently unused parameters.
        """
        return convert_data_type(self,
                                 output_type=output_type,
                                 dtype=dtype,
                                 device=device,
                                 wrap_sequence=True)[0]
Beispiel #20
0
def box_pair_giou(boxes1: NdarrayOrTensor,
                  boxes2: NdarrayOrTensor) -> NdarrayOrTensor:
    """
    Compute the generalized intersection over union (GIoU) of a pair of boxes.
    The two inputs should have the same shape and the func return an (N,) array,
    (in contrary to :func:`~monai.data.box_utils.box_giou` , which does not require the inputs to have the same
    shape and returns ``NxM`` matrix).

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``

    Returns:
        paired GIoU, with size of (N,) and same data type as ``boxes1``

    Reference:
        https://giou.stanford.edu/GIoU.pdf

    """

    if not isinstance(boxes1, type(boxes2)):
        warnings.warn(
            f"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}."
        )

    # convert numpy to tensor if needed
    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)
    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)

    if boxes1_t.shape != boxes2_t.shape:
        raise ValueError(
            "boxes1 and boxes2 should be paired and have same shape.")

    spatial_dims = get_spatial_dims(boxes=boxes1_t)

    # we do computation with compute_dtype to avoid overflow
    box_dtype = boxes1_t.dtype

    # compute area
    area1 = box_area(boxes=boxes1_t.to(dtype=COMPUTE_DTYPE))  # (N,)
    area2 = box_area(boxes=boxes2_t.to(dtype=COMPUTE_DTYPE))  # (N,)

    # Intersection
    # get the left top and right bottom points for the boxes pair
    lt = torch.max(boxes1_t[:, :spatial_dims], boxes2_t[:, :spatial_dims]).to(
        dtype=COMPUTE_DTYPE)  # (N,spatial_dims) left top
    rb = torch.min(boxes1_t[:, spatial_dims:], boxes2_t[:, spatial_dims:]).to(
        dtype=COMPUTE_DTYPE)  # (N,spatial_dims) right bottom

    # compute size for the intersection region for the boxes pair
    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # (N,spatial_dims)
    inter = torch.prod(wh, dim=-1, keepdim=False)  # (N,)

    # compute IoU and convert back to original box_dtype
    union = area1 + area2 - inter
    iou = inter / (union + torch.finfo(COMPUTE_DTYPE).eps)  # (N,)

    # Enclosure
    # get the left top and right bottom points for the boxes pair
    lt = torch.min(boxes1_t[:, :spatial_dims], boxes2_t[:, :spatial_dims]).to(
        dtype=COMPUTE_DTYPE)  # (N,spatial_dims) left top
    rb = torch.max(boxes1_t[:, spatial_dims:], boxes2_t[:, spatial_dims:]).to(
        dtype=COMPUTE_DTYPE)  # (N,spatial_dims) right bottom

    # compute size for the enclose region for the boxes pair
    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # (N,spatial_dims)
    enclosure = torch.prod(wh, dim=-1, keepdim=False)  # (N,)

    giou_t = iou - (enclosure - union) / (enclosure +
                                          torch.finfo(COMPUTE_DTYPE).eps)
    giou_t = giou_t.to(dtype=box_dtype)  # (N,spatial_dims)
    if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
        raise ValueError("Box GIoU is NaN or Inf.")

    # convert tensor back to numpy if needed
    giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
    return giou
Beispiel #21
0
def non_max_suppression(
    boxes: NdarrayOrTensor,
    scores: NdarrayOrTensor,
    nms_thresh: float,
    max_proposals: int = -1,
    box_overlap_metric: Callable = box_iou,
) -> NdarrayOrTensor:
    """
    Non-maximum suppression (NMS).

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        scores: prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.
        nms_thresh: threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.
        max_proposals: maximum number of boxes it keeps.
            If ``max_proposals`` = -1, there is no limit on the number of boxes that are kept.
        box_overlap_metric: the metric to compute overlap between boxes.

    Returns:
        Indexes of ``boxes`` that are kept after NMS.

    Example:
        .. code-block:: python

            boxes = torch.ones(10,6)
            scores = torch.ones(10)
            keep = non_max_suppression(boxes, scores, num_thresh=0.1)
            boxes_after_nms = boxes[keep]
    """

    # returns empty array if boxes is empty
    if boxes.shape[0] == 0:
        return convert_to_dst_type(src=np.array([]),
                                   dst=boxes,
                                   dtype=torch.long)[0]

    if boxes.shape[0] != scores.shape[0]:
        raise ValueError(
            f"boxes and scores should have same length, got boxes shape {boxes.shape}, scores shape {scores.shape}"
        )

    # convert numpy to tensor if needed
    boxes_t, *_ = convert_data_type(boxes, torch.Tensor)
    scores_t, *_ = convert_to_dst_type(scores, boxes_t)

    # sort boxes in desending order according to the scores
    sort_idxs = torch.argsort(scores_t, dim=0, descending=True)
    boxes_sort = deepcopy(boxes_t)[sort_idxs, :]

    # initialize the list of picked indexes
    pick = []
    idxs = torch.Tensor(list(range(0, boxes_sort.shape[0]))).to(torch.long)

    # keep looping while some indexes still remain in the indexes list
    while len(idxs) > 0:
        # pick the first index in the indexes list and add the index value to the list of picked indexes
        i = int(idxs[0].item())
        pick.append(i)
        if len(pick) >= max_proposals >= 1:
            break

        # compute the IoU between the rest of the boxes and the box just picked
        box_overlap = box_overlap_metric(boxes_sort[idxs, :],
                                         boxes_sort[i:i + 1, :])

        # keep only indexes from the index list that have overlap < nms_thresh
        to_keep_idx = (box_overlap <= nms_thresh).flatten()
        to_keep_idx[0] = False  # always remove idxs[0]
        idxs = idxs[to_keep_idx]

    # return only the bounding boxes that were picked using the integer data type
    pick_idx = sort_idxs[pick]

    # convert tensor back to numpy if needed
    return convert_to_dst_type(src=pick_idx, dst=boxes,
                               dtype=pick_idx.dtype)[0]
Beispiel #22
0
def box_giou(boxes1: NdarrayOrTensor,
             boxes2: NdarrayOrTensor) -> NdarrayOrTensor:
    """
    Compute the generalized intersection over union (GIoU) of two sets of boxes.
    The two inputs can have different shapes and the func return an NxM matrix,
    (in contrary to :func:`~monai.data.box_utils.box_pair_giou` , which requires the inputs to have the same
    shape and returns ``N`` values).

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        GIoU, with size of (N,M) and same data type as ``boxes1``

    Reference:
        https://giou.stanford.edu/GIoU.pdf

    """

    if not isinstance(boxes1, type(boxes2)):
        warnings.warn(
            f"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}."
        )

    # convert numpy to tensor if needed
    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)
    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)

    spatial_dims = get_spatial_dims(boxes=boxes1_t)

    # we do computation with compute_dtype to avoid overflow
    box_dtype = boxes1_t.dtype

    inter, union = _box_inter_union(boxes1_t,
                                    boxes2_t,
                                    compute_dtype=COMPUTE_DTYPE)
    iou = inter / (union + torch.finfo(COMPUTE_DTYPE).eps)  # (N,M)

    # Enclosure
    # get the left top and right bottom points for the NxM combinations
    lt = torch.min(boxes1_t[:, None, :spatial_dims],
                   boxes2_t[:, :spatial_dims]).to(
                       dtype=COMPUTE_DTYPE)  # (N,M,spatial_dims) left top
    rb = torch.max(boxes1_t[:, None, spatial_dims:],
                   boxes2_t[:, spatial_dims:]).to(
                       dtype=COMPUTE_DTYPE)  # (N,M,spatial_dims) right bottom

    # compute size for the enclosure region for the NxM combinations
    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # (N,M,spatial_dims)
    enclosure = torch.prod(wh, dim=-1, keepdim=False)  # (N,M)

    # GIoU
    giou_t = iou - (enclosure - union) / (enclosure +
                                          torch.finfo(COMPUTE_DTYPE).eps)
    giou_t = giou_t.to(dtype=box_dtype)
    if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
        raise ValueError("Box GIoU is NaN or Inf.")

    # convert tensor back to numpy if needed
    giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
    return giou
Beispiel #23
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.apps.reconstruction.complex_utils import complex_abs, complex_conj, complex_mul, convert_to_tensor_complex
from monai.utils.type_conversion import convert_data_type
from tests.utils import TEST_NDARRAYS, assert_allclose

# test case for convert_to_tensor_complex
im_complex = [[1.0 + 1.0j, 1.0 + 1.0j], [1.0 + 1.0j, 1.0 + 1.0j]]
expected_shape = convert_data_type((2, 2, 2), torch.Tensor)[0]
TESTS = [(im_complex, expected_shape)]
for p in TEST_NDARRAYS:
    TESTS.append((p(im_complex), expected_shape))

# test case for complex_abs
im = [[3.0, 4.0], [3.0, 4.0]]
res = [5.0, 5.0]
TESTSC = []
for p in TEST_NDARRAYS:
    TESTSC.append((p(im), p(res)))

# test case for complex_mul
x = [[1.0, 2.0], [3.0, 4.0]]
y = [[1.0, 1.0], [1.0, 1.0]]
res = [[-1.0, 3.0], [-1.0, 7.0]]  # type: ignore
Beispiel #24
0
 def test_neg_stride(self):
     _ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor)
Beispiel #25
0
 def _convert(x):
     if isinstance(x, (MetaTensor, torch.Tensor, tuple, list)):
         return convert_data_type(x,
                                  output_type=np.ndarray,
                                  wrap_sequence=False)[0]
     return x
Beispiel #26
0
def matshow3d(
    volume,
    fig=None,
    title: Optional[str] = None,
    figsize=(10, 10),
    frames_per_row: Optional[int] = None,
    frame_dim: int = -3,
    channel_dim: Optional[int] = None,
    vmin=None,
    vmax=None,
    every_n: int = 1,
    interpolation: str = "none",
    show=False,
    fill_value=np.nan,
    margin: int = 1,
    dtype=np.float32,
    **kwargs,
):
    """
    Create a 3D volume figure as a grid of images.

    Args:
        volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`.
            Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg.
            A list of channel-first (C, H[, W, D]) arrays can also be passed in,
            in which case they will be displayed as a padded and stacked volume.
        fig: matplotlib figure or Axes to use. If None, a new figure will be created.
        title: title of the figure.
        figsize: size of the figure.
        frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used.
        frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to
            the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`.
        channel_dim: if not None, explicitly specify the channel dimension to be transposed to the
            last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image.
            if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W).
            note that it can only support 3D input image. default is None.
        vmin: `vmin` for the matplotlib `imshow`.
        vmax: `vmax` for the matplotlib `imshow`.
        every_n: factor to subsample the frames so that only every n-th frame is displayed.
        interpolation: interpolation to use for the matplotlib `matshow`.
        show: if True, show the figure.
        fill_value: value to use for the empty part of the grid.
        margin: margin to use for the grid.
        dtype: data type of the output stacked frames.
        kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`.

    See Also:
        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html
        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html

    Example:

        >>> import numpy as np
        >>> import matplotlib.pyplot as plt
        >>> from monai.visualize import matshow3d
        # create a figure of a 3D volume
        >>> volume = np.random.rand(10, 10, 10)
        >>> fig = plt.figure()
        >>> matshow3d(volume, fig=fig, title="3D Volume")
        >>> plt.show()
        # create a figure of a list of channel-first 3D volumes
        >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)]
        >>> fig = plt.figure()
        >>> matshow3d(volumes, fig=fig, title="List of Volumes")
        >>> plt.show()

    """
    vol = convert_data_type(data=volume, output_type=np.ndarray)[0]
    if channel_dim is not None:
        if channel_dim not in [0, 1
                               ] or vol.shape[channel_dim] not in [1, 3, 4]:
            raise ValueError(
                "channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4."
            )

    if isinstance(vol, (list, tuple)):
        # a sequence of channel-first volumes
        if not isinstance(vol[0], np.ndarray):
            raise ValueError("volume must be a list of arrays.")
        pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0)
        pad = SpatialPad(
            pad_size[1:])  # assuming channel-first for item in vol
        vol = np.concatenate([pad(v) for v in vol], axis=0)
    else:  # ndarray
        while len(vol.shape) < 3:
            vol = np.expand_dims(
                vol, 0)  # type: ignore  # so that we display 2d as well

    if channel_dim is not None:  # move the expected dim to construct frames with `B` dim
        vol = np.moveaxis(vol, frame_dim, -4)  # type: ignore
        vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1]))
    else:
        vol = np.moveaxis(vol, frame_dim, -3)  # type: ignore
        vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1]))
    vmin = np.nanmin(vol) if vmin is None else vmin
    vmax = np.nanmax(vol) if vmax is None else vmax

    # subsample every_n-th frame of the 3D volume
    vol = vol[::max(every_n, 1)]
    if not frames_per_row:
        frames_per_row = int(np.ceil(np.sqrt(len(vol))))
    # create the grid of frames
    cols = max(min(len(vol), frames_per_row), 1)
    rows = int(np.ceil(len(vol) / cols))
    width = [[0, cols * rows - len(vol)]]
    if channel_dim is not None:
        width += [[0, 0]]  # add pad width for the channel dim
    width += [[margin, margin]] * 2
    vol = np.pad(vol.astype(dtype, copy=False),
                 width,
                 mode="constant",
                 constant_values=fill_value)  # type: ignore
    im = np.block([[vol[i * cols + j] for j in range(cols)]
                   for i in range(rows)])
    if channel_dim is not None:
        # move channel dim to the end
        im = np.moveaxis(im, 0, -1)

    # figure related configurations
    if isinstance(fig, plt.Axes):
        ax = fig
    else:
        if fig is None:
            fig = plt.figure(tight_layout=True)
        if not fig.axes:
            fig.add_subplot(111)
        ax = fig.axes[0]
    ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs)
    ax.axis("off")

    if title is not None:
        ax.set_title(title)
    if figsize is not None and hasattr(fig, "set_size_inches"):
        fig.set_size_inches(figsize)
    if show:
        plt.show()
    return fig, im
Beispiel #27
0
    def test_value(self, input_data, mode2, expected_box, expected_area):
        expected_box = convert_data_type(expected_box, dtype=np.float32)[0]
        boxes1 = convert_data_type(input_data["boxes"], dtype=np.float32)[0]
        mode1 = input_data["mode"]
        half_bool = input_data["half"]
        spatial_size = input_data["spatial_size"]

        # test float16
        if half_bool:
            boxes1 = convert_data_type(boxes1, dtype=np.float16)[0]
            expected_box = convert_data_type(expected_box, dtype=np.float16)[0]

        # test convert_box_mode, convert_box_to_standard_mode
        result2 = convert_box_mode(boxes=boxes1, src_mode=mode1, dst_mode=mode2)
        assert_allclose(result2, expected_box, type_test=True, device_test=True, atol=0.0)

        result1 = convert_box_mode(boxes=result2, src_mode=mode2, dst_mode=mode1)
        assert_allclose(result1, boxes1, type_test=True, device_test=True, atol=0.0)

        result_standard = convert_box_to_standard_mode(boxes=boxes1, mode=mode1)
        expected_box_standard = convert_box_to_standard_mode(boxes=expected_box, mode=mode2)
        assert_allclose(result_standard, expected_box_standard, type_test=True, device_test=True, atol=0.0)

        # test box_area, box_iou, box_giou, box_pair_giou
        assert_allclose(box_area(result_standard), expected_area, type_test=True, device_test=True, atol=0.0)
        iou_metrics = (box_iou, box_giou)
        for p in iou_metrics:
            self_iou = p(boxes1=result_standard[1:2, :], boxes2=result_standard[1:1, :])
            assert_allclose(self_iou, np.array([[]]), type_test=False)

            self_iou = p(boxes1=result_standard[1:2, :], boxes2=result_standard[1:2, :])
            assert_allclose(self_iou, np.array([[1.0]]), type_test=False)

        self_iou = box_pair_giou(boxes1=result_standard[1:1, :], boxes2=result_standard[1:1, :])
        assert_allclose(self_iou, np.array([]), type_test=False)

        self_iou = box_pair_giou(boxes1=result_standard[1:2, :], boxes2=result_standard[1:2, :])
        assert_allclose(self_iou, np.array([1.0]), type_test=False)

        # test box_centers, centers_in_boxes, boxes_center_distance
        result_standard_center = box_centers(result_standard)
        expected_center = convert_box_mode(boxes=boxes1, src_mode=mode1, dst_mode="cccwhd")[:, :3]
        assert_allclose(result_standard_center, expected_center, type_test=True, device_test=True, atol=0.0)

        center = expected_center
        center[2, :] += 10
        result_centers_in_boxes = centers_in_boxes(centers=center, boxes=result_standard)
        assert_allclose(result_centers_in_boxes, np.array([False, True, False]), type_test=False)

        center_dist, _, _ = boxes_center_distance(boxes1=result_standard[1:2, :], boxes2=result_standard[1:1, :])
        assert_allclose(center_dist, np.array([[]]), type_test=False)
        center_dist, _, _ = boxes_center_distance(boxes1=result_standard[1:2, :], boxes2=result_standard[1:2, :])
        assert_allclose(center_dist, np.array([[0.0]]), type_test=False)
        center_dist, _, _ = boxes_center_distance(boxes1=result_standard[0:1, :], boxes2=result_standard[0:1, :])
        assert_allclose(center_dist, np.array([[0.0]]), type_test=False)

        # test clip_boxes_to_image
        clipped_boxes, keep = clip_boxes_to_image(expected_box_standard, spatial_size, remove_empty=True)
        assert_allclose(
            expected_box_standard[keep, :], expected_box_standard[1:, :], type_test=True, device_test=True, atol=0.0
        )
        assert_allclose(
            id(clipped_boxes) != id(expected_box_standard), True, type_test=False, device_test=False, atol=0.0
        )

        # test non_max_suppression
        nms_box = non_max_suppression(
            boxes=result_standard, scores=boxes1[:, 1] / 2.0, nms_thresh=1.0, box_overlap_metric=box_giou
        )
        assert_allclose(nms_box, [1, 2, 0], type_test=False)

        nms_box = non_max_suppression(
            boxes=result_standard, scores=boxes1[:, 1] / 2.0, nms_thresh=-1.0, box_overlap_metric=box_iou
        )
        assert_allclose(nms_box, [1], type_test=False)
Beispiel #28
0
def spatial_crop_boxes(
    boxes: NdarrayOrTensor,
    roi_start: Union[Sequence[int], NdarrayOrTensor],
    roi_end: Union[Sequence[int], NdarrayOrTensor],
    remove_empty: bool = True,
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
    """
    This function generate the new boxes when the corresponding image is cropped to the given ROI.
    When ``remove_empty=True``, it makes sure the bounding boxes are within the new cropped image.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        roi_start: voxel coordinates for start of the crop ROI, negative values allowed.
        roi_end: voxel coordinates for end of the crop ROI, negative values allowed.
        remove_empty: whether to remove the boxes that are actually empty

    Returns:
        - cropped boxes, boxes[keep], does not share memory with original boxes
        - ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``.
    """

    # convert numpy to tensor if needed
    boxes_t = convert_data_type(boxes, torch.Tensor)[0].clone()

    # convert to float32 since torch.clamp_ does not support float16
    boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE)

    roi_start_t = convert_to_dst_type(src=roi_start,
                                      dst=boxes_t,
                                      wrap_sequence=True)[0].to(torch.int16)
    roi_end_t = convert_to_dst_type(src=roi_end,
                                    dst=boxes_t,
                                    wrap_sequence=True)[0].to(torch.int16)
    roi_end_t = torch.maximum(roi_end_t, roi_start_t)

    # makes sure the bounding boxes are within the patch
    spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=roi_end)
    for axis in range(0, spatial_dims):
        boxes_t[:, axis] = boxes_t[:,
                                   axis].clamp(min=roi_start_t[axis],
                                               max=roi_end_t[axis] - TO_REMOVE)
        boxes_t[:,
                axis + spatial_dims] = boxes_t[:, axis + spatial_dims].clamp(
                    min=roi_start_t[axis], max=roi_end_t[axis] - TO_REMOVE)
        boxes_t[:, axis] -= roi_start_t[axis]
        boxes_t[:, axis + spatial_dims] -= roi_start_t[axis]

    # remove the boxes that are actually empty
    if remove_empty:
        keep_t = boxes_t[:, spatial_dims] >= boxes_t[:, 0] + 1 - TO_REMOVE
        for axis in range(1, spatial_dims):
            keep_t = keep_t & (boxes_t[:, axis + spatial_dims] >=
                               boxes_t[:, axis] + 1 - TO_REMOVE)
        boxes_t = boxes_t[keep_t]
    else:
        keep_t = torch.full_like(boxes_t[:, 0],
                                 fill_value=True,
                                 dtype=torch.bool)

    # convert tensor back to numpy if needed
    boxes_keep, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)
    keep, *_ = convert_to_dst_type(src=keep_t, dst=boxes, dtype=keep_t.dtype)

    return boxes_keep, keep
Beispiel #29
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
from monai.utils.type_conversion import convert_data_type

# test case for apply_mask
ksp, *_ = convert_data_type(np.ones([50, 50, 2]), torch.Tensor)
TESTSM = [(ksp, )]


class TestMRIUtils(unittest.TestCase):
    @parameterized.expand(TESTSM)
    def test_mask(self, test_data):
        # random mask
        masker = RandomKspaceMask(center_fractions=[0.08],
                                  accelerations=[4.0],
                                  spatial_dims=1,
                                  is_complex=True)
        masker.set_random_state(seed=0)
        result, _ = masker(test_data)
        mask = masker.mask
        result = result[..., mask.squeeze() == 0, :].sum()
def write_nifti(
    data: NdarrayOrTensor,
    file_name: str,
    affine: Optional[NdarrayOrTensor] = None,
    target_affine: Optional[np.ndarray] = None,
    resample: bool = True,
    output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None,
    mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
    padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
    align_corners: bool = False,
    dtype: DtypeLike = np.float64,
    output_dtype: DtypeLike = np.float32,
) -> None:
    """
    Write numpy data into NIfTI files to disk.  This function converts data
    into the coordinate system defined by `target_affine` when `target_affine`
    is specified.

    If the coordinate transform between `affine` and `target_affine` could be
    achieved by simply transposing and flipping `data`, no resampling will
    happen.  otherwise this function will resample `data` using the coordinate
    transform computed from `affine` and `target_affine`.  Note that the shape
    of the resampled `data` may subject to some rounding errors. For example,
    resampling a 20x20 pixel image from pixel size (1.5, 1.5)-mm to (3.0,
    3.0)-mm space will return a 10x10-pixel image.  However, resampling a
    20x20-pixel image from pixel size (2.0, 2.0)-mm to (3.0, 3.0)-mma space
    will output a 14x14-pixel image, where the image shape is rounded from
    13.333x13.333 pixels. In this case `output_spatial_shape` could be specified so
    that this function writes image data to a designated shape.

    The saved `affine` matrix follows:
    - If `affine` equals to `target_affine`, save the data with `target_affine`.
    - If `resample=False`, transform `affine` to `new_affine` based on the orientation
    of `target_affine` and save the data with `new_affine`.
    - If `resample=True`, save the data with `target_affine`, if explicitly specify
    the `output_spatial_shape`, the shape of saved data is not computed by `target_affine`.
    - If `target_affine` is None, set `target_affine=affine` and save.
    - If `affine` and `target_affine` are None, the data will be saved with an identity
    matrix as the image affine.

    This function assumes the NIfTI dimension notations.
    Spatially it supports up to three dimensions, that is, H, HW, HWD for
    1D, 2D, 3D respectively.
    When saving multiple time steps or multiple channels `data`, time and/or
    modality axes should be appended after the first three dimensions.  For
    example, shape of 2D eight-class segmentation probabilities to be saved
    could be `(64, 64, 1, 8)`. Also, data in shape (64, 64, 8), (64, 64, 8, 1)
    will be considered as a single-channel 3D image.

    Args:
        data: input data to write to file.
        file_name: expected file name that saved on disk.
        affine: the current affine of `data`. Defaults to `np.eye(4)`
        target_affine: before saving
            the (`data`, `affine`) as a Nifti1Image,
            transform the data into the coordinates defined by `target_affine`.
        resample: whether to run resampling when the target affine
            could not be achieved by swapping/flipping data axes.
        output_spatial_shape: spatial shape of the output image.
            This option is used when resample = True.
        mode: {``"bilinear"``, ``"nearest"``}
            This option is used when ``resample = True``.
            Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
            See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
        padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
            This option is used when ``resample = True``.
            Padding mode for outside grid values. Defaults to ``"border"``.
            See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
        align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
            See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
        dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
            If None, use the data type of input data.
        output_dtype: data type for saving data. Defaults to ``np.float32``.
    """
    if isinstance(data, torch.Tensor):
        data, *_ = convert_data_type(data, np.ndarray)
    if isinstance(affine, torch.Tensor):
        affine, *_ = convert_data_type(affine, np.ndarray)
    if not isinstance(data, np.ndarray):
        raise AssertionError("input data must be numpy array or torch tensor.")
    dtype = dtype or data.dtype
    sr = min(data.ndim, 3)
    if affine is None:
        affine = np.eye(4, dtype=np.float64)
    affine = to_affine_nd(sr, affine)  # type: ignore

    if target_affine is None:
        target_affine = affine
    target_affine = to_affine_nd(sr, target_affine)

    if np.allclose(affine, target_affine, atol=1e-3):
        # no affine changes, save (data, affine)
        results_img = nib.Nifti1Image(data.astype(output_dtype),
                                      to_affine_nd(3, target_affine))
        nib.save(results_img, file_name)
        return

    # resolve orientation
    start_ornt = nib.orientations.io_orientation(affine)
    target_ornt = nib.orientations.io_orientation(target_affine)
    ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt)
    data_shape = data.shape
    data = nib.orientations.apply_orientation(data, ornt_transform)
    _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform,
                                                     data_shape)
    if np.allclose(_affine, target_affine, atol=1e-3) or not resample:
        results_img = nib.Nifti1Image(data.astype(output_dtype),
                                      to_affine_nd(3, _affine))  # type: ignore
        nib.save(results_img, file_name)
        return

    # need resampling
    affine_xform = AffineTransform(normalized=False,
                                   mode=mode,
                                   padding_mode=padding_mode,
                                   align_corners=align_corners,
                                   reverse_indexing=True)
    transform = np.linalg.inv(_affine) @ target_affine
    if output_spatial_shape is None:
        output_spatial_shape, _ = compute_shape_offset(data.shape, _affine,
                                                       target_affine)
    output_spatial_shape_ = list(
        output_spatial_shape) if output_spatial_shape is not None else []
    if data.ndim > 3:  # multi channel, resampling each channel
        while len(output_spatial_shape_) < 3:
            output_spatial_shape_ = output_spatial_shape_ + [1]
        spatial_shape, channel_shape = data.shape[:3], data.shape[3:]
        data_np: np.ndarray = data.reshape(list(spatial_shape) +
                                           [-1])  # type: ignore
        data_np = np.moveaxis(data_np, -1, 0)  # channel first for pytorch
        data_torch = affine_xform(
            torch.as_tensor(
                np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0),
            torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)),
            spatial_size=output_spatial_shape_[:3],
        )
        data_np = data_torch.squeeze(0).detach().cpu().numpy()
        data_np = np.moveaxis(data_np, 0, -1)  # channel last for nifti
        data_np = data_np.reshape(
            list(data_np.shape[:3]) + list(channel_shape))
    else:  # single channel image, need to expand to have batch and channel
        while len(output_spatial_shape_) < len(data.shape):
            output_spatial_shape_ = output_spatial_shape_ + [1]
        data_torch = affine_xform(
            torch.as_tensor(
                np.ascontiguousarray(data).astype(dtype)[None, None]),
            torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)),
            spatial_size=output_spatial_shape_[:len(data.shape)],
        )
        data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy()

    results_img = nib.Nifti1Image(data_np.astype(output_dtype),
                                  to_affine_nd(3, target_affine))
    nib.save(results_img, file_name)
    return