def test_se3_log_to_exp_to_log(self, batch_size: int = 100): """ Check that `se3_log_map(se3_exp_map(log_transform))==log_transform` for a randomly generated batch of SE(3) matrix logarithms `log_transform`. """ log_transform = TestSE3.init_log_transform(batch_size=batch_size) log_transform_ = se3_log_map(se3_exp_map(log_transform, eps=1e-8), eps=1e-8) self.assertClose(log_transform, log_transform_, atol=1e-1)
def test_bad_se3_input_value_err(self): """ Tests whether `se3_exp_map` and `se3_log_map` correctly return a ValueError if called with an argument of incorrect shape, or with an tensor containing illegal values. """ device = torch.device("cuda:0") for size in ([5, 4], [3, 4, 5], [3, 5, 6]): log_transform = torch.randn(size=size, device=device) with self.assertRaises(ValueError): se3_exp_map(log_transform) for size in ([5, 4], [3, 4, 5], [3, 5, 6], [2, 2, 3, 4]): transform = torch.randn(size=size, device=device) with self.assertRaises(ValueError): se3_log_map(transform) # Test the case where transform[:, :, :3] != 0. transform = torch.rand(size=[5, 4, 4], device=device) + 0.1 with self.assertRaises(ValueError): se3_log_map(transform)
def test_compare_with_precomputed(self): """ Compare the outputs against precomputed results. """ self.assertClose( se3_log_map(self.precomputed_transform), self.precomputed_log_transform, atol=1e-4, ) self.assertClose( self.precomputed_transform, se3_exp_map(self.precomputed_log_transform), atol=1e-4, )
def test_se3_exp_to_log_to_exp(self, batch_size: int = 10000): """ Check that `se3_exp_map(se3_log_map(A))==A` for a batch of randomly generated SE(3) matrices `A`. """ transform = TestSE3.init_transform(batch_size=batch_size) # Limit test transforms to those not around the singularity where # the rotation angle~=pi. nonsingular = so3_rotation_angle(transform[:, :3, :3]) < 3.134 transform = transform[nonsingular] transform_ = se3_exp_map(se3_log_map(transform, eps=1e-8, cos_bound=0.0), eps=1e-8) self.assertClose(transform, transform_, atol=0.02)
def test_se3_log_zero_translation(self, batch_size: int = 100): """ Check that `se3_log_map` with zero translation gives the same result as corresponding `so3_log_map`. """ transform = TestSE3.init_transform(batch_size=batch_size) transform[:, 3, :3] *= 0.0 log_transform = se3_log_map(transform, eps=1e-8, cos_bound=1e-4) log_transform_so3 = so3_log_map(transform[:, :3, :3], eps=1e-8, cos_bound=1e-4) self.assertClose(log_transform[:, 3:], -log_transform_so3, atol=1e-4) self.assertClose(log_transform[:, :3], torch.zeros_like(log_transform[:, :3]), atol=1e-4)
def test_se3_log_singularity(self, batch_size: int = 100): """ Tests whether the `se3_log_map` is robust to the input matrices whose rotation angles and translations are close to the numerically unstable region (i.e. matrices with low rotation angles and 0 translation). """ # generate random rotations with a tiny angle device = torch.device("cuda:0") identity = torch.eye(3, device=device) rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device) r = [identity, rot180] r.extend([ qr(identity + torch.randn_like(identity) * 1e-6)[0] + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8 # this adds random noise to the second half # of the random orthogonal matrices to generate # near-orthogonal matrices for i in range(batch_size - 2) ]) r = torch.stack(r) # tiny translations t = torch.randn(batch_size, 3, dtype=r.dtype, device=device) * 1e-6 # create the transform matrix transform = torch.zeros(batch_size, 4, 4, dtype=torch.float32, device=device) transform[:, :3, :3] = r transform[:, 3, :3] = t transform[:, 3, 3] = 1.0 transform.requires_grad = True # the log of the transform log_transform = se3_log_map(transform, eps=1e-4, cos_bound=1e-4) # tests whether all outputs are finite self.assertTrue(torch.isfinite(log_transform).all()) # tests whether all gradients are finite and not None loss = log_transform.sum() loss.backward() self.assertIsNotNone(transform.grad) self.assertTrue(torch.isfinite(transform.grad).all())
def compute_logs(): se3_log_map(log_transform) torch.cuda.synchronize()