예제 #1
0
    def test_fast_global_registration(self):
        a = torch.randn(100, 3)

        R_gt = euler_angles_to_rotation_matrix(torch.rand(3) * np.pi)
        t_gt = torch.rand(3)
        T_gt = torch.eye(4)
        T_gt[:3, :3] = R_gt
        T_gt[:3, 3] = t_gt
        b = a.mm(R_gt.T) + t_gt
        T_pred = fast_global_registration(a, b, mu_init=1, num_iter=20)
        npt.assert_allclose(T_pred.numpy(), T_gt.numpy(), rtol=1e-3)
예제 #2
0
 def test_fast_global_registration_with_outliers(self):
     a = torch.randn(100, 3)
     R_gt = euler_angles_to_rotation_matrix(torch.rand(3) * np.pi)
     t_gt = torch.rand(3)
     T_gt = torch.eye(4)
     T_gt[:3, :3] = R_gt
     T_gt[:3, 3] = t_gt
     b = a.mm(R_gt.T) + t_gt
     b[[1, 5, 20, 32, 74, 17, 27, 77, 88, 89]] *= 42
     T_pred = fast_global_registration(a, b, mu_init=1, num_iter=20)
     # T_pred = estimate_transfo(a, b)
     npt.assert_allclose(T_pred.numpy(), T_gt.numpy(), rtol=1e-3)
예제 #3
0
    def track(self, model: model_interface.TrackerInterface, **kwargs):
        super().track(model)
        if self._stage != "train":
            batch_idx, batch_idx_target = model.get_batch()
            # batch_xyz, batch_xyz_target = model.get_xyz()  # type: ignore
            # batch_ind, batch_ind_target, batch_size_ind = model.get_ind()  # type: ignore
            input, input_target = model.get_input()
            batch_xyz, batch_xyz_target = input.pos, input_target.pos
            batch_ind, batch_ind_target, batch_size_ind = input.ind, input_target.ind, input.size
            batch_feat, batch_feat_target = model.get_output()

            nb_batches = batch_idx.max() + 1
            cum_sum = 0
            cum_sum_target = 0
            begin = 0
            end = batch_size_ind[0].item()
            for b in range(nb_batches):
                xyz = batch_xyz[batch_idx == b]
                xyz_target = batch_xyz_target[batch_idx_target == b]
                feat = batch_feat[batch_idx == b]
                feat_target = batch_feat_target[batch_idx_target == b]
                # as we have concatenated ind,
                # we need to substract the cum_sum because we deal
                # with each batch independently
                # ind = batch_ind[b * len(batch_ind) / nb_batches : (b + 1) * len(batch_ind) / nb_batches] - cum_sum
                # ind_target = (batch_ind_target[b * len(batch_ind_target) / nb_batches : (b + 1) * len(batch_ind_target) / nb_batches]- cum_sum_target)
                ind = batch_ind[begin:end] - cum_sum
                ind_target = batch_ind_target[begin:end] - cum_sum_target
                # print(begin, end)
                if b < nb_batches - 1:
                    begin = end
                    end = begin + batch_size_ind[b + 1].item()
                cum_sum += len(xyz)
                cum_sum_target += len(xyz_target)
                rand = torch.randperm(len(feat))[:self.num_points]
                rand_target = torch.randperm(
                    len(feat_target))[:self.num_points]

                matches_gt = torch.stack([ind, ind_target]).transpose(0, 1)

                # print(matches_gt.max(0), len(xyz), len(xyz_target), len(matches_gt))
                # print(batch_ind.shape, nb_batches)
                T_gt = estimate_transfo(xyz[matches_gt[:, 0]],
                                        xyz_target[matches_gt[:, 1]])

                matches_pred = get_matches(feat[rand],
                                           feat_target[rand_target])
                T_pred = fast_global_registration(
                    xyz[rand][matches_pred[:, 0]],
                    xyz_target[rand_target][matches_pred[:, 1]])

                hit_ratio = compute_hit_ratio(
                    xyz[rand][matches_pred[:, 0]],
                    xyz_target[rand_target][matches_pred[:,
                                                         1]], T_gt, self.tau_1)

                trans_error, rot_error = compute_transfo_error(T_pred, T_gt)

                sr_err = compute_scaled_registration_error(xyz, T_gt, T_pred)
                self._hit_ratio.add(hit_ratio.item())
                self._feat_match_ratio.add(
                    float(hit_ratio.item() > self.tau_2))
                self._trans_error.add(trans_error.item())
                self._rot_error.add(rot_error.item())
                self._rre.add(rot_error.item() < self.rot_thresh)
                self._rte.add(trans_error.item() < self.trans_thresh)
                self._sr_err.add(sr_err.item())
예제 #4
0
    data_t = transform(
        Batch(pos=torch.from_numpy(pcd_t).float(),
              batch=torch.zeros(pcd_t.shape[0]).long()))

    model = PretainedRegistry.from_pretrained(
        "minkowski-registration-kitti").cuda()

    with torch.no_grad():
        model.set_input(data_s, "cuda")
        output_s = model.forward()
        model.set_input(data_t, "cuda")
        output_t = model.forward()

    rand_s = torch.randint(0, len(output_s), (5000, ))
    rand_t = torch.randint(0, len(output_t), (5000, ))
    matches = get_matches(output_s[rand_s], output_t[rand_t])
    T_est = fast_global_registration(data_s.pos[rand_s][matches[:, 0]],
                                     data_t.pos[rand_t][matches[:, 1]])

    o3d_pcd_s = o3d.geometry.PointCloud()
    o3d_pcd_s.points = o3d.utility.Vector3dVector(data_s.pos.cpu().numpy())
    o3d_pcd_s.paint_uniform_color([0.9, 0.7, 0.1])

    o3d_pcd_t = o3d.geometry.PointCloud()
    o3d_pcd_t.points = o3d.utility.Vector3dVector(data_t.pos.cpu().numpy())
    o3d_pcd_t.paint_uniform_color([0.1, 0.7, 0.9])

    o3d.visualization.draw_geometries([o3d_pcd_s, o3d_pcd_t])
    o3d.visualization.draw_geometries(
        [o3d_pcd_s.transform(T_est.cpu().numpy()), o3d_pcd_t])
예제 #5
0
def compute_metrics(
    xyz,
    xyz_target,
    feat,
    feat_target,
    T_gt,
    sym=False,
    tau_1=0.1,
    tau_2=0.05,
    rot_thresh=5,
    trans_thresh=5,
    use_ransac=False,
    ransac_thresh=0.02,
    use_teaser=False,
    noise_bound_teaser=0.1,
    registration_recall_thresh=0.2,
    xyz_gt=None,
    xyz_target_gt=None,
):
    """
    compute all the necessary metrics
    compute the hit ratio,
    compute the feat_match_ratio

    using fast global registration
    compute the translation error
    compute the rotation error
    compute rre, and compute the rte

    using ransac
    compute the translation error
    compute the rotation error
    compute rre, and compute the rte

    using Teaser++
    compute the translation error
    compute the rotation error
    compute rre, and compute the rtr

    Parameters
    ----------

    xyz: torch tensor of size N x 3
    xyz_target: torch tensor of size N x 3

    feat: torch tensor of size N x C
    feat_target: torch tensor of size N x C


    T_gt; 4 x 4 matrix
    """

    res = dict()
    t0 = time.time()
    matches_pred = get_matches(feat, feat_target, sym=sym)
    t1 = time.time()
    hit_ratio = compute_hit_ratio(xyz[matches_pred[:, 0]], xyz_target[matches_pred[:, 1]], T_gt, tau_1)
    res["hit_ratio"] = hit_ratio.item()
    res["feat_match_ratio"] = float(hit_ratio.item() > tau_2)
    res["time_get_matches"] = t1 - t0
    # fast global registration
    t = time.time()
    T_fgr = fast_global_registration(xyz[matches_pred[:, 0]], xyz_target[matches_pred[:, 1]])
    trans_error_fgr, rot_error_fgr = compute_transfo_error(T_fgr, T_gt)
    res["trans_error_fgr"] = trans_error_fgr.item()
    res["rot_error_fgr"] = rot_error_fgr.item()
    res["rre_fgr"] = float(rot_error_fgr.item() < rot_thresh)
    res["rte_fgr"] = float(trans_error_fgr.item() < trans_thresh)
    res["sr_fgr"] = compute_scaled_registration_error(xyz, T_gt, T_fgr).item()
    res["time_fgr"] = time.time() - t
    if xyz_gt is not None and xyz_target_gt is not None:
        res["registration_recall_fgr"] = compute_registration_recall(
            xyz_gt, xyz_target_gt, T_fgr, registration_recall_thresh
        )

    # teaser pp
    if use_teaser:
        t = time.time()
        T_teaser = teaser_pp_registration(
            xyz[matches_pred[:, 0]], xyz_target[matches_pred[:, 1]], noise_bound=noise_bound_teaser
        )
        trans_error_teaser, rot_error_teaser = compute_transfo_error(T_teaser, T_gt)
        res["trans_error_teaser"] = trans_error_teaser.item()
        res["rot_error_teaser"] = rot_error_teaser.item()
        res["rre_teaser"] = float(rot_error_teaser.item() < rot_thresh)
        res["rte_teaser"] = float(trans_error_teaser.item() < trans_thresh)
        res["sr_teaser"] = compute_scaled_registration_error(xyz, T_gt, T_teaser).item()
        res["time_teaser"] = time.time() - t
        if xyz_gt is not None and xyz_target_gt is not None:
            res["registration_recall_teaser"] = compute_registration_recall(
                xyz_gt, xyz_target_gt, T_teaser, registration_recall_thresh
            )

    if use_ransac:
        raise NotImplementedError

    return res