def test_smoke(self, device, dtype): A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype) B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype) X, _, mask = safe_solve_with_mask(B, A) X2, _ = _torch_solve_cast(B, A) tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4 if mask.sum() > 0: assert_close(X[mask], X2[mask], atol=tol_val, rtol=tol_val)
def test_smoke(self, device, dtype): A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype) B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype) X, _ = _torch_solve_cast(B, A) error = torch.dist(B, A.matmul(X)) tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4 assert_close(error, torch.zeros_like(error), atol=tol_val, rtol=tol_val)
def get_tps_transform(points_src: torch.Tensor, points_dst: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Compute the TPS transform parameters that warp source points to target points. The input to this function is a tensor of :math:`(x, y)` source points :math:`(B, N, 2)` and a corresponding tensor of target :math:`(x, y)` points :math:`(B, N, 2)`. Args: points_src: batch of source points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors. points_dst: batch of target points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors. Returns: :math:`(B, N, 2)` tensor of kernel weights and :math:`(B, 3, 2)` tensor of affine weights. The last dimension contains the x-transform and y-transform weights as separate columns. Example: >>> points_src = torch.rand(1, 5, 2) >>> points_dst = torch.rand(1, 5, 2) >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst) .. note:: This function is often used in conjunction with :func:`warp_points_tps`, :func:`warp_image_tps`. """ if not isinstance(points_src, torch.Tensor): raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}") if not isinstance(points_dst, torch.Tensor): raise TypeError(f"Input points_dst is not torch.Tensor. Got {type(points_dst)}") if not len(points_src.shape) == 3: raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}") if not len(points_dst.shape) == 3: raise ValueError(f"Invalid shape for points_dst, expected BxNx2. Got {points_dst.shape}") device, dtype = points_src.device, points_src.dtype batch_size, num_points = points_src.shape[:2] # set up and solve linear system # [K P] [w] = [dst] # [P^T 0] [a] [ 0 ] pair_distance: torch.Tensor = _pair_square_euclidean(points_src, points_dst) k_matrix: torch.Tensor = _kernel_distance(pair_distance) zero_mat: torch.Tensor = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype) one_mat: torch.Tensor = torch.ones(batch_size, num_points, 1, device=device, dtype=dtype) dest_with_zeros: torch.Tensor = torch.cat((points_dst, zero_mat[:, :, :2]), 1) p_matrix: torch.Tensor = torch.cat((one_mat, points_src), -1) p_matrix_t: torch.Tensor = torch.cat((p_matrix, zero_mat), 1).transpose(1, 2) l_matrix: torch.Tensor = torch.cat((k_matrix, p_matrix), -1) l_matrix = torch.cat((l_matrix, p_matrix_t), 1) weights, _ = _torch_solve_cast(dest_with_zeros, l_matrix) kernel_weights: torch.Tensor = weights[:, :-3] affine_weights: torch.Tensor = weights[:, -3:] return (kernel_weights, affine_weights)
def get_perspective_transform3d(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: r"""Calculate a 3d perspective transform from four pairs of the corresponding points. The function calculates the matrix of a perspective transform so that: .. math :: \begin{bmatrix} t_{i}x_{i}^{'} \\ t_{i}y_{i}^{'} \\ t_{i}z_{i}^{'} \\ t_{i} \\ \end{bmatrix} = \textbf{map_matrix} \cdot \begin{bmatrix} x_{i} \\ y_{i} \\ z_{i} \\ 1 \\ \end{bmatrix} where .. math :: dst(i) = (x_{i}^{'},y_{i}^{'},z_{i}^{'}), src(i) = (x_{i}, y_{i}, z_{i}), i = 0,1,2,5,7 Concrete math is as below: .. math :: \[ u_i =\frac{c_{00} * x_i + c_{01} * y_i + c_{02} * z_i + c_{03}} {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \] \[ v_i =\frac{c_{10} * x_i + c_{11} * y_i + c_{12} * z_i + c_{13}} {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \] \[ w_i =\frac{c_{20} * x_i + c_{21} * y_i + c_{22} * z_i + c_{23}} {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \] .. math :: \begin{pmatrix} x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_0*u_0 & -y_0*u_0 & -z_0 * u_0 \\ x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_1*u_1 & -y_1*u_1 & -z_1 * u_1 \\ x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_2*u_2 & -y_2*u_2 & -z_2 * u_2 \\ x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_5*u_5 & -y_5*u_5 & -z_5 * u_5 \\ x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_7*u_7 & -y_7*u_7 & -z_7 * u_7 \\ 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & -x_0*v_0 & -y_0*v_0 & -z_0 * v_0 \\ 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & -x_1*v_1 & -y_1*v_1 & -z_1 * v_1 \\ 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & -x_2*v_2 & -y_2*v_2 & -z_2 * v_2 \\ 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & -x_5*v_5 & -y_5*v_5 & -z_5 * v_5 \\ 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & -x_7*v_7 & -y_7*v_7 & -z_7 * v_7 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & -x_0*w_0 & -y_0*w_0 & -z_0 * w_0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & -x_1*w_1 & -y_1*w_1 & -z_1 * w_1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & -x_2*w_2 & -y_2*w_2 & -z_2 * w_2 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & -x_5*w_5 & -y_5*w_5 & -z_5 * w_5 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & -x_7*w_7 & -y_7*w_7 & -z_7 * w_7 \\ \end{pmatrix} Args: src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 8, 3)`. dst: coordinates of the corresponding quadrangle vertices in the destination image with shape :math:`(B, 8, 3)`. Returns: the perspective transformation with shape :math:`(B, 4, 4)`. .. note:: This function is often used in conjunction with :func:`warp_perspective3d`. """ if not isinstance(src, (torch.Tensor)): raise TypeError(f"Input type is not a torch.Tensor. Got {type(src)}") if not isinstance(dst, (torch.Tensor)): raise TypeError(f"Input type is not a torch.Tensor. Got {type(dst)}") if not src.shape[-2:] == (8, 3): raise ValueError(f"Inputs must be a Bx8x3 tensor. Got {src.shape}") if not src.shape == dst.shape: raise ValueError(f"Inputs must have the same shape. Got {dst.shape}") if not (src.shape[0] == dst.shape[0]): raise ValueError( f"Inputs must have same batch size dimension. Expect {src.shape} but got {dst.shape}" ) if not (src.device == dst.device and src.dtype == dst.dtype): raise AssertionError( f"Expect `src` and `dst` to be in the same device (Got {src.dtype}, {dst.dtype}) " f"with the same dtype (Got {src.dtype}, {dst.dtype}).") # we build matrix A by using only 4 point correspondence. The linear # system is solved with the least square method, so here # we could even pass more correspondence p = [] # 000, 100, 110, 101, 011 for i in [0, 1, 2, 5, 7]: p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'x')) p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'y')) p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'z')) # A is Bx15x15 A = torch.stack(p, dim=1) # b is a Bx15x1 b = torch.stack( [ dst[:, 0:1, 0], dst[:, 0:1, 1], dst[:, 0:1, 2], dst[:, 1:2, 0], dst[:, 1:2, 1], dst[:, 1:2, 2], dst[:, 2:3, 0], dst[:, 2:3, 1], dst[:, 2:3, 2], # dst[:, 3:4, 0], dst[:, 3:4, 1], dst[:, 3:4, 2], # dst[:, 4:5, 0], dst[:, 4:5, 1], dst[:, 4:5, 2], dst[:, 5:6, 0], dst[:, 5:6, 1], dst[:, 5:6, 2], # dst[:, 6:7, 0], dst[:, 6:7, 1], dst[:, 6:7, 2], dst[:, 7:8, 0], dst[:, 7:8, 1], dst[:, 7:8, 2], ], dim=1, ) # solve the system Ax = b X, _ = _torch_solve_cast(b, A) # create variable to return batch_size = src.shape[0] M = torch.ones(batch_size, 16, device=src.device, dtype=src.dtype) M[..., :15] = torch.squeeze(X, dim=-1) return M.view(-1, 4, 4) # Bx4x4
def get_perspective_transform(src, dst): r"""Calculates a perspective transform from four pairs of the corresponding points. The function calculates the matrix of a perspective transform so that: .. math :: \begin{bmatrix} t_{i}x_{i}^{'} \\ t_{i}y_{i}^{'} \\ t_{i} \\ \end{bmatrix} = \textbf{map_matrix} \cdot \begin{bmatrix} x_{i} \\ y_{i} \\ 1 \\ \end{bmatrix} where .. math :: dst(i) = (x_{i}^{'},y_{i}^{'}), src(i) = (x_{i}, y_{i}), i = 0,1,2,3 Args: src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 4, 2)`. dst: coordinates of the corresponding quadrangle vertices in the destination image with shape :math:`(B, 4, 2)`. Returns: the perspective transformation with shape :math:`(B, 3, 3)`. .. note:: This function is often used in conjuntion with :func:`warp_perspective`. """ if not isinstance(src, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(src))) if not isinstance(dst, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(dst))) if not src.shape[-2:] == (4, 2): raise ValueError("Inputs must be a Bx4x2 tensor. Got {}".format( src.shape)) if not src.shape == dst.shape: raise ValueError("Inputs must have the same shape. Got {}".format( dst.shape)) if not (src.shape[0] == dst.shape[0]): raise ValueError( "Inputs must have same batch size dimension. Expect {} but got {}". format(src.shape, dst.shape)) # we build matrix A by using only 4 point correspondence. The linear # system is solved with the least square method, so here # we could even pass more correspondence p = [] for i in [0, 1, 2, 3]: p.append(_build_perspective_param(src[:, i], dst[:, i], 'x')) p.append(_build_perspective_param(src[:, i], dst[:, i], 'y')) # A is Bx8x8 A = torch.stack(p, dim=1) # b is a Bx8x1 b = torch.stack( [ dst[:, 0:1, 0], dst[:, 0:1, 1], dst[:, 1:2, 0], dst[:, 1:2, 1], dst[:, 2:3, 0], dst[:, 2:3, 1], dst[:, 3:4, 0], dst[:, 3:4, 1], ], dim=1, ) # solve the system Ax = b X, LU = _torch_solve_cast(b, A) # create variable to return batch_size = src.shape[0] M = torch.ones(batch_size, 9, device=src.device, dtype=src.dtype) M[..., :8] = torch.squeeze(X, dim=-1) return M.view(-1, 3, 3) # Bx3x3
def conv_quad_interp3d(input: torch.Tensor, strict_maxima_bonus: float = 10.0, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor]: r"""Compute the single iteration of quadratic interpolation of the extremum (max or min). Args: input: the given heatmap with shape :math:`(N, C, D_{in}, H_{in}, W_{in})`. strict_maxima_bonus: pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features eps: parameter to control the hessian matrix ill-condition number. Returns: the location and value per each 3x3x3 window which contains strict extremum, similar to one done is SIFT. :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_quad_interp3d(input, 1.0) """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not len(input.shape) == 5: raise ValueError( f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") B, CH, D, H, W = input.shape grid_global: torch.Tensor = create_meshgrid3d(D, H, W, False, device=input.device).permute( 0, 4, 1, 2, 3) grid_global = grid_global.to(input.dtype) # to determine the location we are solving system of linear equations Ax = b, where b is 1st order gradient # and A is Hessian matrix b: torch.Tensor = kornia.filters.spatial_gradient3d(input, order=1, mode='diff') # b = b.permute(0, 1, 3, 4, 5, 2).reshape(-1, 3, 1) A: torch.Tensor = kornia.filters.spatial_gradient3d(input, order=2, mode='diff') A = A.permute(0, 1, 3, 4, 5, 2).reshape(-1, 6) dxx = A[..., 0] dyy = A[..., 1] dss = A[..., 2] dxy = 0.25 * A[..., 3] # normalization to match OpenCV implementation dys = 0.25 * A[..., 4] # normalization to match OpenCV implementation dxs = 0.25 * A[..., 5] # normalization to match OpenCV implementation Hes = torch.stack([dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss], dim=-1).view(-1, 3, 3) # The following is needed to avoid singular cases Hes += torch.rand(Hes[0].size(), device=Hes.device).abs()[None] * eps nms_mask: torch.Tensor = kornia.feature.nms3d(input, (3, 3, 3), True) x_solved: torch.Tensor = torch.zeros_like(b) x_solved_masked, _ = _torch_solve_cast(b[nms_mask.view(-1)], Hes[nms_mask.view(-1)]) x_solved.masked_scatter_(nms_mask.view(-1, 1, 1), x_solved_masked) dx: torch.Tensor = -x_solved # Ignore ones, which are far from window center mask1 = dx.abs().max(dim=1, keepdim=True)[0] > 0.7 dx.masked_fill_(mask1.expand_as(dx), 0) dy: torch.Tensor = 0.5 * torch.bmm(b.permute(0, 2, 1), dx) y_max = input + dy.view(B, CH, D, H, W) if strict_maxima_bonus > 0: y_max += strict_maxima_bonus * nms_mask.to(input.dtype) dx_res: torch.Tensor = dx.flip(1).reshape(B, CH, D, H, W, 3).permute(0, 1, 5, 2, 3, 4) coords_max: torch.Tensor = grid_global.repeat(B, 1, 1, 1, 1).unsqueeze(1) coords_max = coords_max + dx_res return coords_max, y_max