Пример #1
0
 def test_normalization(self):
     a_lms = concat_so3vecs([self.a_lms_1, self.a_lms_2])
     so3_distr = SO3Distribution(a_lms=a_lms,
                                 sphs=self.sphs,
                                 dtype=torch.float)
     grid = generate_fibonacci_grid(n=1024)
     grid_t = torch.tensor(grid, dtype=torch.float).unsqueeze(1)
     probs = so3_distr.prob(grid_t)
     integral = 4 * np.pi * torch.mean(probs, dim=0)
     self.assertTrue(np.allclose(to_numpy(integral), 1.0))
Пример #2
0
    def verify_probs(self, atoms):
        grid_points = torch.tensor(generate_fibonacci_grid(n=100_000),
                                   dtype=torch.float,
                                   device=self.device)
        grid_points = grid_points.unsqueeze(-2)

        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist = action['dists'][-1]

        # Rotate atoms
        wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh,
                                                      dtype=self.agent.dtype)
        atoms_rotated = atoms.copy()
        atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat,
                                            atoms.positions)

        observation = self.observation_space.build(atoms_rotated,
                                                   formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist_rot = action['dists'][-1]

        log_probs = so3_dist.log_prob(grid_points)  # (samples, batches)
        log_probs_rot = so3_dist_rot.log_prob(
            grid_points)  # (samples, batches)

        # Maximum over grid points
        maximum, max_indices = torch.max(log_probs, dim=0)
        minimum, min_indices = torch.min(log_probs, dim=0)

        maximum_rot, max_indices_rot = torch.max(log_probs_rot, dim=0)
        minimum_rot, min_indices_rot = torch.min(log_probs_rot, dim=0)

        self.assertTrue(torch.allclose(maximum, maximum_rot, atol=5e-3))
        self.assertTrue(torch.allclose(minimum, minimum_rot, atol=5e-3))
Пример #3
0
 def test_empty(self):
     count = 0
     grid = generate_fibonacci_grid(n=count)
     self.assertEqual(grid.shape, (count, 3))