Ejemplo n.º 1
0
    def test_rotate_y(self, device, dtype):
        input = torch.tensor(
            [[[
                [[0.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]],
                [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
                [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
            ]]],
            device=device,
            dtype=dtype,
        )

        expected = torch.tensor(
            [[[
                [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                [[0.0, 0.0, 0.0], [2.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
                [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
            ]]],
            device=device,
            dtype=dtype,
        )

        _, _, D, H, W = input.shape
        center = torch.tensor([[(W - 1) / 2, (H - 1) / 2, (D - 1) / 2]],
                              device=device,
                              dtype=dtype)

        angles = torch.tensor([[0.0, 90.0, 0.0]], device=device, dtype=dtype)

        scales: torch.Tensor = torch.ones_like(angles,
                                               device=device,
                                               dtype=dtype)
        P = proj.get_projective_transform(center, angles, scales)
        output = proj.warp_affine3d(input, P, (3, 3, 3))
        assert_close(output, expected, rtol=1e-4, atol=1e-4)
Ejemplo n.º 2
0
def crop_by_transform_mat3d(
    tensor: torch.Tensor,
    transform: torch.Tensor,
    out_size: Tuple[int, int, int],
    mode: str = 'bilinear',
    padding_mode: str = 'zeros',
    align_corners: Optional[bool] = None,
) -> torch.Tensor:
    """Perform crop transform on 3D volumes (5D tensor) given a perspective transformation matrix.

    Args:
        tensor: the 2D image tensor with shape (B, C, H, W).
        transform: a perspective transformation matrix with shape (B, 4, 4).
        out_size: size of the output image (depth, height, width).
        mode: interpolation mode to calculate output values
          ``'bilinear'`` | ``'nearest'``.
        padding_mode: padding mode for outside grid values
          ``'zeros'`` | ``'border'`` | ``'reflection'``.
        align_corners: mode for grid_generation.

    Returns:
        the output tensor with patches.
    """
    # simulate broadcasting
    dst_trans_src = transform.expand(tensor.shape[0], -1, -1)

    patches: torch.Tensor = warp_affine3d(tensor,
                                          dst_trans_src[:, :3, :],
                                          out_size,
                                          flags=mode,
                                          padding_mode=padding_mode,
                                          align_corners=align_corners)

    return patches
Ejemplo n.º 3
0
 def test_forth_back(self, device, dtype):
     out_shape = (3, 4, 5)
     input = torch.rand(2, 5, 3, 4, 5, device=device, dtype=dtype)
     P = torch.rand(2, 3, 4, device=device, dtype=dtype)
     P = kornia.geometry.convert_affinematrix_to_homography3d(P)
     P_hat = (P.inverse() @ P)[:, :3]
     output = proj.warp_affine3d(input, P_hat, out_shape, flags='nearest')
     assert_allclose(output, input, rtol=1e-4, atol=1e-4)
Ejemplo n.º 4
0
def affine3d(
    tensor: torch.Tensor,
    matrix: torch.Tensor,
    mode: str = 'bilinear',
    padding_mode: str = 'zeros',
    align_corners: bool = False,
) -> torch.Tensor:
    r"""Apply an affine transformation to the 3d volume.

    Args:
        tensor (torch.Tensor): The image tensor to be warped in shapes of
            :math:`(D, H, W)`, :math:`(C, D, H, W)` and :math:`(B, C, D, H, W)`.
        matrix (torch.Tensor): The affine transformation matrix with shape :math:`(B, 3, 4)`.
        mode (str): interpolation mode to calculate output values
          'bilinear' | 'nearest'. Default: 'bilinear'.
        padding_mode (str): padding mode for outside grid values
          'zeros' | 'border' | 'reflection'. Default: 'zeros'.
        align_corners(bool, optional): interpolation flag. Default: False.

    Returns:
        torch.Tensor: The warped image.

    Example:
        >>> img = torch.rand(1, 2, 4, 3, 5)
        >>> aff = torch.eye(3, 4)[None]
        >>> out = affine3d(img, aff)
        >>> print(out.shape)
        torch.Size([1, 2, 4, 3, 5])
    """
    # warping needs data in the shape of BCDHW
    is_unbatched: bool = tensor.ndimension() == 4
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # we enforce broadcasting since by default grid_sample it does not
    # give support for that
    matrix = matrix.expand(tensor.shape[0], -1, -1)

    # warp the input tensor
    depth: int = tensor.shape[-3]
    height: int = tensor.shape[-2]
    width: int = tensor.shape[-1]
    warped: torch.Tensor = warp_affine3d(tensor, matrix,
                                         (depth, height, width), mode,
                                         padding_mode, align_corners)

    # return in the original shape
    if is_unbatched:
        warped = torch.squeeze(warped, dim=0)

    return warped
Ejemplo n.º 5
0
    def test_rotate_y_large(self, device, dtype):
        """Rotates 90deg anti-clockwise."""
        input = torch.tensor(
            [[
                [
                    [[0.0, 4.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]],
                    [[0.0, 2.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                ],
                [
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 9.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 6.0, 7.0], [0.0, 0.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 0.0]],
                ],
            ]],
            device=device,
            dtype=dtype,
        )

        expected = torch.tensor(
            [[
                [
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                    [[4.0, 2.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                ],
                [
                    [[0.0, 0.0, 0.0], [0.0, 7.0, 0.0], [0.0, 0.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 6.0, 8.0], [9.0, 0.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                ],
            ]],
            device=device,
            dtype=dtype,
        )

        _, _, D, H, W = input.shape
        center = torch.tensor([[(W - 1) / 2, (H - 1) / 2, (D - 1) / 2]],
                              device=device,
                              dtype=dtype)

        angles = torch.tensor([[0.0, 90.0, 0.0]], device=device, dtype=dtype)

        scales: torch.Tensor = torch.ones_like(angles,
                                               device=device,
                                               dtype=dtype)
        P = proj.get_projective_transform(center, angles, scales)
        output = proj.warp_affine3d(input, P, (3, 3, 3))
        assert_close(output, expected, rtol=1e-4, atol=1e-4)
Ejemplo n.º 6
0
def apply_affine3d(input: torch.Tensor, params: Dict[str, torch.Tensor],
                   flags: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Random affine transformation of the image keeping center invariant.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (D, H, W), (C, D, H, W), (B, C, D, H, W).
        params (Dict[str, torch.Tensor]):
            - params['angles']: Degrees of rotation with the shape of :math: `(*, 3)` for yaw, pitch, roll.
            - params['translations']: Horizontal, vertical and depthical translations (dx,dy,dz).
            - params['center']: Rotation center (x,y,z).
            - params['scale']: Isotropic scaling params.
            - params['sxy']: Shear param toward x-y-axis.
            - params['sxz']: Shear param toward x-z-axis.
            - params['syx']: Shear param toward y-x-axis.
            - params['syz']: Shear param toward y-z-axis.
            - params['szx']: Shear param toward z-x-axis.
            - params['szy']: Shear param toward z-y-axis.
        flags (Dict[str, torch.Tensor]):
            - params['resample']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: The transfromed input
    """
    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    input = _transform_input3d(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-4:])

    depth, height, width = x_data.shape[-3:]

    # concatenate transforms
    transform: torch.Tensor = compute_affine_transformation3d(input, params)

    resample_name: str = Resample(flags['resample'].item()).name.lower()
    align_corners: bool = cast(bool, flags['align_corners'].item())

    out_data: torch.Tensor = warp_affine3d(x_data,
                                           transform[:, :3, :],
                                           (depth, height, width),
                                           resample_name,
                                           align_corners=align_corners)
    return out_data.view_as(input)
Ejemplo n.º 7
0
    def test_rotate_x(self, device, dtype):
        input = torch.tensor([[[[
            [0., 0., 0.],
            [0., 2., 0.],
            [0., 0., 0.],
        ], [
            [0., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.],
        ], [
            [0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.],
        ]]]],
                             device=device,
                             dtype=dtype)

        expected = torch.tensor([[[[
            [0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.],
        ], [
            [0., 0., 0.],
            [0., 1., 0.],
            [0., 2., 0.],
        ], [
            [0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.],
        ]]]],
                                device=device,
                                dtype=dtype)

        _, _, D, H, W = input.shape
        center = torch.tensor([[(W - 1) / 2, (H - 1) / 2, (D - 1) / 2]],
                              device=device,
                              dtype=dtype)

        angles = torch.tensor([[90., 0., 0.]], device=device, dtype=dtype)

        scales: torch.Tensor = torch.ones_like(angles,
                                               device=device,
                                               dtype=dtype)
        P = proj.get_projective_transform(center, angles, scales)
        output = proj.warp_affine3d(input, P, (3, 3, 3))
        assert_allclose(output, expected)
Ejemplo n.º 8
0
def affine3d(tensor: torch.Tensor,
             matrix: torch.Tensor,
             mode: str = 'bilinear',
             align_corners: bool = False) -> torch.Tensor:
    r"""Apply an affine transformation to the 3d volume.

    Args:
        tensor (torch.Tensor): The image tensor to be warped in shapes of
            :math:`(D, H, W)`, :math:`(C, D, H, W)` and :math:`(B, C, D, H, W)`.
        matrix (torch.Tensor): The 3x4 affine transformation matrix.
        mode (str): 'bilinear' | 'nearest'
        align_corners(bool): interpolation flag. Default: False. See
        https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail

    Returns:
        torch.Tensor: The warped image.
    """
    # warping needs data in the shape of BCDHW
    is_unbatched: bool = tensor.ndimension() == 4
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # we enforce broadcasting since by default grid_sample it does not
    # give support for that
    matrix = matrix.expand(tensor.shape[0], -1, -1)

    # warp the input tensor
    depth: int = tensor.shape[-3]
    height: int = tensor.shape[-2]
    width: int = tensor.shape[-1]
    warped: torch.Tensor = warp_affine3d(tensor,
                                         matrix, (depth, height, width),
                                         mode,
                                         align_corners=align_corners)

    # return in the original shape
    if is_unbatched:
        warped = torch.squeeze(warped, dim=0)

    return warped
Ejemplo n.º 9
0
def crop_by_boxes3d(tensor: torch.Tensor,
                    src_box: torch.Tensor,
                    dst_box: torch.Tensor,
                    interpolation: str = 'bilinear',
                    align_corners: bool = False) -> torch.Tensor:
    """Perform crop transform on 3D volumes (5D tensor) by bounding boxes.

    Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
    Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
    So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
    in a batch must be rectangles with same width, height and depth.

    Args:
        tensor (torch.Tensor): the 3D volume tensor with shape (B, C, D, H, W).
        src_box (torch.Tensor): a tensor with shape (B, 8, 3) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx8x3, where each box is defined in the clockwise
            order: front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
            back-top-right, back-bottom-right, back-bottom-left. The coordinates must be in x, y, z order.
        dst_box (torch.Tensor): a tensor with shape (B, 8, 3) containing the coordinates of the bounding boxes
            to be placed. The tensor must have the shape of Bx8x3, where each box is defined in the clockwise
            order: front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
            back-top-right, back-bottom-right, back-bottom-left. The coordinates must be in x, y, z order.
        interpolation (str): Interpolation flag. Default: 'bilinear'.
        align_corners (bool): mode for grid_generation. Default: False. See
            https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details.

    Returns:
        torch.Tensor: the output tensor with patches.

    Examples:
        >>> input = torch.tensor([[[
        ...         [[ 0.,  1.,  2.,  3.],
        ...          [ 4.,  5.,  6.,  7.],
        ...          [ 8.,  9., 10., 11.],
        ...          [12., 13., 14., 15.]],
        ...         [[16., 17., 18., 19.],
        ...          [20., 21., 22., 23.],
        ...          [24., 25., 26., 27.],
        ...          [28., 29., 30., 31.]],
        ...         [[32., 33., 34., 35.],
        ...          [36., 37., 38., 39.],
        ...          [40., 41., 42., 43.],
        ...          [44., 45., 46., 47.]]]]])
        >>> src_box = torch.tensor([[
        ...     [1., 1., 1.],
        ...     [3., 1., 1.],
        ...     [3., 3., 1.],
        ...     [1., 3., 1.],
        ...     [1., 1., 2.],
        ...     [3., 1., 2.],
        ...     [3., 3., 2.],
        ...     [1., 3., 2.],
        ... ]])  # 1x8x3
        >>> dst_box = torch.tensor([[
        ...     [0., 0., 0.],
        ...     [2., 0., 0.],
        ...     [2., 2., 0.],
        ...     [0., 2., 0.],
        ...     [0., 0., 1.],
        ...     [2., 0., 1.],
        ...     [2., 2., 1.],
        ...     [0., 2., 1.],
        ... ]])  # 1x8x3
        >>> crop_by_boxes3d(input, src_box, dst_box, interpolation='nearest', align_corners=True)
        tensor([[[[[21., 22., 23.],
                   [25., 26., 27.],
                   [29., 30., 31.]],
        <BLANKLINE>
                  [[37., 38., 39.],
                   [41., 42., 43.],
                   [45., 46., 47.]]]]])

    """
    validate_bboxes3d(src_box)
    validate_bboxes3d(dst_box)

    assert len(
        tensor.shape
    ) == 5, f"Only tensor with shape (B, C, D, H, W) supported. Got {tensor.shape}."

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform3d(
        src_box.to(tensor.dtype), dst_box.to(tensor.dtype))
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1,
                                         -1).type_as(tensor)

    bbox = infer_box_shape3d(dst_box)
    assert (bbox[0] == bbox[0][0]).all() and (
        bbox[1] == bbox[1][0]).all() and (bbox[2] == bbox[2][0]).all(), (
            "Cropping height, width and depth must be exact same in a batch."
            f"Got height {bbox[0]}, width {bbox[1]} and depth {bbox[2]}.")
    patches: torch.Tensor = warp_affine3d(
        tensor,
        dst_trans_src[:, :3, :],
        # TODO: It will break the grads
        (int(bbox[0][0].item()), int(bbox[1][0].item()), int(bbox[2][0].item())
         ),
        flags=interpolation,
        align_corners=align_corners)

    return patches
Ejemplo n.º 10
0
 def test_batch(self, batch_size, num_channels, out_shape, device, dtype):
     B, C = batch_size, num_channels
     input = torch.rand(B, C, 3, 4, 5, device=device, dtype=dtype)
     P = torch.rand(B, 3, 4, device=device, dtype=dtype)
     output = proj.warp_affine3d(input, P, out_shape)
     assert list(output.shape) == [B, C] + list(out_shape)
Ejemplo n.º 11
0
 def test_smoke(self, device, dtype):
     input = torch.rand(1, 3, 3, 4, 5, device=device, dtype=dtype)
     P = torch.rand(1, 3, 4, device=device, dtype=dtype)
     output = proj.warp_affine3d(input, P, (3, 4, 5))
     assert output.shape == (1, 3, 3, 4, 5)
Ejemplo n.º 12
0
    def test_rotate_y_large(self, device, dtype):
        """Rotates 90deg anti-clockwise."""
        input = torch.tensor([[[[
            [0., 4., 0.],
            [0., 3., 0.],
            [0., 0., 0.],
        ], [
            [0., 2., 0.],
            [0., 1., 0.],
            [0., 0., 0.],
        ], [
            [0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.],
        ]],
                               [[
                                   [0., 0., 0.],
                                   [0., 0., 0.],
                                   [0., 9., 0.],
                               ], [
                                   [0., 0., 0.],
                                   [0., 6., 7.],
                                   [0., 0., 0.],
                               ], [
                                   [0., 0., 0.],
                                   [0., 8., 0.],
                                   [0., 0., 0.],
                               ]]]],
                             device=device,
                             dtype=dtype)

        expected = torch.tensor(
            [[[[
                [0., 0., 0.],
                [0., 0., 0.],
                [0., 0., 0.],
            ], [
                [4., 2., 0.],
                [3., 1., 0.],
                [0., 0., 0.],
            ], [
                [0., 0., 0.],
                [0., 0., 0.],
                [0., 0., 0.],
            ]],
              [[
                  [0., 0., 0.],
                  [0., 7., 0.],
                  [0., 0., 0.],
              ], [
                  [0., 0., 0.],
                  [0., 6., 8.],
                  [9., 0., 0.],
              ], [
                  [0., 0., 0.],
                  [0., 0., 0.],
                  [0., 0., 0.],
              ]]]],
            device=device,
            dtype=dtype)

        _, _, D, H, W = input.shape
        center = torch.tensor([[(W - 1) / 2, (H - 1) / 2, (D - 1) / 2]],
                              device=device,
                              dtype=dtype)

        angles = torch.tensor([[0., 90., 0.]], device=device, dtype=dtype)

        scales: torch.Tensor = torch.ones_like(angles,
                                               device=device,
                                               dtype=dtype)
        P = proj.get_projective_transform(center, angles, scales)
        output = proj.warp_affine3d(input, P, (3, 3, 3))
        assert_allclose(output, expected)