def relative_pose_error(kp1,
                        kp2,
                        kp1_desc,
                        kp2_desc,
                        shift_scale1,
                        shift_scale2,
                        intrinsics1,
                        intrinsics2,
                        extrinsics1,
                        extrinsics2,
                        px_thresh,
                        detailed=False):
    """
    :param kp1: B x N x 2
    :param kp2: B x N x 2
    :param kp1_desc: B x N x C
    :param kp2_desc: B x N x C
    :param shift_scale1: B x 4
    :param shift_scale2: B x 4
    :param intrinsics1: B x 3 x 3
    :param intrinsics2: B x 3 x 3
    :param extrinsics1: B x 4 x 4
    :param extrinsics2: B x 4 x 4
    :param px_thresh: list
    :param detailed: bool
    """
    mutual_desc_matches_mask, nn_desc_ids = get_mutual_desc_matches(
        kp1_desc, kp2_desc, None, 0.9)

    r_kp1 = revert_data_transform(kp1, shift_scale1)
    r_kp2 = revert_data_transform(kp2, shift_scale2)

    nn_r_kp2 = select_kp(r_kp2, nn_desc_ids)

    num_thresh = len(px_thresh)
    b, n = kp1.shape[:2]

    gt_rel_pose = get_gt_rel_pose(extrinsics1, extrinsics2)

    R_err = torch.zeros(num_thresh, b)
    t_err = torch.zeros(num_thresh, b)

    est_inl_mask = torch.zeros(num_thresh, b, n,
                               dtype=torch.bool).to(kp1.device)

    for i, thresh in enumerate(px_thresh):
        i_est_rel_pose, i_est_inl_mask = prepare_rel_pose(
            r_kp1, nn_r_kp2, mutual_desc_matches_mask, intrinsics1,
            intrinsics2, thresh)

        R_err[i] = angle_mat(i_est_rel_pose[:, :3, :3], gt_rel_pose[:, :3, :3])
        t_err[i] = angle_vec(i_est_rel_pose[:, :3, 3], gt_rel_pose[:, :3, 3])

    if detailed:
        return R_err, t_err, est_inl_mask, mutual_desc_matches_mask, nn_desc_ids

    else:
        return R_err, t_err, est_inl_mask
def mutual_nn_matcher(descriptors1, descriptors2):
    match_mask, nn_desc_id1 = get_mutual_desc_matches(
        descriptors1.unsqueeze(0), descriptors2.unsqueeze(0),
        DescriptorDistance.INV_COS_SIM, 0.9)

    ids1 = torch.arange(0, descriptors1.shape[0], device=descriptors1.device)

    matches = torch.stack([ids1[match_mask[0]],
                           nn_desc_id1[0][match_mask[0]]]).t()
    return matches.data.cpu().numpy()
def relative_param_pose_error(kp1, kp2, kp1_desc, kp2_desc, shift_scale1,
                              shift_scale2, intrinsics1, intrinsics2,
                              extrinsics1, extrinsics2, px_thresh):
    mutual_desc_matches_mask, nn_desc_ids = get_mutual_desc_matches(
        kp1_desc, kp2_desc, None, 0.9)

    r_kp1 = revert_data_transform(kp1, shift_scale1)
    r_kp2 = revert_data_transform(kp2, shift_scale2)

    nn_r_kp2 = select_kp(r_kp2, nn_desc_ids)
    nn_r_i1_kp2 = change_intrinsics(nn_r_kp2, intrinsics2, intrinsics1)

    gt_E_param = compose_gt_transform(intrinsics1, intrinsics2, extrinsics1,
                                      extrinsics2, E_param)

    num_thresh = len(px_thresh)
    b = kp1.shape[0]

    R_param_err = torch.zeros(num_thresh, b)
    t_param_err = torch.zeros(num_thresh, b)

    success_mask = torch.zeros(num_thresh, b, dtype=torch.bool)

    for i, thresh in enumerate(px_thresh):
        i_est_E_param, i_success_mask = prepare_param_rel_pose(
            r_kp1, nn_r_i1_kp2, mutual_desc_matches_mask, intrinsics1,
            intrinsics2, thresh)

        R_param_err[i] = (i_est_E_param[:, :3] -
                          gt_E_param[:, :3]).norm(dim=-1)
        t_param_err[i] = (i_est_E_param[:, 3:] -
                          gt_E_param[:, 3:]).norm(dim=-1)

        success_mask[i] = i_success_mask

    return R_param_err, t_param_err, success_mask
def epipolar_match_score(kp1,
                         kp2,
                         w_kp1,
                         w_kp2,
                         w_vis_kp1_mask,
                         w_vis_kp2_mask,
                         kp1_desc,
                         kp2_desc,
                         shift_scale1,
                         shift_scale2,
                         intrinsics1,
                         intrinsics2,
                         extrinsics1,
                         extrinsics2,
                         px_thresh,
                         dd_measure,
                         detailed=False):
    mutual_desc_matches_mask, nn_desc_ids = get_mutual_desc_matches(
        kp1_desc, kp2_desc, dd_measure, None)

    # Verify matches by using the largest pixel threshold
    v_mutual_desc_matches_mask, nn_kp_values = verify_mutual_desc_matches(
        nn_desc_ids,
        kp1,
        kp2,
        w_kp1,
        w_kp2,
        w_vis_kp1_mask,
        w_vis_kp2_mask,
        px_thresh[-1],
        return_reproj=True)

    # Select minimum number of visible points for each scene
    num_vis_gt_matches = get_num_vis_gt_matches(w_vis_kp1_mask, w_vis_kp2_mask)

    num_thresh = len(px_thresh)
    b, n = kp1.shape[:2]

    o_kp1 = revert_data_transform(kp1, shift_scale1)
    o_kp2 = revert_data_transform(kp2, shift_scale2)

    nn_o_kp2 = select_kp(o_kp2, nn_desc_ids)

    F = compose_gt_transform(intrinsics1, intrinsics2, extrinsics1,
                             extrinsics2)

    ep_dist = epipolar_distance(o_kp1, nn_o_kp2, F)

    if detailed:
        em_scores = torch.zeros(num_thresh, b)
        num_matches = torch.zeros(num_thresh, b)
        match_mask = torch.zeros(num_thresh, b, n)
    else:
        em_scores = torch.zeros(num_thresh)

    for i, thresh in enumerate(px_thresh):
        if i != num_thresh - 1:
            i_mutual_matches_mask = mutual_desc_matches_mask * nn_kp_values.le(
                thresh) * ep_dist.le(thresh)
        else:
            i_mutual_matches_mask = mutual_desc_matches_mask * v_mutual_desc_matches_mask * ep_dist.le(
                thresh)

        i_num_matches = i_mutual_matches_mask.sum(dim=-1).float()

        if detailed:
            em_scores[i] = i_num_matches / num_vis_gt_matches
            num_matches[i] = i_num_matches
            match_mask[i] = i_mutual_matches_mask
        else:
            em_scores[i] = (i_num_matches / num_vis_gt_matches).mean()

    if detailed:
        return em_scores, num_matches, num_vis_gt_matches, nn_desc_ids, match_mask
    else:
        return em_scores
def mean_matching_accuracy(kp1,
                           kp2,
                           w_kp1,
                           w_kp2,
                           w_vis_kp1_mask,
                           w_vis_kp2_mask,
                           kp1_desc,
                           kp2_desc,
                           px_thresh,
                           dd_measure,
                           detailed=False):
    """
    :param kp1: B x N x 2; keypoints on the first image
    :param w_kp1: B x N x 2; keypoints on the first image projected to the second
    :param kp2: B x N x 2; keypoints on the second image
    :param w_kp2: B x N x 2; keypoints on the second image projected to the first
    :param w_vis_kp1_mask: B x N; keypoints on the first image which are visible on the second
    :param w_vis_kp2_mask: B x N; keypoints on the second image which are visible on the first
    :param kp1_desc: B x N x C; descriptors for keypoints on the first image
    :param kp2_desc: B x N x C; descriptors for keypoints on the second image
    :param px_thresh: list; keypoints distance thresholds
    :param dd_measure: measure of descriptor distance. Can be L2-norm or similarity measure
    :param detailed: return detailed information :type bool
    """
    mutual_desc_matches_mask, nn_desc_ids = get_mutual_desc_matches(
        kp1_desc, kp2_desc, dd_measure, None)

    # Verify matches by using the largest pixel threshold
    v_mutual_desc_matches_mask, nn_kp_values = verify_mutual_desc_matches(
        nn_desc_ids,
        kp1,
        kp2,
        w_kp1,
        w_kp2,
        w_vis_kp1_mask,
        w_vis_kp2_mask,
        px_thresh[-1],
        return_reproj=True)

    num_vis_gt_matches = mutual_desc_matches_mask.sum(dim=-1).float().clamp(
        min=1e-8)

    num_thresh = len(px_thresh)
    b, n = kp1.shape[:2]

    if detailed:
        mma_scores = torch.zeros(num_thresh, b)
        num_matches = torch.zeros(num_thresh, b)
        match_mask = torch.zeros(num_thresh, b, n)
    else:
        mma_scores = torch.zeros(num_thresh)

    for i, thresh in enumerate(px_thresh):
        if i != num_thresh - 1:
            i_mutual_matches_mask = mutual_desc_matches_mask * nn_kp_values.le(
                thresh)
        else:
            i_mutual_matches_mask = mutual_desc_matches_mask * v_mutual_desc_matches_mask

        i_num_matches = i_mutual_matches_mask.sum(dim=-1).float()

        if detailed:
            mma_scores[i] = i_num_matches / num_vis_gt_matches
            num_matches[i] = i_num_matches
            match_mask[i] = i_mutual_matches_mask
        else:
            mma_scores[i] = (i_num_matches / num_vis_gt_matches).mean()

    if detailed:
        return mma_scores, num_matches, num_vis_gt_matches, nn_desc_ids, match_mask
    else:
        return mma_scores