def test_quaternion_multiplication(self): """Quaternion and matrix multiplication are equivalent.""" a = random_quaternions(15, torch.float64).reshape((3, 5, 4)) b = random_quaternions(21, torch.float64).reshape((7, 3, 1, 4)) ab = quaternion_multiply(a, b) self.assertEqual(ab.shape, (7, 3, 5, 4)) a_matrix = quaternion_to_matrix(a) b_matrix = quaternion_to_matrix(b) ab_matrix = torch.matmul(a_matrix, b_matrix) ab_from_matrix = matrix_to_quaternion(ab_matrix) self._assert_quaternions_close(ab, ab_from_matrix)
def test_quaternion_application(self): """Applying a quaternion is the same as applying the matrix.""" quaternions = random_quaternions(3, torch.float64, requires_grad=True) matrices = quaternion_to_matrix(quaternions) points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True) transform1 = quaternion_apply(quaternions, points) transform2 = torch.matmul(matrices, points[..., None])[..., 0] self.assertTrue(torch.allclose(transform1, transform2)) [p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions]) self.assertTrue(torch.isfinite(p).all()) self.assertTrue(torch.isfinite(q).all())
def test_from_quat(self): """quat -> mtx -> quat""" data = random_quaternions(13, dtype=torch.float64) mdata = matrix_to_quaternion(quaternion_to_matrix(data)) self.assertTrue(torch.allclose(data, mdata))