def test_sample(self): torch.manual_seed(1) samples_shape = (2048, ) a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs) samples = so3_distr.sample(samples_shape) self.assertEqual( samples.shape, samples_shape + so3_distr.batch_shape + so3_distr.event_shape) angles = cartesian_to_spherical(to_numpy(samples)) # [S, B, 2] mean_angles = np.mean(angles, axis=0) # [B, 2] self.assertEqual(mean_angles.shape, (2, 2)) so3_distr_1 = SO3Distribution(a_lms=self.a_lms_1, sphs=self.sphs) samples_1 = so3_distr_1.sample(samples_shape) angles_1 = cartesian_to_spherical(to_numpy(samples_1)) # [S, 1, 2] mean_angles_1 = np.mean(angles_1, axis=0) # [1, 2] so3_distr_2 = SO3Distribution(a_lms=self.a_lms_2, sphs=self.sphs) samples_2 = so3_distr_2.sample(samples_shape) angles_2 = cartesian_to_spherical(to_numpy(samples_2)) # [S, 1, 2] mean_angles_2 = np.mean(angles_2, axis=0) # [1, 2] # Assert that batching does not affect the result self.assertTrue(np.allclose(mean_angles[0], mean_angles_1, atol=0.1)) self.assertTrue(np.allclose(mean_angles[1], mean_angles_2, atol=0.1))
def test_max_sample(self): a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs, dtype=torch.float) samples = so3_distr.argmax(count=17) self.assertEqual(samples.shape, (2, 3))
def test_concat(self): theta_phi = np.array([[np.pi / 2, np.pi / 2]]) xyz_refs = spherical_to_cartesian(theta_phi) y_lms_conj = self.sphs_conj.forward( torch.tensor(xyz_refs, dtype=torch.float)) a_lms = estimate_alms(y_lms_conj) a_lms = concat_so3vecs([a_lms] * 3) self.assertTrue(all(a_lm.shape[0] == 3 for a_lm in a_lms))
def test_normalization(self): a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2]) so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs, dtype=torch.float) grid = generate_fibonacci_grid(n=1024) grid_t = torch.tensor(grid, dtype=torch.float).unsqueeze(1) probs = so3_distr.prob(grid_t) integral = 4 * np.pi * torch.mean(probs, dim=0) self.assertTrue(np.allclose(to_numpy(integral), 1.0))
def test_prob(self): a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1]) distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, beta=100) samples = torch.tensor([ [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], ]) self.assertEqual(distr.log_prob(samples).shape, (3, )) self.assertEqual(distr.log_prob(samples[[0]]).shape, (3, )) with self.assertRaises(RuntimeError): distr.log_prob(samples[:2])
def test_max(self): a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1]) so3_distr = SO3Distribution(a_lms=a_lms, sphs=self.sphs) self.assertEqual(so3_distr.get_max_prob().shape, (3, ))
def test_max(self): a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2, self.a_lms_1]) distr = ExpSO3Distribution(a_lms=a_lms, sphs=self.sphs, beta=100) self.assertEqual(distr.get_max_log_prob().shape, (3, ))