Example #1
0
    def test_sample(self):
        torch.manual_seed(1)
        samples_shape = (2048, )

        a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2])
        so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs)
        samples = so3_distr.sample(samples_shape)

        self.assertEqual(
            samples.shape,
            samples_shape + so3_distr.batch_shape + so3_distr.event_shape)

        angles = cartesian_to_spherical(to_numpy(samples))  # [S, B, 2]
        mean_angles = np.mean(angles, axis=0)  # [B, 2]

        self.assertEqual(mean_angles.shape, (2, 2))

        so3_distr_1 = SO3Distribution(a_lms=self.a_lms_1, sphs=self.sphs)
        samples_1 = so3_distr_1.sample(samples_shape)
        angles_1 = cartesian_to_spherical(to_numpy(samples_1))  # [S, 1, 2]
        mean_angles_1 = np.mean(angles_1, axis=0)  # [1, 2]

        so3_distr_2 = SO3Distribution(a_lms=self.a_lms_2, sphs=self.sphs)
        samples_2 = so3_distr_2.sample(samples_shape)
        angles_2 = cartesian_to_spherical(to_numpy(samples_2))  # [S, 1, 2]
        mean_angles_2 = np.mean(angles_2, axis=0)  # [1, 2]

        # Assert that batching does not affect the result
        self.assertTrue(np.allclose(mean_angles[0], mean_angles_1, atol=0.1))
        self.assertTrue(np.allclose(mean_angles[1], mean_angles_2, atol=0.1))
Example #2
0
 def test_max_sample(self):
     a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2])
     so3_distr = SO3Distribution(a_lms=a_lms,
                                 sphs=self.sphs,
                                 dtype=torch.float)
     samples = so3_distr.argmax(count=17)
     self.assertEqual(samples.shape, (2, 3))
Example #3
0
    def test_concat(self):
        theta_phi = np.array([[np.pi / 2, np.pi / 2]])
        xyz_refs = spherical_to_cartesian(theta_phi)
        y_lms_conj = self.sphs_conj.forward(
            torch.tensor(xyz_refs, dtype=torch.float))

        a_lms = estimate_alms(y_lms_conj)
        a_lms = concat_so3vecs([a_lms] * 3)

        self.assertTrue(all(a_lm.shape[0] == 3 for a_lm in a_lms))
Example #4
0
 def test_normalization(self):
     a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2])
     so3_distr = SO3Distribution(a_lms=a_lms,
                                 sphs=self.sphs,
                                 dtype=torch.float)
     grid = generate_fibonacci_grid(n=1024)
     grid_t = torch.tensor(grid, dtype=torch.float).unsqueeze(1)
     probs = so3_distr.prob(grid_t)
     integral = 4 * np.pi * torch.mean(probs, dim=0)
     self.assertTrue(np.allclose(to_numpy(integral), 1.0))
Example #5
0
    def test_prob(self):
        a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1])
        distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, beta=100)
        samples = torch.tensor([
            [1.0, 0.0, 0.0],
            [0.0, 0.0, 1.0],
            [0.0, 1.0, 0.0],
        ])

        self.assertEqual(distr.log_prob(samples).shape, (3, ))
        self.assertEqual(distr.log_prob(samples[[0]]).shape, (3, ))

        with self.assertRaises(RuntimeError):
            distr.log_prob(samples[:2])
Example #6
0
 def test_max(self):
     a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1])
     so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs)
     self.assertEqual(so3_distr.get_max_prob().shape, (3, ))
Example #7
0
 def test_max(self):
     a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1])
     distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, beta=100)
     self.assertEqual(distr.get_max_log_prob().shape, (3, ))