Example #1
0
def compute_geo_costs(rot, trans, Ex, Kinv, hp0, hp1, tau, Kinv_n=None):
    if Kinv_n is None: Kinv_n = Kinv
    R01 = kornia.angle_axis_to_rotation_matrix(rot)
    H01 = Kinv.inverse().matmul(R01).matmul(Kinv_n)
    comp_hp1 = H01.matmul(hp1.permute(0, 2, 1))
    foe = (comp_hp1 - tau * hp0.permute(0, 2, 1))
    parallax3d = Kinv.matmul(foe)
    p3dmag = parallax3d.norm(2, 1)[:, np.newaxis]
    parallax2d = (comp_hp1 / comp_hp1[:, -1:] - hp0.permute(0, 2, 1))[:, :2]
    p2dmag = parallax2d.norm(2, 1)[:, np.newaxis]
    p2dnorm = parallax2d / (1e-9 + p2dmag)
    foe_cam = Kinv.inverse().matmul(trans[:, :, np.newaxis])
    foe_cam = foe_cam[:, :2] / (1e-9 + foe_cam[:, -1:])
    direct = foe_cam - hp0.permute(0, 2, 1)[:, :2]
    directn = direct / (1e-9 + direct.norm(2, 1)[:, np.newaxis])

    # cost metrics: 0) R-homography+symterr; 1) sampson 2) 2D angular (P+P) 3) 3D distance 4) 3D angular (P+P)
    ##TODO validate
    comp_hp0 = H01.inverse().matmul(hp0.permute(0, 2, 1))
    mcost00 = parallax2d.norm(2, 1)
    mcost01 = (comp_hp0 / comp_hp0[:, -1:] - hp1.permute(0, 2, 1))[:, :2].norm(
        2, 1)
    mcost1 = sampson_err(Kinv.matmul(hp0.permute(0, 2, 1)),
                         Kinv_n.matmul(hp1.permute(0, 2, 1)),
                         Ex.cuda().permute(0, 2, 1))  # variable K
    mcost2 = -(trans[:, -1:, np.newaxis]).sign() * (directn * p2dnorm).sum(
        1, keepdims=True)
    mcost4 = -(trans[:, :, np.newaxis] * parallax3d).sum(1, keepdims=True) / (
        p3dmag + 1e-9)
    mcost3 = torch.clamp(1 - mcost4.pow(2), 0,
                         1).sqrt() * p3dmag * mcost4.sign()
    mcost10 = torch.clamp(1 - mcost2.pow(2), 0,
                          1).sqrt() * p2dmag * mcost2.sign()
    return mcost00, mcost01, mcost1, mcost2, mcost3, mcost4, p3dmag, mcost10
Example #2
0
def generate_scene(num_views: int, num_points: int) -> Dict[str, torch.Tensor]:
    # Generate the 3d points
    points3d = torch.rand(1, num_points, 3)  # NxMx3

    # Create random camera matrix
    K = epipolar.random_intrinsics(0.0, 100.0)  # 1x3x3

    # Create random rotation per view
    ang = torch.rand(num_views, 1) * kornia.pi * 2.0

    rvec = torch.rand(num_views, 3)
    rvec = ang * rvec / torch.norm(rvec, dim=1, keepdim=True)  # Nx3
    rot_mat = kornia.angle_axis_to_rotation_matrix(rvec)  # Nx3x3
    # matches with cv2.Rodrigues -> yay !

    # Create random translation per view
    tx = torch.empty(num_views).uniform_(-0.5, 0.5)
    ty = torch.empty(num_views).uniform_(-0.5, 0.5)
    tz = torch.empty(num_views).uniform_(-1.0, 2.0)
    tvec = torch.stack([tx, ty, tz], dim=1)[..., None]

    # Make sure the shape is in front of the camera
    points3d_trans = (rot_mat @ points3d.transpose(-2, -1)) + tvec
    min_dist = torch.min(points3d_trans[:, 2], dim=1)[0]
    tvec[:, 2, 0] = torch.where(min_dist < 0, tz - min_dist + 1.0, tz)

    # compute projection matrices
    P = epipolar.projection_from_KRt(K, rot_mat, tvec)

    # project points3d and backproject to image plane
    points2d = kornia.transform_points(P, points3d.expand(num_views, -1, -1))

    return dict(K=K, R=rot_mat, t=tvec, P=P, points3d=points3d, points2d=points2d)
Example #3
0
def get_projective_transform(center: torch.Tensor, angles: torch.Tensor,
                             scales: torch.Tensor) -> torch.Tensor:
    r"""Calculate the projection matrix for a 3D rotation.

    .. warning::
        This API signature it is experimental and might suffer some changes in the future.

    The function computes the projection matrix given the center and angles per axis.

    Args:
        center: center of the rotation (x,y,z) in the source with shape :math:`(B, 3)`.
        angles: angle axis vector containing the rotation angles in degrees in the form
            of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute
            the rotation matrix from axis-angle.
        scales: scale factor for x-y-z-directions with shape :math:`(B, 3)`.

    Returns:
        the projection matrix of 3D rotation with shape :math:`(B, 3, 4)`.

    .. note::
        This function is often used in conjunction with :func:`warp_affine3d`.
    """
    if not (len(center.shape) == 2 and center.shape[-1] == 3):
        raise AssertionError(center.shape)
    if not (len(angles.shape) == 2 and angles.shape[-1] == 3):
        raise AssertionError(angles.shape)
    if center.device != angles.device:
        raise AssertionError(center.device, angles.device)
    if center.dtype != angles.dtype:
        raise AssertionError(center.dtype, angles.dtype)

    # create rotation matrix
    angle_axis_rad: torch.Tensor = K.deg2rad(angles)
    rmat: torch.Tensor = K.angle_axis_to_rotation_matrix(
        angle_axis_rad)  # Bx3x3
    scaling_matrix: torch.Tensor = K.eye_like(3, rmat)
    scaling_matrix = scaling_matrix * scales.unsqueeze(dim=1)
    rmat = rmat @ scaling_matrix.to(rmat)

    # define matrix to move forth and back to origin
    from_origin_mat = torch.eye(4)[None].repeat(rmat.shape[0], 1,
                                                1).type_as(center)  # Bx4x4
    from_origin_mat[..., :3, -1] += center

    to_origin_mat = from_origin_mat.clone()
    to_origin_mat = _torch_inverse_cast(from_origin_mat)

    # append translation with zeros
    proj_mat = projection_from_Rt(rmat,
                                  torch.zeros_like(center)[..., None])  # Bx3x4

    # chain 4x4 transforms
    proj_mat = convert_affinematrix_to_homography3d(proj_mat)  # Bx4x4
    proj_mat = from_origin_mat @ proj_mat @ to_origin_mat

    return proj_mat[..., :3, :]  # Bx3x4
Example #4
0
 def aa2matrot(pose):
     '''
     :param Nx1xnum_jointsx3
     :return: pose_matrot: Nx1xnum_jointsx9
     '''
     batch_size = pose.size(0)
     pose_body_matrot = kornia.angle_axis_to_rotation_matrix(
         pose.reshape(-1, 3))[:, :3, :3].contiguous().view(
             batch_size, 1, -1, 9)
     return pose_body_matrot
Example #5
0
def getDiffMap_tensor(hyp, objectCoordinates, sampling, camMat):

    # obj_cam = torch.zeros([height, width, 3]).to(DEVICE).double()
    RotMrx = kornia.angle_axis_to_rotation_matrix(hyp[0, :3].view(1, -1)).view(
        [3, 3]).float()
    f = camMat[0, 0]
    ppx = camMat[0, 2]
    ppy = camMat[1, 2]
    # trans = hyp[3:].view([-1, 1]).to(DEVICE)
    # object_tensor = objectCoordinates.view([1600, 3, 1]).to(DEVICE).double()
    # Rotmrx_tensor = RotMrx.unsqueeze(0).repeat(1600, 1, 1).to(DEVICE).double()
    # print('hyp:', hyp[0, 3] * torch.ones([40, 40, 1]))
    # trans_mrx = torch.cat((hyp[0, 3] * torch.ones([40, 40, 1]),
    #                        hyp[0, 4] * torch.ones([40, 40, 1]),
    #                        hyp[0, 5] * torch.ones([40, 40, 1])), 2)
    # print('trans:', trans_mrx)
    # obj_cam_tensor = torch.bmm(Rotmrx_tensor, object_tensor).view([40, 40, 3]).to(DEVICE) + trans_mrx
    obj_cam_tensor = torch.matmul(objectCoordinates,
                                  torch.transpose(RotMrx, 0, 1))
    obj_cam_tensor[:, :, 0] += hyp[0, 3]
    obj_cam_tensor[:, :, 1] += hyp[0, 4]
    obj_cam_tensor[:, :, 2] += hyp[0, 5]

    # print('trans_mrx:', trans_mrx.shape)
    # print('rot:', RotMrx.shape)
    # obj_cam_tensor = torch.matmul(RotMrx, object_tensor) + trans_mrx

    # for i in range(width * height):
    #     x = i % width
    #     y = i // width
    #
    #     obj_cam[y, x, :] = obj2cam(objectPoints=objectCoordinates[y, x, :],
    #                                       trans=trans,
    #                                       RotMrx=RotMrx)
    #     print('obj_cam:', obj_cam[y, x, :])
    #     print('obj_cam_tensor:', obj_cam_tensor[y, x, :])

    project_x = f * obj_cam_tensor[:, :, 0] / obj_cam_tensor[:, :, 2] + ppx
    project_y = f * obj_cam_tensor[:, :, 1] / obj_cam_tensor[:, :, 2] + ppy
    reproject_x = project_x - sampling[:, :, 0]
    reproject_y = project_y - sampling[:, :, 1]
    diff = torch.sqrt(reproject_x**2 + reproject_y**2)
    diffMap = torch.clamp(diff, 0., 100.)
    diffMap = torch.unsqueeze(diffMap, 0)
    # diff = torch.zeros([1, height, width]).to(DEVICE)
    # diff[0, :, :] = torch.sqrt(reproject_x ** 2 + reproject_y ** 2).to(DEVICE)
    # diffMap = torch.where(diff > 100., torch.full_like(diff, 100.), diff).requires_grad_(True)
    # diffMap[0, y, x] = getdiff_tensor(objectPoints=objectCoordinates[y, x, :],
    #                                   hyps=hyp,
    #                                   pixelx=sampling[y, x][0],
    #                                   pixely=sampling[y, x][1],
    #                                   RotMrx=RotMrx,
    #                                   f=f, ppx=ppx, ppy=ppy).to(DEVICE)
    return diffMap
Example #6
0
def get_skew_mat(transx, rotx):
    rot = kornia.angle_axis_to_rotation_matrix(rotx)
    trans = -rot.permute(0, 2, 1).matmul(transx[:, :, np.newaxis])[:, :, 0]
    rot = rot.permute(0, 2, 1)
    tx = torch.zeros(transx.shape[0], 3, 3)
    tx[:, 0, 1] = -transx[:, 2]
    tx[:, 0, 2] = transx[:, 1]
    tx[:, 1, 0] = transx[:, 2]
    tx[:, 1, 2] = -transx[:, 0]
    tx[:, 2, 0] = -transx[:, 1]
    tx[:, 2, 1] = transx[:, 0]
    return rot.matmul(tx)
Example #7
0
def get_res(pts2d, pts3d, K, P):
    n = pts2d.size(0)
    m = 6
    feas1 = P[0,m-1].item() > 0
    R = kn.angle_axis_to_rotation_matrix(P[0, 0:m - 3].view(1, 3))
    P = torch.cat((R[0, 0:3, 0:3].view(3, 3), P[0, m - 3:m].view(3, 1)), dim=-1)
    pts3d_h = torch.cat((pts3d,torch.ones(n,1,device=pts3d.device)), dim=-1)
    pts3d_cam = pts3d_h.mm(P.transpose(0, 1))
    feas2 = (pts3d_cam[:,2].min().item() >= 0)
    feas = feas1 and feas2
    pts2d_proj = pts3d_cam.mm(K.transpose(0, 1))
    S = pts2d_proj[:, 2].view(n, 1)
    res = pts2d - pts2d_proj[:, 0:2].div(S)
    return torch.norm(res,dim=1).sum().item(), feas
Example #8
0
    def test_triplet_amq(self, axis, device, dtype, atol, rtol):
        array = [[0.0, 0.0, 0.0]]
        array[0][axis] = kornia.pi / 2.0
        angle_axis = torch.tensor(array, device=device, dtype=dtype)
        assert angle_axis.shape[-1] == 3

        rot_m = kornia.angle_axis_to_rotation_matrix(angle_axis)
        assert rot_m.shape[-1] == 3
        assert rot_m.shape[-2] == 3

        quaternion = kornia.rotation_matrix_to_quaternion(rot_m, order=QuaternionCoeffOrder.WXYZ)
        assert quaternion.shape[-1] == 4

        angle_axis_hat = kornia.quaternion_to_angle_axis(quaternion, order=QuaternionCoeffOrder.WXYZ)
        assert_close(angle_axis_hat, angle_axis, atol=atol, rtol=rtol)
Example #9
0
def test_angle_axis_to_rotation_matrix(batch_size, device, dtype):
    # generate input data
    angle_axis = torch.rand(batch_size, 3, device=device, dtype=dtype)
    eye_batch = create_eye_batch(batch_size, 3, device=device, dtype=dtype)

    # apply transform
    rotation_matrix = kornia.angle_axis_to_rotation_matrix(angle_axis)

    rotation_matrix_eye = torch.matmul(rotation_matrix,
                                       rotation_matrix.transpose(1, 2))
    assert_allclose(rotation_matrix_eye, eye_batch, atol=1e-4, rtol=1e-4)

    # evaluate function gradient
    angle_axis = tensor_to_gradcheck_var(angle_axis)  # to var
    assert gradcheck(kornia.angle_axis_to_rotation_matrix, (angle_axis, ),
                     raise_exception=True)
Example #10
0
def batch_project(P, pts3d, K, angle_axis=True):
    n = pts3d.size(0)
    bs = P.size(0)
    device = P.device
    pts3d_h = torch.cat((pts3d, torch.ones(n, 1, device=device)), dim=-1)
    if angle_axis:
        R_out = kn.angle_axis_to_rotation_matrix(P[:, 0:3].view(bs, 3))
        PM = torch.cat((R_out[:,0:3,0:3], P[:, 3:6].view(bs, 3, 1)), dim=-1)
    else:
        PM = P
    pts3d_cam = pts3d_h.matmul(PM.transpose(1,2))
    pts2d_proj = pts3d_cam.matmul(K.t())
    S = pts2d_proj[:,:, 2].view(bs, n, 1)
    pts2d_pro = pts2d_proj[:,:,0:2].div(S)

    return pts2d_pro
Example #11
0
def test_angle_axis_to_rotation_matrix(batch_size, device_type):
    # generate input data
    device = torch.device(device_type)
    angle_axis = torch.rand(batch_size, 3).to(device)
    eye_batch = utils.create_eye_batch(batch_size, 4).to(device)

    # apply transform
    rotation_matrix = kornia.angle_axis_to_rotation_matrix(angle_axis)

    rotation_matrix_eye = torch.matmul(
        rotation_matrix, rotation_matrix.transpose(1, 2))
    assert check_equal_torch(rotation_matrix_eye, eye_batch)

    # evaluate function gradient
    angle_axis = utils.tensor_to_gradcheck_var(angle_axis)  # to var
    assert gradcheck(kornia.angle_axis_to_rotation_matrix, (angle_axis,),
                     raise_exception=True)
Example #12
0
    def test_triplet_qam_xyzw(self, axis, device, dtype, atol, rtol):
        array = [[0.0, 0.0, 0.0, 0.0]]
        array[0][axis] = 1.0
        quaternion = torch.tensor(array, device=device, dtype=dtype)
        assert quaternion.shape[-1] == 4

        with pytest.warns(UserWarning):
            angle_axis = kornia.quaternion_to_angle_axis(quaternion, order=QuaternionCoeffOrder.XYZW)
        assert angle_axis.shape[-1] == 3

        rot_m = kornia.angle_axis_to_rotation_matrix(angle_axis)
        assert rot_m.shape[-1] == 3
        assert rot_m.shape[-2] == 3

        with pytest.warns(UserWarning):
            quaternion_hat = kornia.rotation_matrix_to_quaternion(rot_m, order=QuaternionCoeffOrder.XYZW)
        assert_close(quaternion_hat, quaternion, atol=atol, rtol=rtol)
Example #13
0
def get_projective_transform(center: torch.Tensor,
                             angles: torch.Tensor) -> torch.Tensor:
    r"""Calculates the projection matrix for a 3D rotation.

    The function computes the projection matrix given the center and angles per axis.

    Args:
        center (torch.Tensor): center of the rotation in the source with shape :math:`(B, 3)`.
        angles (torch.Tensor): angle axis vector containing the rotation angles in degrees in the form
            of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute
            the rotation matrix from axis-angle.

    Returns:
        torch.Tensor: the projection matrix of 3D rotation with shape :math:`(B, 3, 4)`.

    """
    assert len(center.shape) == 2 and center.shape[-1] == 3, center.shape
    assert len(angles.shape) == 2 and angles.shape[-1] == 3, angles.shape
    assert center.device == angles.device, (center.device, angles.device)
    assert center.dtype == angles.dtype, (center.dtype, angles.dtype)

    # create rotation matrix
    angle_axis_rad: torch.Tensor = K.deg2rad(angles)
    rmat: torch.Tensor = K.angle_axis_to_rotation_matrix(
        angle_axis_rad)  # Bx3x3

    # define matrix to move forth and back to origin
    from_origin_mat = torch.eye(4)[None].repeat(rmat.shape[0], 1,
                                                1).type_as(center)  # Bx4x4
    from_origin_mat[..., :3, -1] += center

    to_origin_mat = from_origin_mat.clone()
    to_origin_mat = from_origin_mat.inverse()

    # append tranlation with zeros
    proj_mat = projection_from_Rt(rmat,
                                  torch.zeros_like(center)[..., None])  # Bx3x4

    # chain 4x4 transforms
    proj_mat = matrix_to_homogeneous(proj_mat)  # Bx4x4
    proj_mat = (from_origin_mat @ proj_mat @ to_origin_mat)

    return proj_mat[..., :3, :]  # Bx3x4
    def optimize_nodes(
        self, match_count: int, optimized_node_count: int,
        batch_edge_count: int, graph_nodes_i: torch.Tensor,
        source_anchors: torch.Tensor, source_weights: torch.Tensor,
        source_points_filtered: torch.Tensor,
        source_colors_filtered: torch.Tensor,
        correspondence_weights_filtered: torch.Tensor,
        xy_pixels_warped_filtered: torch.Tensor,
        target_matches_filtered: torch.Tensor,
        graph_edge_pairs_filtered: torch.Tensor,
        graph_edge_weights_pairs: torch.Tensor, num_neighbors: int,
        fx: torch.Tensor, fy: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor,
        batch_convergence_info
    ) -> Tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor]:

        self.vec_to_skew_mat.to(source_anchors.device)

        float_dtype = source_weights.dtype
        device = source_weights.device
        gauss_newton_iteration_count = self.gn_num_iter

        # The parameters in GN solver are 3 parameters for rotation and 3 parameters for
        # translation for every node. All node rotation parameters are listed first, and
        # then all node translation parameters are listed.
        #                        transform_delta = [rotations_current, translations_current]
        rotations_current = torch.eye(3, dtype=float_dtype,
                                      device=device).view(1, 3, 3).repeat(
                                          optimized_node_count, 1, 1)
        translations_current = torch.zeros((optimized_node_count, 3, 1),
                                           dtype=float_dtype,
                                           device=device)

        if self.gn_debug:
            print(
                f"\tMatch count: {match_count} || Node count: {optimized_node_count} || Edges count: {batch_edge_count}"
            )

        # Initialize helper structures.
        data_increment_vec_0_3 = torch.arange(0,
                                              match_count * 3,
                                              3,
                                              out=torch.cuda.LongTensor(),
                                              device=device)  # (match_count)
        data_increment_vec_1_3 = torch.arange(1,
                                              match_count * 3,
                                              3,
                                              out=torch.cuda.LongTensor(),
                                              device=device)  # (match_count)
        data_increment_vec_2_3 = torch.arange(2,
                                              match_count * 3,
                                              3,
                                              out=torch.cuda.LongTensor(),
                                              device=device)  # (match_count)

        arap_increment_vec_0_3 = None
        arap_increment_vec_1_3 = None
        arap_increment_vec_2_3 = None
        arap_one_vec = None
        if batch_edge_count > 0:
            arap_increment_vec_0_3 = torch.arange(
                0,
                batch_edge_count * 3,
                3,
                out=torch.cuda.LongTensor(),
                device=device)  # (batch_edge_count)
            arap_increment_vec_1_3 = torch.arange(
                1,
                batch_edge_count * 3,
                3,
                out=torch.cuda.LongTensor(),
                device=device)  # (batch_edge_count)
            arap_increment_vec_2_3 = torch.arange(
                2,
                batch_edge_count * 3,
                3,
                out=torch.cuda.LongTensor(),
                device=device)  # (batch_edge_count)
            arap_one_vec = torch.ones(batch_edge_count,
                                      dtype=float_dtype,
                                      device=device)

        ill_posed_system = False
        residuals = None

        for i_iteration in range(gauss_newton_iteration_count):
            residual_data, jacobian_data = \
                self.compute_data_residual_and_jacobian(data_increment_vec_0_3, data_increment_vec_1_3, data_increment_vec_2_3,
                                                        match_count, optimized_node_count, graph_nodes_i,
                                                        source_anchors, source_weights, source_points_filtered, source_colors_filtered,
                                                        correspondence_weights_filtered, xy_pixels_warped_filtered, target_matches_filtered,
                                                        i_iteration, fx, fy, cx, cy, rotations_current, translations_current)
            loss_arap = None
            if batch_edge_count > 0:
                residual_arap, jacobian_arap = \
                    self.compute_arap_residual_and_jacobian(arap_increment_vec_0_3, arap_increment_vec_1_3, arap_increment_vec_2_3, arap_one_vec,
                                                            graph_edge_pairs_filtered, graph_edge_weights_pairs,
                                                            graph_nodes_i, batch_edge_count, optimized_node_count, num_neighbors,
                                                            rotations_current, translations_current)
                loss_arap = torch.norm(residual_arap).item()
                residuals = torch.cat((residual_data, residual_arap), 0)
                jacobian = torch.cat((jacobian_data, jacobian_arap), 0)
            else:
                residuals = residual_data
                jacobian = jacobian_data

            success, transform_delta = self.solve_linear_system(
                residuals, jacobian, batch_convergence_info)
            ill_posed_system = not success
            if ill_posed_system:
                break

            # Increment the current rotation and translation.
            rotation_increments = kornia.angle_axis_to_rotation_matrix(
                transform_delta[:optimized_node_count * 3].view(
                    optimized_node_count, 3))
            translation_increments = transform_delta[optimized_node_count *
                                                     3:].view(
                                                         optimized_node_count,
                                                         3, 1)

            rotations_current = torch.matmul(rotation_increments,
                                             rotations_current)
            translations_current = translations_current + translation_increments

            loss_data = torch.norm(residual_data).item()
            loss_total = torch.norm(residuals).item()

            batch_convergence_info["data"].append(loss_data)
            batch_convergence_info["total"].append(loss_total)

            if batch_edge_count > 0:
                batch_convergence_info["arap"].append(loss_arap)

            if self.gn_debug:
                if batch_edge_count > 0:
                    print(
                        f"\t\t-->Iteration: {i_iteration}. "
                        f"Loss: \tdata = {loss_data:.3f}, \tarap = {loss_arap:.3f}, \ttotal = {loss_total:.3f}"
                    )
                else:
                    print(
                        f"\t\t-->Iteration: {i_iteration}. Loss: \tdata = {loss_data:.3f}, \ttotal = {loss_total:.3f}"
                    )
        return ill_posed_system, residuals, rotations_current, translations_current
Example #15
0
 def angle_axis_to_rot_matrix(euler):
     return kornia.angle_axis_to_rotation_matrix(angle_axis=euler)
Example #16
0
 def test_smoke(self, device):
     angle_axis = torch.zeros(3)
     rotation_matrix = kornia.angle_axis_to_rotation_matrix(angle_axis)
     assert rotation_matrix.shape == (3, 3)
Example #17
0
    def backward(ctx, grad_output):

        pts2d, P_6d, pts3d, K = ctx.saved_tensors
        device = pts2d.device
        bs = pts2d.size(0)
        n = pts2d.size(1)
        m = 6

        grad_x = torch.zeros_like(pts2d)
        grad_z = torch.zeros_like(pts3d)
        grad_K = torch.zeros_like(K)

        for i in range(bs):
            J_fy = torch.zeros(m,m, device=device)
            J_fx = torch.zeros(m,2*n, device=device)
            J_fz = torch.zeros(m,3*n, device=device)
            J_fK = torch.zeros(m, 9, device=device)

            coefs = get_coefs(P_6d[i].view(1,6), pts3d, K)

            pts2d_flat = pts2d[i].clone().view(-1).detach().requires_grad_()
            P_6d_flat = P_6d[i].clone().view(-1).detach().requires_grad_()
            pts3d_flat = pts3d.clone().view(-1).detach().requires_grad_()
            K_flat = K.clone().view(-1).detach().requires_grad_()

            for j in range(m):
                torch.set_grad_enabled(True)
                if j > 0:
                    pts2d_flat.grad.zero_()
                    P_6d_flat.grad.zero_()
                    pts3d_flat.grad.zero_()
                    K_flat.grad.zero_()

                R = kn.angle_axis_to_rotation_matrix(P_6d_flat[0:m-3].view(1,3))

                P = torch.cat((R[0,0:3,0:3].view(3,3), P_6d_flat[m-3:m].view(3,1)),dim=-1)
                KP = torch.mm(K_flat.view(3,3), P)
                pts2d_i = pts2d_flat.view(n,2).transpose(0,1)
                pts3d_i = torch.cat((pts3d_flat.view(n,3),torch.ones(n,1,device=device)),dim=-1).t()
                proj_i = KP.mm(pts3d_i)
                Si = proj_i[2,:].view(1,n)

                r = pts2d_i*Si-proj_i[0:2,:]
                coef = coefs[:,:,j].transpose(0,1) # size: [2,n]
                fj = (coef*r).sum()
                fj.backward()
                J_fy[j,:] = P_6d_flat.grad.clone()
                J_fx[j,:] = pts2d_flat.grad.clone()
                J_fz[j,:] = pts3d_flat.grad.clone()
                J_fK[j,:] = K_flat.grad.clone()

            inv_J_fy = torch.inverse(J_fy)

            J_yx = (-1) * torch.mm(inv_J_fy, J_fx)
            J_yz = (-1) * torch.mm(inv_J_fy, J_fz)
            J_yK = (-1) * torch.mm(inv_J_fy, J_fK)

            grad_x[i] = grad_output[i].view(1,m).mm(J_yx).view(n,2)
            grad_z += grad_output[i].view(1,m).mm(J_yz).view(n,3)
            grad_K += grad_output[i].view(1,m).mm(J_yK).view(3,3)

        return grad_x, grad_z, grad_K, None
Example #18
0
def P6d2PM(P6d):
    bs = P6d.size(0)
    PM = kn.angle_axis_to_rotation_matrix(P6d[:,0:3].view(bs,3))
    T = P6d[:,3:6].view(bs,3,1)
    PM = torch.cat((PM[:,0:3,0:3].view(bs,3,3),T),dim=-1)
    return PM