def test_fast_global_registration(self): a = torch.randn(100, 3) R_gt = euler_angles_to_rotation_matrix(torch.rand(3) * np.pi) t_gt = torch.rand(3) T_gt = torch.eye(4) T_gt[:3, :3] = R_gt T_gt[:3, 3] = t_gt b = a.mm(R_gt.T) + t_gt T_pred = fast_global_registration(a, b, mu_init=1, num_iter=20) npt.assert_allclose(T_pred.numpy(), T_gt.numpy(), rtol=1e-3)
def test_fast_global_registration_with_outliers(self): a = torch.randn(100, 3) R_gt = euler_angles_to_rotation_matrix(torch.rand(3) * np.pi) t_gt = torch.rand(3) T_gt = torch.eye(4) T_gt[:3, :3] = R_gt T_gt[:3, 3] = t_gt b = a.mm(R_gt.T) + t_gt b[[1, 5, 20, 32, 74, 17, 27, 77, 88, 89]] *= 42 T_pred = fast_global_registration(a, b, mu_init=1, num_iter=20) # T_pred = estimate_transfo(a, b) npt.assert_allclose(T_pred.numpy(), T_gt.numpy(), rtol=1e-3)
def track(self, model: model_interface.TrackerInterface, **kwargs): super().track(model) if self._stage != "train": batch_idx, batch_idx_target = model.get_batch() # batch_xyz, batch_xyz_target = model.get_xyz() # type: ignore # batch_ind, batch_ind_target, batch_size_ind = model.get_ind() # type: ignore input, input_target = model.get_input() batch_xyz, batch_xyz_target = input.pos, input_target.pos batch_ind, batch_ind_target, batch_size_ind = input.ind, input_target.ind, input.size batch_feat, batch_feat_target = model.get_output() nb_batches = batch_idx.max() + 1 cum_sum = 0 cum_sum_target = 0 begin = 0 end = batch_size_ind[0].item() for b in range(nb_batches): xyz = batch_xyz[batch_idx == b] xyz_target = batch_xyz_target[batch_idx_target == b] feat = batch_feat[batch_idx == b] feat_target = batch_feat_target[batch_idx_target == b] # as we have concatenated ind, # we need to substract the cum_sum because we deal # with each batch independently # ind = batch_ind[b * len(batch_ind) / nb_batches : (b + 1) * len(batch_ind) / nb_batches] - cum_sum # ind_target = (batch_ind_target[b * len(batch_ind_target) / nb_batches : (b + 1) * len(batch_ind_target) / nb_batches]- cum_sum_target) ind = batch_ind[begin:end] - cum_sum ind_target = batch_ind_target[begin:end] - cum_sum_target # print(begin, end) if b < nb_batches - 1: begin = end end = begin + batch_size_ind[b + 1].item() cum_sum += len(xyz) cum_sum_target += len(xyz_target) rand = torch.randperm(len(feat))[:self.num_points] rand_target = torch.randperm( len(feat_target))[:self.num_points] matches_gt = torch.stack([ind, ind_target]).transpose(0, 1) # print(matches_gt.max(0), len(xyz), len(xyz_target), len(matches_gt)) # print(batch_ind.shape, nb_batches) T_gt = estimate_transfo(xyz[matches_gt[:, 0]], xyz_target[matches_gt[:, 1]]) matches_pred = get_matches(feat[rand], feat_target[rand_target]) T_pred = fast_global_registration( xyz[rand][matches_pred[:, 0]], xyz_target[rand_target][matches_pred[:, 1]]) hit_ratio = compute_hit_ratio( xyz[rand][matches_pred[:, 0]], xyz_target[rand_target][matches_pred[:, 1]], T_gt, self.tau_1) trans_error, rot_error = compute_transfo_error(T_pred, T_gt) sr_err = compute_scaled_registration_error(xyz, T_gt, T_pred) self._hit_ratio.add(hit_ratio.item()) self._feat_match_ratio.add( float(hit_ratio.item() > self.tau_2)) self._trans_error.add(trans_error.item()) self._rot_error.add(rot_error.item()) self._rre.add(rot_error.item() < self.rot_thresh) self._rte.add(trans_error.item() < self.trans_thresh) self._sr_err.add(sr_err.item())
data_t = transform( Batch(pos=torch.from_numpy(pcd_t).float(), batch=torch.zeros(pcd_t.shape[0]).long())) model = PretainedRegistry.from_pretrained( "minkowski-registration-kitti").cuda() with torch.no_grad(): model.set_input(data_s, "cuda") output_s = model.forward() model.set_input(data_t, "cuda") output_t = model.forward() rand_s = torch.randint(0, len(output_s), (5000, )) rand_t = torch.randint(0, len(output_t), (5000, )) matches = get_matches(output_s[rand_s], output_t[rand_t]) T_est = fast_global_registration(data_s.pos[rand_s][matches[:, 0]], data_t.pos[rand_t][matches[:, 1]]) o3d_pcd_s = o3d.geometry.PointCloud() o3d_pcd_s.points = o3d.utility.Vector3dVector(data_s.pos.cpu().numpy()) o3d_pcd_s.paint_uniform_color([0.9, 0.7, 0.1]) o3d_pcd_t = o3d.geometry.PointCloud() o3d_pcd_t.points = o3d.utility.Vector3dVector(data_t.pos.cpu().numpy()) o3d_pcd_t.paint_uniform_color([0.1, 0.7, 0.9]) o3d.visualization.draw_geometries([o3d_pcd_s, o3d_pcd_t]) o3d.visualization.draw_geometries( [o3d_pcd_s.transform(T_est.cpu().numpy()), o3d_pcd_t])
def compute_metrics( xyz, xyz_target, feat, feat_target, T_gt, sym=False, tau_1=0.1, tau_2=0.05, rot_thresh=5, trans_thresh=5, use_ransac=False, ransac_thresh=0.02, use_teaser=False, noise_bound_teaser=0.1, registration_recall_thresh=0.2, xyz_gt=None, xyz_target_gt=None, ): """ compute all the necessary metrics compute the hit ratio, compute the feat_match_ratio using fast global registration compute the translation error compute the rotation error compute rre, and compute the rte using ransac compute the translation error compute the rotation error compute rre, and compute the rte using Teaser++ compute the translation error compute the rotation error compute rre, and compute the rtr Parameters ---------- xyz: torch tensor of size N x 3 xyz_target: torch tensor of size N x 3 feat: torch tensor of size N x C feat_target: torch tensor of size N x C T_gt; 4 x 4 matrix """ res = dict() t0 = time.time() matches_pred = get_matches(feat, feat_target, sym=sym) t1 = time.time() hit_ratio = compute_hit_ratio(xyz[matches_pred[:, 0]], xyz_target[matches_pred[:, 1]], T_gt, tau_1) res["hit_ratio"] = hit_ratio.item() res["feat_match_ratio"] = float(hit_ratio.item() > tau_2) res["time_get_matches"] = t1 - t0 # fast global registration t = time.time() T_fgr = fast_global_registration(xyz[matches_pred[:, 0]], xyz_target[matches_pred[:, 1]]) trans_error_fgr, rot_error_fgr = compute_transfo_error(T_fgr, T_gt) res["trans_error_fgr"] = trans_error_fgr.item() res["rot_error_fgr"] = rot_error_fgr.item() res["rre_fgr"] = float(rot_error_fgr.item() < rot_thresh) res["rte_fgr"] = float(trans_error_fgr.item() < trans_thresh) res["sr_fgr"] = compute_scaled_registration_error(xyz, T_gt, T_fgr).item() res["time_fgr"] = time.time() - t if xyz_gt is not None and xyz_target_gt is not None: res["registration_recall_fgr"] = compute_registration_recall( xyz_gt, xyz_target_gt, T_fgr, registration_recall_thresh ) # teaser pp if use_teaser: t = time.time() T_teaser = teaser_pp_registration( xyz[matches_pred[:, 0]], xyz_target[matches_pred[:, 1]], noise_bound=noise_bound_teaser ) trans_error_teaser, rot_error_teaser = compute_transfo_error(T_teaser, T_gt) res["trans_error_teaser"] = trans_error_teaser.item() res["rot_error_teaser"] = rot_error_teaser.item() res["rre_teaser"] = float(rot_error_teaser.item() < rot_thresh) res["rte_teaser"] = float(trans_error_teaser.item() < trans_thresh) res["sr_teaser"] = compute_scaled_registration_error(xyz, T_gt, T_teaser).item() res["time_teaser"] = time.time() - t if xyz_gt is not None and xyz_target_gt is not None: res["registration_recall_teaser"] = compute_registration_recall( xyz_gt, xyz_target_gt, T_teaser, registration_recall_thresh ) if use_ransac: raise NotImplementedError return res