Esempio n. 1
0
    def test_bounding_boxes_dim_inferring(self, device, dtype):
        boxes = torch.tensor(
            [
                [[0, 1, 2], [10, 1, 2], [10, 21, 2], [0, 21, 2], [0, 1, 32], [10, 1, 32], [10, 21, 32], [0, 21, 32]],
                [[3, 4, 5], [43, 4, 5], [43, 54, 5], [3, 54, 5], [3, 4, 65], [43, 4, 65], [43, 54, 65], [3, 54, 65]],
            ],
            device=device,
            dtype=dtype,
        )  # 2x8x3
        d, h, w = infer_bbox_shape3d(boxes)

        assert_allclose(d, torch.tensor([31.0, 61.0], device=device, dtype=dtype))
        assert_allclose(h, torch.tensor([21.0, 51.0], device=device, dtype=dtype))
        assert_allclose(w, torch.tensor([11.0, 41.0], device=device, dtype=dtype))
Esempio n. 2
0
def crop_by_boxes3d(
    tensor: torch.Tensor,
    src_box: torch.Tensor,
    dst_box: torch.Tensor,
    interpolation: str = 'bilinear',
    align_corners: bool = False,
) -> torch.Tensor:
    """Perform crop transform on 3D volumes (5D tensor) by bounding boxes.

    Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
    Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
    So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
    in a batch must be rectangles with same width, height and depth.

    Args:
        tensor : the 3D volume tensor with shape (B, C, D, H, W).
        src_box : a tensor with shape (B, 8, 3) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx8x3, where each box is defined in the clockwise
            order: front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
            back-top-right, back-bottom-right, back-bottom-left. The coordinates must be in x, y, z order.
        dst_box: a tensor with shape (B, 8, 3) containing the coordinates of the bounding boxes
            to be placed. The tensor must have the shape of Bx8x3, where each box is defined in the clockwise
            order: front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
            back-top-right, back-bottom-right, back-bottom-left. The coordinates must be in x, y, z order.
        interpolation: Interpolation flag.
        align_corners: mode for grid_generation.

    Returns:
        the output tensor with patches.

    Examples:
        >>> input = torch.tensor([[[
        ...         [[ 0.,  1.,  2.,  3.],
        ...          [ 4.,  5.,  6.,  7.],
        ...          [ 8.,  9., 10., 11.],
        ...          [12., 13., 14., 15.]],
        ...         [[16., 17., 18., 19.],
        ...          [20., 21., 22., 23.],
        ...          [24., 25., 26., 27.],
        ...          [28., 29., 30., 31.]],
        ...         [[32., 33., 34., 35.],
        ...          [36., 37., 38., 39.],
        ...          [40., 41., 42., 43.],
        ...          [44., 45., 46., 47.]]]]])
        >>> src_box = torch.tensor([[
        ...     [1., 1., 1.],
        ...     [3., 1., 1.],
        ...     [3., 3., 1.],
        ...     [1., 3., 1.],
        ...     [1., 1., 2.],
        ...     [3., 1., 2.],
        ...     [3., 3., 2.],
        ...     [1., 3., 2.],
        ... ]])  # 1x8x3
        >>> dst_box = torch.tensor([[
        ...     [0., 0., 0.],
        ...     [2., 0., 0.],
        ...     [2., 2., 0.],
        ...     [0., 2., 0.],
        ...     [0., 0., 1.],
        ...     [2., 0., 1.],
        ...     [2., 2., 1.],
        ...     [0., 2., 1.],
        ... ]])  # 1x8x3
        >>> crop_by_boxes3d(input, src_box, dst_box, interpolation='nearest', align_corners=True)
        tensor([[[[[21., 22., 23.],
                   [25., 26., 27.],
                   [29., 30., 31.]],
        <BLANKLINE>
                  [[37., 38., 39.],
                   [41., 42., 43.],
                   [45., 46., 47.]]]]])
    """
    validate_bbox3d(src_box)
    validate_bbox3d(dst_box)

    if len(tensor.shape) != 5:
        raise AssertionError(
            f"Only tensor with shape (B, C, D, H, W) supported. Got {tensor.shape}."
        )

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform3d(
        src_box.to(tensor.dtype), dst_box.to(tensor.dtype))
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1,
                                         -1).type_as(tensor)

    bbox = infer_bbox_shape3d(dst_box)
    if not ((bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all() and
            (bbox[2] == bbox[2][0]).all()):
        raise AssertionError(
            "Cropping height, width and depth must be exact same in a batch."
            f"Got height {bbox[0]}, width {bbox[1]} and depth {bbox[2]}.")

    patches: torch.Tensor = crop_by_transform_mat3d(
        tensor,
        dst_trans_src,
        (int(bbox[0][0].item()), int(bbox[1][0].item()), int(
            bbox[2][0].item())),
        mode=interpolation,
        align_corners=align_corners,
    )

    return patches