def test_covariance(self, tau, num_channels, maxl, sample_batch): # setup the environment # env = build_environment(tau, maxl, num_channels) # datasets, data, num_species, charge_scale, sph_harms = env data, __, __ = sample_batch device, dtype = data['positions'].device, data['positions'].dtype sph_harms = SphericalHarmonicsRel(maxl - 1, conj=True, device=device, dtype=dtype, cg_dict=None) D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype) # Build Atom layer tlist = [tau] * maxl print(tlist) atom_lvl = CormorantAtomLevel(tlist, tlist, maxl, num_channels, 1, 'rand', device=device, dtype=dtype, cg_dict=None) # Setup Input atom_rep, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input( data, tau, maxl) atom_positions_rot = rot.rotate_cart_vec(R, atom_positions) # Get nonrotated data spherical_harmonics, norms = sph_harms(atom_positions, atom_positions) edge_rep_list = [ torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics ] edge_reps = SO3Vec(edge_rep_list) print(edge_reps.shapes) print(atom_rep.shapes) # Get Rotated output output = atom_lvl(atom_rep, edge_reps, atom_mask) output = output.apply_wigner(D) # Get rotated outputdata atom_rep_rot = atom_rep.apply_wigner(D) spherical_harmonics_rot, norms = sph_harms(atom_positions_rot, atom_positions_rot) edge_rep_list_rot = [ torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics_rot ] edge_reps_rot = SO3Vec(edge_rep_list_rot) output_from_rot = atom_lvl(atom_rep_rot, edge_reps_rot, atom_mask) for i in range(maxl): assert (torch.max(torch.abs(output_from_rot[i] - output[i])) < 1E-5)
def verify_alms(self, atoms): observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist = action['dists'][-1] # Rotate wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) atoms.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist_rot = action['dists'][-1] rotated_b_lms = so3_dist.coefficients.apply_wigner(wigner_d) for part1, part2 in zip(so3_dist_rot.coefficients, rotated_b_lms): max_delta = torch.max(torch.abs(part1 - part2)) self.assertTrue(max_delta < 1e-5)
def verify_probs(self, atoms): grid_points = torch.tensor(generate_fibonacci_grid(n=100_000), dtype=torch.float, device=self.device) grid_points = grid_points.unsqueeze(-2) observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist = action['dists'][-1] # Rotate atoms wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) atoms_rotated = atoms.copy() atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) observation = self.observation_space.build(atoms_rotated, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist_rot = action['dists'][-1] log_probs = so3_dist.log_prob(grid_points) # (samples, batches) log_probs_rot = so3_dist_rot.log_prob( grid_points) # (samples, batches) # Maximum over grid points maximum, max_indices = torch.max(log_probs, dim=0) minimum, min_indices = torch.min(log_probs, dim=0) maximum_rot, max_indices_rot = torch.max(log_probs_rot, dim=0) minimum_rot, min_indices_rot = torch.min(log_probs_rot, dim=0) self.assertTrue(torch.allclose(maximum, maximum_rot, atol=5e-3)) self.assertTrue(torch.allclose(minimum, minimum_rot, atol=5e-3))
def verify_invariance(self, atoms): atomic_scalars = AtomicScalars(maxl=self.agent.max_sh) observation = self.observation_space.build(atoms, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist = action['dists'][-1] scalars = atomic_scalars(so3_dist.coefficients) # Rotate atoms wigner_d, rot_mat, angles = rotations.gen_rot(self.agent.max_sh, dtype=self.agent.dtype) atoms_rotated = atoms.copy() atoms_rotated.positions = np.einsum('ij,...j->...i', rot_mat, atoms.positions) observation = self.observation_space.build(atoms_rotated, formula=self.formula) util.set_seeds(0) action = self.agent.step([observation]) so3_dist_rot = action['dists'][-1] scalars_rot = atomic_scalars(so3_dist_rot.coefficients) self.assertTrue(torch.allclose(scalars, scalars_rot, atol=1e-05))
def test_CGProduct(self, batch, maxl1, maxl2, maxl, channels): maxl_all = max(maxl1, maxl2, maxl) D, R, _ = rot.gen_rot(maxl_all) cg_dict = CGDict(maxl=maxl_all, dtype=torch.double) cg_prod = CGProduct(maxl=maxl, dtype=torch.double, cg_dict=cg_dict) tau1 = SO3Tau([channels] * (maxl1 + 1)) tau2 = SO3Tau([channels] * (maxl2 + 1)) vec1 = SO3Vec.randn(tau1, batch, dtype=torch.double) vec2 = SO3Vec.randn(tau2, batch, dtype=torch.double) vec1i = vec1.apply_wigner(D, dir='left') vec2i = vec2.apply_wigner(D, dir='left') vec_prod = cg_prod(vec1, vec2) veci_prod = cg_prod(vec1i, vec2i) vecf_prod = vec_prod.apply_wigner(D, dir='left') # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(veci_prod, vecf_prod)] assert all([d < 1e-6 for d in diff])
def test_spherical_harmonics(self, maxl, channels, conj): D, R, angles = rot.gen_rot(maxl, dtype=torch.double) D = SO3WignerD(D) if not conj: R = R.t() pos = torch.randn((channels, 3), dtype=torch.double) posr = rot.rotate_cart_vec(R, pos) cg_dict = CGDict(maxl, dtype=torch.double) sph_harms = spherical_harmonics(cg_dict, pos, maxl, conj=conj) sph_harmsr = spherical_harmonics(cg_dict, posr, maxl, conj=conj) dir = 'left' if conj else 'right' sph_harmsd = so3_torch.apply_wigner(D, sph_harms, dir=dir) # diff = (sph_harmsr - sph_harmsd).abs() diff = [(p1 - p2).abs().max() for p1, p2 in zip(sph_harmsr, sph_harmsd)] print(diff) assert all([d < 1e-6 for d in diff])
def covariance_test(model, data): logging.info('Beginning covariance test!') targets_rotout, outputs_rotin = [], [] device, dtype = data['positions'].device, data['positions'].dtype D, R, _ = rot.gen_rot(model.maxl, device=device, dtype=dtype) D = SO3WignerD(D).to(model.device, model.dtype) data_rotout = data data_rotin = { key: val.clone() if torch.is_tensor(val) else None for key, val in data.items() } data_rotin['positions'] = rot.rotate_cart_vec(R, data_rotin['positions']) outputs_rotout, reps_rotout, _ = model(data_rotout, covariance_test=True) outputs_rotin, reps_rotin, _ = model(data_rotin, covariance_test=True) invariance_test = (outputs_rotout - outputs_rotin).norm().item() reps_rotout = [reps.apply_wigner(D) for reps in reps_rotout] rep_diff = [(level_in - level_out).abs() for (level_in, level_out) in zip(reps_rotin, reps_rotout)] covariance_test_norm = [ torch.tensor([diff.norm() for diff in diffs_lvl]) for diffs_lvl in rep_diff ] covariance_test_mean = [ torch.tensor([diff.mean() for diff in diffs_lvl]) for diffs_lvl in rep_diff ] covariance_test_max = [ torch.tensor([diff.max() for diff in diffs_lvl]) for diffs_lvl in rep_diff ] covariance_test_max_all = max([max(lvl) for lvl in covariance_test_max]).item() logging.info('Rotation Invariance test: {:0.5g}'.format(invariance_test)) logging.info('Largest deviation in covariance test : {:0.5g}'.format( covariance_test_max_all)) # If the largest deviation in the covariance test is greater than 1e-5, # display l1 and l2 norms of each irrep along each level. if covariance_test_max_all > 1e-5: logging.warning( 'Largest deviation in covariance test {:0.5g} detected! Detailed summary:' .format(covariance_test_max_all)) for lvl_idx, (lvl_norm, lvl_mean, lvl_max) in enumerate( zip(covariance_test_norm, covariance_test_mean, covariance_test_max)): for ell_idx, (ell_norm, ell_mean, ell_max) in enumerate( zip(lvl_norm, lvl_mean, lvl_max)): logging.warning('(lvl, ell) = ({}, {}) -> ' '{:0.5g} (norm) {:0.5g} (mean) {:0.5g} (max)'\ .format(lvl_idx, ell_idx, ell_norm, ell_mean, ell_max))
def test_covariance(self, tau, num_channels, maxl, basis, edge_net_type, sample_batch): # env = build_environment(tau, maxl, num_channels) # datasets, data, num_species, charge_scale, sph_harms = env data, __, __ = sample_batch device, dtype = data['positions'].device, data['positions'].dtype sph_harms = SphericalHarmonicsRel(maxl - 1, conj=True, device=device, dtype=dtype, cg_dict=None) batch_size, natoms = data['positions'].shape[:2] D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype) # Setup Input atom_reps, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input( data, tau, maxl) atom_positions_rot = rot.rotate_cart_vec(R, atom_positions) atom_reps_rot = atom_reps.apply_wigner(D) # Calculate spherical harmonics and radial functions __, norms = sph_harms(atom_positions, atom_positions) __, norms_rot = sph_harms(atom_positions_rot, atom_positions_rot) rad_funcs = RadialFilters([maxl - 1], [basis, basis], [num_channels], 1, device=device, dtype=dtype) rad_func_levels = rad_funcs(norms, edge_mask * (norms > 0)) tau_pos = rad_funcs.tau[0] # Build the initial edge network if edge_net_type is None: edge_reps = None elif edge_net_type == 'rand': reps = [ torch.randn((batch_size, natoms, natoms, tau, 2)) for i in range(maxl) ] edge_reps = SO3Scalar(reps) else: raise ValueError # Build Edge layer tlist = [tau] * maxl tau_atoms = tlist tau_edge = tlist if edge_net_type is None: tau_edge = [] edge_lvl = CormorantEdgeLevel(tau_atoms, tau_edge, tau_pos, num_channels, maxl, cutoff_type='soft', device=device, dtype=dtype, hard_cut_rad=1.73, soft_cut_rad=1.73, soft_cut_width=0.2) output_edge_reps = edge_lvl(edge_reps, atom_reps, rad_func_levels[0], edge_mask, norms) output_edge_reps_rot = edge_lvl(edge_reps, atom_reps_rot, rad_func_levels[0], edge_mask, norms) for i in range(maxl): assert (torch.max( torch.abs(output_edge_reps[i] - output_edge_reps_rot[i])) < 1E-5)