Beispiel #1
0
    def test_smoke(self, device, dtype):
        A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype)
        B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype)

        X, _, mask = safe_solve_with_mask(B, A)
        X2, _ = _torch_solve_cast(B, A)
        tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4
        if mask.sum() > 0:
            assert_close(X[mask], X2[mask], atol=tol_val, rtol=tol_val)
Beispiel #2
0
    def test_smoke(self, device, dtype):
        A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype)
        B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype)

        X, _ = _torch_solve_cast(B, A)
        error = torch.dist(B, A.matmul(X))

        tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4
        assert_close(error, torch.zeros_like(error), atol=tol_val, rtol=tol_val)
def get_tps_transform(points_src: torch.Tensor, points_dst: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Compute the TPS transform parameters that warp source points to target points.

    The input to this function is a tensor of :math:`(x, y)` source points :math:`(B, N, 2)` and a corresponding
    tensor of target :math:`(x, y)` points :math:`(B, N, 2)`.

    Args:
        points_src: batch of source points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.
        points_dst: batch of target points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.

    Returns:
        :math:`(B, N, 2)` tensor of kernel weights and :math:`(B, 3, 2)`
            tensor of affine weights. The last dimension contains the x-transform and y-transform weights
            as separate columns.

    Example:
        >>> points_src = torch.rand(1, 5, 2)
        >>> points_dst = torch.rand(1, 5, 2)
        >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)

    .. note::
        This function is often used in conjunction with :func:`warp_points_tps`, :func:`warp_image_tps`.
    """
    if not isinstance(points_src, torch.Tensor):
        raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")

    if not isinstance(points_dst, torch.Tensor):
        raise TypeError(f"Input points_dst is not torch.Tensor. Got {type(points_dst)}")

    if not len(points_src.shape) == 3:
        raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")

    if not len(points_dst.shape) == 3:
        raise ValueError(f"Invalid shape for points_dst, expected BxNx2. Got {points_dst.shape}")

    device, dtype = points_src.device, points_src.dtype
    batch_size, num_points = points_src.shape[:2]

    # set up and solve linear system
    # [K   P] [w] = [dst]
    # [P^T 0] [a]   [ 0 ]
    pair_distance: torch.Tensor = _pair_square_euclidean(points_src, points_dst)
    k_matrix: torch.Tensor = _kernel_distance(pair_distance)

    zero_mat: torch.Tensor = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype)
    one_mat: torch.Tensor = torch.ones(batch_size, num_points, 1, device=device, dtype=dtype)
    dest_with_zeros: torch.Tensor = torch.cat((points_dst, zero_mat[:, :, :2]), 1)
    p_matrix: torch.Tensor = torch.cat((one_mat, points_src), -1)
    p_matrix_t: torch.Tensor = torch.cat((p_matrix, zero_mat), 1).transpose(1, 2)
    l_matrix: torch.Tensor = torch.cat((k_matrix, p_matrix), -1)
    l_matrix = torch.cat((l_matrix, p_matrix_t), 1)

    weights, _ = _torch_solve_cast(dest_with_zeros, l_matrix)
    kernel_weights: torch.Tensor = weights[:, :-3]
    affine_weights: torch.Tensor = weights[:, -3:]

    return (kernel_weights, affine_weights)
Beispiel #4
0
def get_perspective_transform3d(src: torch.Tensor,
                                dst: torch.Tensor) -> torch.Tensor:
    r"""Calculate a 3d perspective transform from four pairs of the corresponding points.

    The function calculates the matrix of a perspective transform so that:

    .. math ::

        \begin{bmatrix}
        t_{i}x_{i}^{'} \\
        t_{i}y_{i}^{'} \\
        t_{i}z_{i}^{'} \\
        t_{i} \\
        \end{bmatrix}
        =
        \textbf{map_matrix} \cdot
        \begin{bmatrix}
        x_{i} \\
        y_{i} \\
        z_{i} \\
        1 \\
        \end{bmatrix}

    where

    .. math ::
        dst(i) = (x_{i}^{'},y_{i}^{'},z_{i}^{'}), src(i) = (x_{i}, y_{i}, z_{i}), i = 0,1,2,5,7

    Concrete math is as below:

    .. math ::

        \[ u_i =\frac{c_{00} * x_i + c_{01} * y_i + c_{02} * z_i + c_{03}}
            {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
        \[ v_i =\frac{c_{10} * x_i + c_{11} * y_i + c_{12} * z_i + c_{13}}
            {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
        \[ w_i =\frac{c_{20} * x_i + c_{21} * y_i + c_{22} * z_i + c_{23}}
            {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]

    .. math ::

        \begin{pmatrix}
        x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_0*u_0 & -y_0*u_0 & -z_0 * u_0 \\
        x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_1*u_1 & -y_1*u_1 & -z_1 * u_1 \\
        x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_2*u_2 & -y_2*u_2 & -z_2 * u_2 \\
        x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_5*u_5 & -y_5*u_5 & -z_5 * u_5 \\
        x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_7*u_7 & -y_7*u_7 & -z_7 * u_7 \\
        0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & -x_0*v_0 & -y_0*v_0 & -z_0 * v_0 \\
        0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & -x_1*v_1 & -y_1*v_1 & -z_1 * v_1 \\
        0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & -x_2*v_2 & -y_2*v_2 & -z_2 * v_2 \\
        0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & -x_5*v_5 & -y_5*v_5 & -z_5 * v_5 \\
        0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & -x_7*v_7 & -y_7*v_7 & -z_7 * v_7 \\
        0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & -x_0*w_0 & -y_0*w_0 & -z_0 * w_0 \\
        0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & -x_1*w_1 & -y_1*w_1 & -z_1 * w_1 \\
        0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & -x_2*w_2 & -y_2*w_2 & -z_2 * w_2 \\
        0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & -x_5*w_5 & -y_5*w_5 & -z_5 * w_5 \\
        0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & -x_7*w_7 & -y_7*w_7 & -z_7 * w_7 \\
        \end{pmatrix}

    Args:
        src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 8, 3)`.
        dst: coordinates of the corresponding quadrangle vertices in
            the destination image with shape :math:`(B, 8, 3)`.

    Returns:
        the perspective transformation with shape :math:`(B, 4, 4)`.

    .. note::
        This function is often used in conjunction with :func:`warp_perspective3d`.
    """
    if not isinstance(src, (torch.Tensor)):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(src)}")

    if not isinstance(dst, (torch.Tensor)):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(dst)}")

    if not src.shape[-2:] == (8, 3):
        raise ValueError(f"Inputs must be a Bx8x3 tensor. Got {src.shape}")

    if not src.shape == dst.shape:
        raise ValueError(f"Inputs must have the same shape. Got {dst.shape}")

    if not (src.shape[0] == dst.shape[0]):
        raise ValueError(
            f"Inputs must have same batch size dimension. Expect {src.shape} but got {dst.shape}"
        )

    if not (src.device == dst.device and src.dtype == dst.dtype):
        raise AssertionError(
            f"Expect `src` and `dst` to be in the same device (Got {src.dtype}, {dst.dtype}) "
            f"with the same dtype (Got {src.dtype}, {dst.dtype}).")

    # we build matrix A by using only 4 point correspondence. The linear
    # system is solved with the least square method, so here
    # we could even pass more correspondence
    p = []

    # 000, 100, 110, 101, 011
    for i in [0, 1, 2, 5, 7]:
        p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'x'))
        p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'y'))
        p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'z'))

    # A is Bx15x15
    A = torch.stack(p, dim=1)

    # b is a Bx15x1
    b = torch.stack(
        [
            dst[:, 0:1, 0],
            dst[:, 0:1, 1],
            dst[:, 0:1, 2],
            dst[:, 1:2, 0],
            dst[:, 1:2, 1],
            dst[:, 1:2, 2],
            dst[:, 2:3, 0],
            dst[:, 2:3, 1],
            dst[:, 2:3, 2],
            # dst[:, 3:4, 0], dst[:, 3:4, 1], dst[:, 3:4, 2],
            # dst[:, 4:5, 0], dst[:, 4:5, 1], dst[:, 4:5, 2],
            dst[:, 5:6, 0],
            dst[:, 5:6, 1],
            dst[:, 5:6, 2],
            # dst[:, 6:7, 0], dst[:, 6:7, 1], dst[:, 6:7, 2],
            dst[:, 7:8, 0],
            dst[:, 7:8, 1],
            dst[:, 7:8, 2],
        ],
        dim=1,
    )

    # solve the system Ax = b
    X, _ = _torch_solve_cast(b, A)

    # create variable to return
    batch_size = src.shape[0]
    M = torch.ones(batch_size, 16, device=src.device, dtype=src.dtype)
    M[..., :15] = torch.squeeze(X, dim=-1)
    return M.view(-1, 4, 4)  # Bx4x4
Beispiel #5
0
def get_perspective_transform(src, dst):
    r"""Calculates a perspective transform from four pairs of the corresponding
    points.

    The function calculates the matrix of a perspective transform so that:

    .. math ::

        \begin{bmatrix}
        t_{i}x_{i}^{'} \\
        t_{i}y_{i}^{'} \\
        t_{i} \\
        \end{bmatrix}
        =
        \textbf{map_matrix} \cdot
        \begin{bmatrix}
        x_{i} \\
        y_{i} \\
        1 \\
        \end{bmatrix}

    where

    .. math ::
        dst(i) = (x_{i}^{'},y_{i}^{'}), src(i) = (x_{i}, y_{i}), i = 0,1,2,3

    Args:
        src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 4, 2)`.
        dst: coordinates of the corresponding quadrangle vertices in
            the destination image with shape :math:`(B, 4, 2)`.

    Returns:
        the perspective transformation with shape :math:`(B, 3, 3)`.

    .. note::
        This function is often used in conjuntion with :func:`warp_perspective`.
    """
    if not isinstance(src, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(src)))

    if not isinstance(dst, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(dst)))

    if not src.shape[-2:] == (4, 2):
        raise ValueError("Inputs must be a Bx4x2 tensor. Got {}".format(
            src.shape))

    if not src.shape == dst.shape:
        raise ValueError("Inputs must have the same shape. Got {}".format(
            dst.shape))

    if not (src.shape[0] == dst.shape[0]):
        raise ValueError(
            "Inputs must have same batch size dimension. Expect {} but got {}".
            format(src.shape, dst.shape))

    # we build matrix A by using only 4 point correspondence. The linear
    # system is solved with the least square method, so here
    # we could even pass more correspondence
    p = []
    for i in [0, 1, 2, 3]:
        p.append(_build_perspective_param(src[:, i], dst[:, i], 'x'))
        p.append(_build_perspective_param(src[:, i], dst[:, i], 'y'))

    # A is Bx8x8
    A = torch.stack(p, dim=1)

    # b is a Bx8x1
    b = torch.stack(
        [
            dst[:, 0:1, 0],
            dst[:, 0:1, 1],
            dst[:, 1:2, 0],
            dst[:, 1:2, 1],
            dst[:, 2:3, 0],
            dst[:, 2:3, 1],
            dst[:, 3:4, 0],
            dst[:, 3:4, 1],
        ],
        dim=1,
    )

    # solve the system Ax = b
    X, LU = _torch_solve_cast(b, A)

    # create variable to return
    batch_size = src.shape[0]
    M = torch.ones(batch_size, 9, device=src.device, dtype=src.dtype)
    M[..., :8] = torch.squeeze(X, dim=-1)

    return M.view(-1, 3, 3)  # Bx3x3
Beispiel #6
0
def conv_quad_interp3d(input: torch.Tensor,
                       strict_maxima_bonus: float = 10.0,
                       eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Compute the single iteration of quadratic interpolation of the extremum (max or min).

    Args:
        input: the given heatmap with shape :math:`(N, C, D_{in}, H_{in}, W_{in})`.
        strict_maxima_bonus: pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value.
          This is needed for mimic behavior of strict NMS in classic local features
        eps: parameter to control the hessian matrix ill-condition number.

    Returns:
        the location and value per each 3x3x3 window which contains strict extremum, similar to one done is SIFT.
        :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`,

        where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_quad_interp3d(input, 1.0)
    """
    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    if not len(input.shape) == 5:
        raise ValueError(
            f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")

    B, CH, D, H, W = input.shape
    grid_global: torch.Tensor = create_meshgrid3d(D,
                                                  H,
                                                  W,
                                                  False,
                                                  device=input.device).permute(
                                                      0, 4, 1, 2, 3)
    grid_global = grid_global.to(input.dtype)

    # to determine the location we are solving system of linear equations Ax = b, where b is 1st order gradient
    # and A is Hessian matrix
    b: torch.Tensor = kornia.filters.spatial_gradient3d(input,
                                                        order=1,
                                                        mode='diff')  #
    b = b.permute(0, 1, 3, 4, 5, 2).reshape(-1, 3, 1)
    A: torch.Tensor = kornia.filters.spatial_gradient3d(input,
                                                        order=2,
                                                        mode='diff')
    A = A.permute(0, 1, 3, 4, 5, 2).reshape(-1, 6)
    dxx = A[..., 0]
    dyy = A[..., 1]
    dss = A[..., 2]
    dxy = 0.25 * A[..., 3]  # normalization to match OpenCV implementation
    dys = 0.25 * A[..., 4]  # normalization to match OpenCV implementation
    dxs = 0.25 * A[..., 5]  # normalization to match OpenCV implementation

    Hes = torch.stack([dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss],
                      dim=-1).view(-1, 3, 3)

    # The following is needed to avoid singular cases
    Hes += torch.rand(Hes[0].size(), device=Hes.device).abs()[None] * eps

    nms_mask: torch.Tensor = kornia.feature.nms3d(input, (3, 3, 3), True)
    x_solved: torch.Tensor = torch.zeros_like(b)
    x_solved_masked, _ = _torch_solve_cast(b[nms_mask.view(-1)],
                                           Hes[nms_mask.view(-1)])
    x_solved.masked_scatter_(nms_mask.view(-1, 1, 1), x_solved_masked)
    dx: torch.Tensor = -x_solved

    # Ignore ones, which are far from window center
    mask1 = dx.abs().max(dim=1, keepdim=True)[0] > 0.7
    dx.masked_fill_(mask1.expand_as(dx), 0)
    dy: torch.Tensor = 0.5 * torch.bmm(b.permute(0, 2, 1), dx)
    y_max = input + dy.view(B, CH, D, H, W)
    if strict_maxima_bonus > 0:
        y_max += strict_maxima_bonus * nms_mask.to(input.dtype)

    dx_res: torch.Tensor = dx.flip(1).reshape(B, CH, D, H, W,
                                              3).permute(0, 1, 5, 2, 3, 4)
    coords_max: torch.Tensor = grid_global.repeat(B, 1, 1, 1, 1).unsqueeze(1)
    coords_max = coords_max + dx_res

    return coords_max, y_max