Example #1
0
    def test_sample(self):
        torch.manual_seed(1)
        samples_shape = (2048, )

        a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2])
        so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs)
        samples = so3_distr.sample(samples_shape)

        self.assertEqual(
            samples.shape,
            samples_shape + so3_distr.batch_shape + so3_distr.event_shape)

        angles = cartesian_to_spherical(to_numpy(samples))  # [S, B, 2]
        mean_angles = np.mean(angles, axis=0)  # [B, 2]

        self.assertEqual(mean_angles.shape, (2, 2))

        so3_distr_1 = SO3Distribution(a_lms=self.a_lms_1, sphs=self.sphs)
        samples_1 = so3_distr_1.sample(samples_shape)
        angles_1 = cartesian_to_spherical(to_numpy(samples_1))  # [S, 1, 2]
        mean_angles_1 = np.mean(angles_1, axis=0)  # [1, 2]

        so3_distr_2 = SO3Distribution(a_lms=self.a_lms_2, sphs=self.sphs)
        samples_2 = so3_distr_2.sample(samples_shape)
        angles_2 = cartesian_to_spherical(to_numpy(samples_2))  # [S, 1, 2]
        mean_angles_2 = np.mean(angles_2, axis=0)  # [1, 2]

        # Assert that batching does not affect the result
        self.assertTrue(np.allclose(mean_angles[0], mean_angles_1, atol=0.1))
        self.assertTrue(np.allclose(mean_angles[1], mean_angles_2, atol=0.1))
Example #2
0
 def test_cycle_2(self):
     theta_phi = np.array([[0.3, -1.2]])
     xyz = spherical_to_cartesian(theta_phi)
     theta_phi_2 = cartesian_to_spherical(xyz)
     self.assertTrue(np.all(np.isclose(theta_phi, theta_phi_2)))
Example #3
0
 def test_cycle(self):
     xyz = np.array([[0.0, -1.0, 0.0]])
     theta_phi = cartesian_to_spherical(xyz)
     xyz_new = spherical_to_cartesian(theta_phi)
     self.assertTrue(np.all(np.isclose(xyz, xyz_new)))
Example #4
0
 def test_cartesian_to_spherical(self):
     xyz = np.array([[0.0, -1.0, 0.0]])
     theta_phi = cartesian_to_spherical(xyz)
     self.assertTrue(
         np.all(np.isclose(theta_phi, np.array([[np.pi / 2, -np.pi / 2]]))))