Esempio n. 1
0
    def test_covariance(self, tau, num_channels, maxl, sample_batch):
        # setup the environment
        # env = build_environment(tau, maxl, num_channels)
        # datasets, data, num_species, charge_scale, sph_harms = env
        data, __, __ = sample_batch
        device, dtype = data['positions'].device, data['positions'].dtype
        sph_harms = SphericalHarmonicsRel(maxl - 1,
                                          conj=True,
                                          device=device,
                                          dtype=dtype,
                                          cg_dict=None)
        D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype)

        # Build Atom layer
        tlist = [tau] * maxl
        print(tlist)
        atom_lvl = CormorantAtomLevel(tlist,
                                      tlist,
                                      maxl,
                                      num_channels,
                                      1,
                                      'rand',
                                      device=device,
                                      dtype=dtype,
                                      cg_dict=None)

        # Setup Input
        atom_rep, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input(
            data, tau, maxl)
        atom_positions_rot = rot.rotate_cart_vec(R, atom_positions)

        # Get nonrotated data
        spherical_harmonics, norms = sph_harms(atom_positions, atom_positions)
        edge_rep_list = [
            torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics
        ]
        edge_reps = SO3Vec(edge_rep_list)
        print(edge_reps.shapes)
        print(atom_rep.shapes)

        # Get Rotated output
        output = atom_lvl(atom_rep, edge_reps, atom_mask)
        output = output.apply_wigner(D)

        # Get rotated outputdata
        atom_rep_rot = atom_rep.apply_wigner(D)
        spherical_harmonics_rot, norms = sph_harms(atom_positions_rot,
                                                   atom_positions_rot)
        edge_rep_list_rot = [
            torch.cat([sph_l] * tau, axis=-3)
            for sph_l in spherical_harmonics_rot
        ]
        edge_reps_rot = SO3Vec(edge_rep_list_rot)
        output_from_rot = atom_lvl(atom_rep_rot, edge_reps_rot, atom_mask)

        for i in range(maxl):
            assert (torch.max(torch.abs(output_from_rot[i] - output[i])) <
                    1E-5)
Esempio n. 2
0
    def verify_alms(self, atoms):
        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist = action['dists'][-1]

        # Rotate
        wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh,
                                                      dtype=self.agent.dtype)
        atoms.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions)

        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist_rot = action['dists'][-1]

        rotated_b_lms = so3_dist.coefficients.apply_wigner(wigner_d)
        for part1, part2 in zip(so3_dist_rot.coefficients, rotated_b_lms):
            max_delta = torch.max(torch.abs(part1 - part2))
            self.assertTrue(max_delta < 1e-5)
Esempio n. 3
0
    def verify_probs(self, atoms):
        grid_points = torch.tensor(generate_fibonacci_grid(n=100_000),
                                   dtype=torch.float,
                                   device=self.device)
        grid_points = grid_points.unsqueeze(-2)

        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist = action['dists'][-1]

        # Rotate atoms
        wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh,
                                                      dtype=self.agent.dtype)
        atoms_rotated = atoms.copy()
        atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat,
                                            atoms.positions)

        observation = self.observation_space.build(atoms_rotated,
                                                   formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist_rot = action['dists'][-1]

        log_probs = so3_dist.log_prob(grid_points)  # (samples, batches)
        log_probs_rot = so3_dist_rot.log_prob(
            grid_points)  # (samples, batches)

        # Maximum over grid points
        maximum, max_indices = torch.max(log_probs, dim=0)
        minimum, min_indices = torch.min(log_probs, dim=0)

        maximum_rot, max_indices_rot = torch.max(log_probs_rot, dim=0)
        minimum_rot, min_indices_rot = torch.min(log_probs_rot, dim=0)

        self.assertTrue(torch.allclose(maximum, maximum_rot, atol=5e-3))
        self.assertTrue(torch.allclose(minimum, minimum_rot, atol=5e-3))
Esempio n. 4
0
    def verify_invariance(self, atoms):
        atomic_scalars = AtomicScalars(maxl=self.agent.max_sh)

        observation = self.observation_space.build(atoms, formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist = action['dists'][-1]
        scalars = atomic_scalars(so3_dist.coefficients)

        # Rotate atoms
        wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh,
                                                      dtype=self.agent.dtype)
        atoms_rotated = atoms.copy()
        atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat,
                                            atoms.positions)

        observation = self.observation_space.build(atoms_rotated,
                                                   formula=self.formula)
        util.set_seeds(0)
        action = self.agent.step([observation])
        so3_dist_rot = action['dists'][-1]
        scalars_rot = atomic_scalars(so3_dist_rot.coefficients)

        self.assertTrue(torch.allclose(scalars, scalars_rot, atol=1e-05))
Esempio n. 5
0
    def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels):
        maxl_all = max(maxl1, maxl2, maxl)
        D, R, _ = rot.gen_rot(maxl_all)

        cg_dict = CGDict(maxl=maxl_all, dtype=torch.double)
        cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict)

        tau1 = SO3Tau([channels] * (maxl1 + 1))
        tau2 = SO3Tau([channels] * (maxl2 + 1))

        vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double)
        vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double)

        vec1i = vec1.apply_wigner(D, dir='left')
        vec2i = vec2.apply_wigner(D, dir='left')

        vec_prod = cg_prod(vec1, vec2)
        veci_prod = cg_prod(vec1i, vec2i)

        vecf_prod = vec_prod.apply_wigner(D, dir='left')

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)]
        assert all([d < 1e-6 for d in diff])
Esempio n. 6
0
    def test_spherical_harmonics(self, maxl, channels, conj):
        D, R, angles = rot.gen_rot(maxl, dtype=torch.double)
        D = SO3WignerD(D)

        if not conj:
            R = R.t()

        pos = torch.randn((channels, 3), dtype=torch.double)
        posr = rot.rotate_cart_vec(R, pos)

        cg_dict = CGDict(maxl, dtype=torch.double)

        sph_harms = spherical_harmonics(cg_dict, pos, maxl, conj=conj)
        sph_harmsr = spherical_harmonics(cg_dict, posr, maxl, conj=conj)

        dir = 'left' if conj else 'right'

        sph_harmsd = so3_torch.apply_wigner(D, sph_harms, dir=dir)

        # diff = (sph_harmsr - sph_harmsd).abs()
        diff = [(p1 - p2).abs().max()
                for p1, p2 in zip(sph_harmsr, sph_harmsd)]
        print(diff)
        assert all([d < 1e-6 for d in diff])
Esempio n. 7
0
def covariance_test(model, data):
    logging.info('Beginning covariance test!')
    targets_rotout, outputs_rotin = [], []

    device, dtype = data['positions'].device, data['positions'].dtype

    D, R, _ = rot.gen_rot(model.maxl, device=device, dtype=dtype)

    D = SO3WignerD(D).to(model.device, model.dtype)

    data_rotout = data

    data_rotin = {
        key: val.clone() if torch.is_tensor(val) else None
        for key, val in data.items()
    }
    data_rotin['positions'] = rot.rotate_cart_vec(R, data_rotin['positions'])

    outputs_rotout, reps_rotout, _ = model(data_rotout, covariance_test=True)
    outputs_rotin, reps_rotin, _ = model(data_rotin, covariance_test=True)

    invariance_test = (outputs_rotout - outputs_rotin).norm().item()

    reps_rotout = [reps.apply_wigner(D) for reps in reps_rotout]

    rep_diff = [(level_in - level_out).abs()
                for (level_in, level_out) in zip(reps_rotin, reps_rotout)]

    covariance_test_norm = [
        torch.tensor([diff.norm() for diff in diffs_lvl])
        for diffs_lvl in rep_diff
    ]
    covariance_test_mean = [
        torch.tensor([diff.mean() for diff in diffs_lvl])
        for diffs_lvl in rep_diff
    ]
    covariance_test_max = [
        torch.tensor([diff.max() for diff in diffs_lvl])
        for diffs_lvl in rep_diff
    ]

    covariance_test_max_all = max([max(lvl)
                                   for lvl in covariance_test_max]).item()

    logging.info('Rotation Invariance test: {:0.5g}'.format(invariance_test))
    logging.info('Largest deviation in covariance test : {:0.5g}'.format(
        covariance_test_max_all))

    # If the largest deviation in the covariance test is greater than 1e-5,
    # display l1 and l2 norms of each irrep along each level.
    if covariance_test_max_all > 1e-5:
        logging.warning(
            'Largest deviation in covariance test {:0.5g} detected! Detailed summary:'
            .format(covariance_test_max_all))
        for lvl_idx, (lvl_norm, lvl_mean, lvl_max) in enumerate(
                zip(covariance_test_norm, covariance_test_mean,
                    covariance_test_max)):
            for ell_idx, (ell_norm, ell_mean, ell_max) in enumerate(
                    zip(lvl_norm, lvl_mean, lvl_max)):
                logging.warning('(lvl, ell) = ({}, {}) -> '
                       '{:0.5g} (norm) {:0.5g} (mean) {:0.5g} (max)'\
                    .format(lvl_idx, ell_idx, ell_norm, ell_mean, ell_max))
Esempio n. 8
0
    def test_covariance(self, tau, num_channels, maxl, basis, edge_net_type,
                        sample_batch):
        # env = build_environment(tau, maxl, num_channels)
        # datasets, data, num_species, charge_scale, sph_harms = env
        data, __, __ = sample_batch
        device, dtype = data['positions'].device, data['positions'].dtype
        sph_harms = SphericalHarmonicsRel(maxl - 1,
                                          conj=True,
                                          device=device,
                                          dtype=dtype,
                                          cg_dict=None)
        batch_size, natoms = data['positions'].shape[:2]
        D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype)
        # Setup Input
        atom_reps, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input(
            data, tau, maxl)
        atom_positions_rot = rot.rotate_cart_vec(R, atom_positions)
        atom_reps_rot = atom_reps.apply_wigner(D)

        # Calculate spherical harmonics and radial functions
        __, norms = sph_harms(atom_positions, atom_positions)
        __, norms_rot = sph_harms(atom_positions_rot, atom_positions_rot)

        rad_funcs = RadialFilters([maxl - 1], [basis, basis], [num_channels],
                                  1,
                                  device=device,
                                  dtype=dtype)
        rad_func_levels = rad_funcs(norms, edge_mask * (norms > 0))
        tau_pos = rad_funcs.tau[0]

        # Build the initial edge network
        if edge_net_type is None:
            edge_reps = None
        elif edge_net_type == 'rand':
            reps = [
                torch.randn((batch_size, natoms, natoms, tau, 2))
                for i in range(maxl)
            ]
            edge_reps = SO3Scalar(reps)
        else:
            raise ValueError

        # Build Edge layer
        tlist = [tau] * maxl
        tau_atoms = tlist
        tau_edge = tlist
        if edge_net_type is None:
            tau_edge = []

        edge_lvl = CormorantEdgeLevel(tau_atoms,
                                      tau_edge,
                                      tau_pos,
                                      num_channels,
                                      maxl,
                                      cutoff_type='soft',
                                      device=device,
                                      dtype=dtype,
                                      hard_cut_rad=1.73,
                                      soft_cut_rad=1.73,
                                      soft_cut_width=0.2)

        output_edge_reps = edge_lvl(edge_reps, atom_reps, rad_func_levels[0],
                                    edge_mask, norms)
        output_edge_reps_rot = edge_lvl(edge_reps, atom_reps_rot,
                                        rad_func_levels[0], edge_mask, norms)

        for i in range(maxl):
            assert (torch.max(
                torch.abs(output_edge_reps[i] - output_edge_reps_rot[i])) <
                    1E-5)