def test_angular_symmetry_functions(self, N_types, N_etas, dtype): displacement, shift = space.free() gr = nn.angular_symmetry_functions(displacement,np.array([1, 1, N_types]), etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=8.0) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) gr_out = gr(R) self.assertAllClose(gr_out.shape, (3, N_etas * N_types * (N_types + 1) // 2)) self.assertAllClose(gr_out[2, 0], dtype(1.577944), rtol=1e-6, atol=1e-6)
def test_angular_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype, dim): key = random.PRNGKey(0) N = 128 box_size = 12.0 r_cutoff = 3. displacement, shift = space.periodic(box_size) R_key, species_key = random.split(key) R = box_size * random.uniform(R_key, (N, dim)) species = random.choice(species_key, N_types, (N, )) neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff, 0.) etas = np.linspace(1., 2., N_etas, dtype=dtype) gr = nn.angular_symmetry_functions(displacement, species, etas=etas, lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=r_cutoff) gr_neigh = nn.angular_symmetry_functions_neighbor_list( displacement, species, etas=etas, lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=r_cutoff) nbrs = neighbor_fn(R) gr_exact = gr(R) gr_nbrs = gr_neigh(R, neighbor=nbrs) tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6 self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)