def test_radial_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype, dim): key = random.PRNGKey(0) N = 128 box_size = 12.0 r_cutoff = 3. displacement, shift = space.periodic(box_size) R_key, species_key = random.split(key) R = box_size * random.uniform(R_key, (N, dim)) species = random.choice(species_key, N_types, (N, )) neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff, 0.) gr = nn.radial_symmetry_functions( displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype), r_cutoff) gr_neigh = nn.radial_symmetry_functions_neighbor_list( displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype), r_cutoff) nbrs = neighbor_fn(R) gr_exact = gr(R) gr_nbrs = gr_neigh(R, neighbor=nbrs) tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6 self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)
def test_radial_symmetry_functions(self, N_types, N_etas, dtype): displacement, shift = space.free() gr = nn.radial_symmetry_functions(displacement, np.array([1, 1, N_types]), np.linspace(1.0, 2.0, N_etas, dtype=dtype), 4) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) gr_out = gr(R) self.assertAllClose(gr_out.shape, (3, N_types * N_etas)) self.assertAllClose(gr_out[2, 0], dtype(0.411717), rtol=1e-6, atol=1e-6)