def test_quaternion_multiplication(self):
     """Quaternion and matrix multiplication are equivalent."""
     a = random_quaternions(15, torch.float64).reshape((3, 5, 4))
     b = random_quaternions(21, torch.float64).reshape((7, 3, 1, 4))
     ab = quaternion_multiply(a, b)
     self.assertEqual(ab.shape, (7, 3, 5, 4))
     a_matrix = quaternion_to_matrix(a)
     b_matrix = quaternion_to_matrix(b)
     ab_matrix = torch.matmul(a_matrix, b_matrix)
     ab_from_matrix = matrix_to_quaternion(ab_matrix)
     self._assert_quaternions_close(ab, ab_from_matrix)
    def test_quaternion_application(self):
        """Applying a quaternion is the same as applying the matrix."""
        quaternions = random_quaternions(3, torch.float64, requires_grad=True)
        matrices = quaternion_to_matrix(quaternions)
        points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
        transform1 = quaternion_apply(quaternions, points)
        transform2 = torch.matmul(matrices, points[..., None])[..., 0]
        self.assertTrue(torch.allclose(transform1, transform2))

        [p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
        self.assertTrue(torch.isfinite(p).all())
        self.assertTrue(torch.isfinite(q).all())
 def test_from_quat(self):
     """quat -> mtx -> quat"""
     data = random_quaternions(13, dtype=torch.float64)
     mdata = matrix_to_quaternion(quaternion_to_matrix(data))
     self.assertTrue(torch.allclose(data, mdata))