Пример #1
0
    def apply_transform(self, p0, transform_mat):
        p1 = se3.transform(transform_mat, p0[:, :3])
        if p0.shape[1] == 6:  # Need to rotate normals also
            n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6])
            p1 = np.concatenate((p1, n1), axis=-1)

        igt = transform_mat
        gt = se3.inverse(igt)

        return p1, gt, igt
Пример #2
0
def chamfer_distance(source_pts, target_pts, raw_pts, T_pred, T_gt):
    def square_distance(src, dst):
        src, dst = torch.Tensor(src), torch.Tensor(dst)
        return torch.sum((src[:, None, :] - dst[None, :, :])**2, dim=-1)

    source_pc = o3d.PointCloud()
    source_pc.points = o3d.Vector3dVector(source_pts)

    src_transformed = np.asarray(deepcopy(source_pc).transform(T_pred).points)
    ref_clean = raw_pts
    src_clean = se3.transform(
        se3.concatenate(T_pred[:3, :], se3.inverse(T_gt[:3, :])), raw_pts)
    dist_src = torch.min(square_distance(src_transformed, ref_clean),
                         dim=-1)[0]
    dist_ref = torch.min(square_distance(target_pts, src_clean), dim=-1)[0]
    chamfer_dist = torch.mean(dist_src, dim=0) + torch.mean(dist_ref, dim=0)
    return chamfer_dist
Пример #3
0
def compute_metrics(data: Dict, pred_transforms, perm_matrices=None) -> Dict:
    """Compute metrics required in the paper
    """

    def square_distance(src, dst):
        return torch.sum((src[:, :, None, :] - dst[:, None, :, :]) ** 2, dim=-1)
    
    

    with torch.no_grad():
        pred_transforms = pred_transforms
        gt_transforms = data['transform_gt']
        points_src = data['points_src'][..., :3]
        points_ref = data['points_ref'][..., :3]
        points_raw = data['points_raw'][..., :3]

        # Euler angles, Individual translation errors (Deep Closest Point convention)
        # TODO Change rotation to torch operations
        r_gt_euler_deg = dcm2euler(gt_transforms[:, :3, :3].detach().cpu().numpy(), seq='xyz')
        r_pred_euler_deg = dcm2euler(pred_transforms[:, :3, :3].detach().cpu().numpy(), seq='xyz')
        t_gt = gt_transforms[:, :3, 3]
        t_pred = pred_transforms[:, :3, 3]
        r_mse = np.mean((r_gt_euler_deg - r_pred_euler_deg) ** 2, axis=1)
        r_mae = np.mean(np.abs(r_gt_euler_deg - r_pred_euler_deg), axis=1)
        t_mse = torch.mean((t_gt - t_pred) ** 2, dim=1)
        t_mae = torch.mean(torch.abs(t_gt - t_pred), dim=1)

        # Rotation, translation errors (isotropic, i.e. doesn't depend on error
        # direction, which is more representative of the actual error)
        concatenated = se3.concatenate(se3.inverse(gt_transforms), pred_transforms)
        rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2]
        residual_rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi
        residual_transmag = concatenated[:, :, 3].norm(dim=-1)

        # Modified Chamfer distance
        src_transformed = se3.transform(pred_transforms, points_src)
        ref_clean = points_raw
        src_clean = se3.transform(se3.concatenate(pred_transforms, se3.inverse(gt_transforms)), points_raw)
        dist_src = torch.min(square_distance(src_transformed, ref_clean), dim=-1)[0]
        dist_ref = torch.min(square_distance(points_ref, src_clean), dim=-1)[0]
        chamfer_dist = torch.mean(dist_src, dim=1) + torch.mean(dist_ref, dim=1)


        # computing percentage of correct correspondences        
        if perm_matrices is not None:

            scores_pred = perm_matrices #b,m,n
            scores_gt    = data['corr_mat'] # b,m,n

            corr_mat_pred = scores_pred.detach().cpu().numpy()     # b,m,n    
            col_idx_pred = np.argmax(corr_mat_pred,axis=-1) 
            corr_mat_gt = scores_gt.detach().cpu().numpy()     # b,m,n   
            col_idx_gt = np.argmax(corr_mat_gt,axis=-1)        # b,m

            correct_mask = (col_idx_gt == col_idx_pred)*1      # b,m
            correct_corr = np.mean(correct_mask,axis=1) # b
       
        metrics = {
            'r_mse': r_mse,
            'r_mae': r_mae,
            't_mse': to_numpy(t_mse),
            't_mae': to_numpy(t_mae),
            'err_r_deg': to_numpy(residual_rotdeg),
            'err_t': to_numpy(residual_transmag),
            'chamfer_dist': to_numpy(chamfer_dist),
            'correct_corr': correct_corr
        }

    return metrics