Esempio n. 1
0
 def test_so3_cos_angle(self, batch_size: int = 100):
     """
     Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
     is the same as `so3_relative_angle(R1, R2, cos_angle=True)`
     batches of randomly generated rotation matrices `R1` and `R2`.
     """
     rot1 = TestSO3.init_rot(batch_size=batch_size)
     rot2 = TestSO3.init_rot(batch_size=batch_size)
     angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
     angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
     self.assertTrue(torch.allclose(angles, angles_))
Esempio n. 2
0
 def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that `so3_exponential_map(so3_log_map(R))==R` for
     a batch of randomly generated rotation matrices `R`.
     """
     rot = TestSO3.init_rot(batch_size=batch_size)
     rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
     angles = so3_relative_angle(rot, rot_)
     # TODO: a lot of precision lost here ...
     self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
Esempio n. 3
0
 def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that `so3_exponential_map(so3_log_map(R))==R` for
     a batch of randomly generated rotation matrices `R`.
     """
     rot = TestSO3.init_rot(batch_size=batch_size)
     rot_ = so3_exponential_map(so3_log_map(rot))
     angles = so3_relative_angle(rot, rot_)
     max_angle = angles.max()
     # a lot of precision lost here :(
     # TODO: fix this test??
     self.assertTrue(np.allclose(float(max_angle), 0.0, atol=0.1))
Esempio n. 4
0
 def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that `so3_exp_map(so3_log_map(R))==R` for
     a batch of randomly generated rotation matrices `R`.
     """
     rot = TestSO3.init_rot(batch_size=batch_size)
     non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2
     rot = rot[non_singular]
     rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8),
                        eps=1e-8)
     self.assertClose(rot_, rot, atol=0.1)
     angles = so3_relative_angle(rot, rot_, cos_bound=1e-4)
     self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
Esempio n. 5
0
 def calc_camera_distance(cam_1, cam_2):
     """
     Calculates the divergence of a batch of pairs of cameras cam_1, cam_2.
     The distance is composed of the cosine of the relative angle between
     the rotation components of the camera extrinsics and the l2 distance
     between the translation vectors.
     """
     # rotation distance
     R_distance = (
         1. - so3_relative_angle(cam_1.R, cam_2.R, cos_angle=True)).mean()
     # translation distance
     T_distance = ((cam_1.T - cam_2.T)**2).sum(1).mean()
     # the final distance is the sum
     return R_distance + T_distance
Esempio n. 6
0
 def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that
     `so3_exp_map(so3_log_map(so3_exp_map(log_rot)))
     == so3_exp_map(log_rot)`
     for a randomly generated batch of rotation matrix logarithms `log_rot`.
     Unlike `test_so3_log_to_exp_to_log`, this test checks the
     correctness of converting a `log_rot` which contains values > math.pi.
     """
     log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
     # check also the singular cases where rot. angle = {0, 2pi}
     log_rot[:2] = 0
     log_rot[1, 0] = 2.0 * math.pi - 1e-6
     rot = so3_exp_map(log_rot, eps=1e-4)
     rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6),
                        eps=1e-6)
     self.assertClose(rot, rot_, atol=0.01)
     angles = so3_relative_angle(rot, rot_, cos_bound=1e-6)
     self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
Esempio n. 7
0
 def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that
     `so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
     == so3_exponential_map(log_rot)`
     for a randomly generated batch of rotation matrix logarithms `log_rot`.
     Unlike `test_so3_log_to_exp_to_log`, this test allows to check the
     correctness of converting `log_rot` which contains values > math.pi.
     """
     log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
     # check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
     log_rot[:3] = 0
     log_rot[1, 0] = math.pi
     log_rot[2, 0] = 2.0 * math.pi
     log_rot[3, 0] = 3.0 * math.pi
     rot = so3_exponential_map(log_rot, eps=1e-8)
     rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
     angles = so3_relative_angle(rot, rot_)
     self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
Esempio n. 8
0
    def _corresponding_cameras_alignment_test_case(
        self,
        cameras,
        R_align_gt,
        T_align_gt,
        s_align_gt,
        estimate_scale,
        mode,
        add_noise,
    ):
        batch_size = cameras.R.shape[0]

        # get target camera centers
        R_new = torch.bmm(R_align_gt[None].expand_as(cameras.R), cameras.R)
        T_new = (
            torch.bmm(T_align_gt[None, None].repeat(batch_size, 1, 1), cameras.R)[:, 0]
            + cameras.T
        ) * s_align_gt

        if add_noise != 0.0:
            R_new = torch.bmm(
                R_new, so3_exponential_map(torch.randn_like(T_new) * add_noise)
            )
            T_new += torch.randn_like(T_new) * add_noise

        # create new cameras from R_new and T_new
        cameras_tgt = cameras.clone()
        cameras_tgt.R = R_new
        cameras_tgt.T = T_new

        # align cameras and cameras_tgt
        cameras_aligned = corresponding_cameras_alignment(
            cameras, cameras_tgt, estimate_scale=estimate_scale, mode=mode
        )

        if batch_size <= 2 and mode == "centers":
            # underdetermined case - check only the center alignment error
            # since the rotation and translation are ambiguous here
            self.assertClose(
                cameras_aligned.get_camera_center(),
                cameras_tgt.get_camera_center(),
                atol=max(add_noise * 7.0, 1e-4),
            )

        else:

            def _rmse(a):
                return (torch.norm(a, dim=1, p=2) ** 2).mean().sqrt()

            if add_noise != 0.0:
                # in a noisy case check mean rotation/translation error for
                # extrinsic alignment and root mean center error for center alignment
                if mode == "centers":
                    self.assertNormsClose(
                        cameras_aligned.get_camera_center(),
                        cameras_tgt.get_camera_center(),
                        _rmse,
                        atol=max(add_noise * 10.0, 1e-4),
                    )
                elif mode == "extrinsics":
                    angle_err = so3_relative_angle(
                        cameras_aligned.R, cameras_tgt.R
                    ).mean()
                    self.assertClose(
                        angle_err, torch.zeros_like(angle_err), atol=add_noise * 10.0
                    )
                    self.assertNormsClose(
                        cameras_aligned.T, cameras_tgt.T, _rmse, atol=add_noise * 7.0
                    )
                else:
                    raise ValueError(mode)

            else:
                # compare the rotations and translations of cameras
                self.assertClose(cameras_aligned.R, cameras_tgt.R, atol=3e-4)
                self.assertClose(cameras_aligned.T, cameras_tgt.T, atol=3e-4)
                # compare the centers
                self.assertClose(
                    cameras_aligned.get_camera_center(),
                    cameras_tgt.get_camera_center(),
                    atol=3e-4,
                )
Esempio n. 9
0
def eval_pose(pred, gt):
    # ScanNet convention for x_cam ... x_world
    # cam->world: right multiply by extrinsics[:3,:3].T and then add extrinsics[:3,3]
    # i.e.: x_world = x_cam @ pose[:3, :3].t() + pose[:3, 3]
    # i.e.: x_cam = (x_world - pose[:3, 3]) @ pose[:3, :3]

    from pytorch3d import ops as pt3ops
    from pytorch3d.transforms import so3

    ok_gt = torch.isfinite(gt.mean((1, 2)))  # some GT poses are NaN
    if not ok_gt.any():
        return 'NO_GT'
    orig_C_pred = pred[ok_gt, :3, 3].clone()
    orig_R_pred = pred[ok_gt, :3, :3].clone()
    orig_C_gt = gt[ok_gt, :3, 3].clone()
    orig_R_gt = gt[ok_gt, :3, :3].clone()
    n_frames = orig_C_pred.shape[0]

    result = {}

    for interpolate in (True, False):
        registered = torch.isfinite(orig_C_pred.mean(1))
        if not registered.any():
            return None
        if interpolate:  # interpolate NaN cameras
            C_pred, R_pred = interpolate_cameras(orig_C_pred, orig_R_pred)
            R_gt = orig_R_gt.clone()
            C_gt = orig_C_gt.clone()
        else:  # remove NaN cameras
            C_pred = orig_C_pred.clone()[registered > 0]
            R_pred = orig_R_pred.clone()[registered > 0]
            C_gt = orig_C_gt.clone()[registered > 0]
            R_gt = orig_R_gt.clone()[registered > 0]

        for align_cams in (True, False):
            for estimate_scale in ((True, False) if align_cams else (False,)):
                if align_cams:
                    # estimate the rigid alignment
                    align_result = pt3ops.corresponding_points_alignment(
                        C_pred[None], C_gt[None], estimate_scale=estimate_scale)
                    # align centers and rotations
                    C_pred_align = (
                        align_result.s *
                        C_pred @ align_result.R[0] +
                        align_result.T[0])
                    R_pred_align = torch.bmm(
                        align_result.R.permute(
                            0, 2, 1).expand_as(R_pred), R_pred)
                else:
                    C_pred_align = C_pred.clone()
                    R_pred_align = R_pred.clone()

                # compute the rotation errors and camera center errors
                cam_center_error = (C_pred_align - C_gt).norm(dim=1).mean()
                cam_angle_error = so3.so3_relative_angle(
                    R_pred_align, R_gt).median() * 180 / np.pi

                # store the errors
                postfix = ''
                if not align_cams:
                    postfix += '_noalign'
                if interpolate:
                    postfix += '_interp'
                if estimate_scale:
                    postfix += '_scale'
                result['cam_center_err' + postfix] = float(cam_center_error)
                result['cam_angle_err' + postfix] = float(cam_angle_error)

                if estimate_scale and not interpolate and align_cams:
                    result['best_scale'] = float(align_result.s)

    return result