コード例 #1
0
ファイル: nn_test.py プロジェクト: scnlong/jax-md
    def test_radial_symmetry_functions_neighbor_list(self, N_types, N_etas,
                                                     dtype, dim):
        key = random.PRNGKey(0)

        N = 128
        box_size = 12.0
        r_cutoff = 3.

        displacement, shift = space.periodic(box_size)
        R_key, species_key = random.split(key)
        R = box_size * random.uniform(R_key, (N, dim))
        species = random.choice(species_key, N_types, (N, ))

        neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff,
                                              0.)

        gr = nn.radial_symmetry_functions(
            displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype),
            r_cutoff)
        gr_neigh = nn.radial_symmetry_functions_neighbor_list(
            displacement, species, np.linspace(1.0, 2.0, N_etas, dtype=dtype),
            r_cutoff)
        nbrs = neighbor_fn(R)
        gr_exact = gr(R)
        gr_nbrs = gr_neigh(R, neighbor=nbrs)

        tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6
        self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)
コード例 #2
0
ファイル: nn_test.py プロジェクト: saridout/jax-md
 def test_radial_symmetry_functions(self, N_types, N_etas, dtype):
   displacement, shift = space.free()
   gr = nn.radial_symmetry_functions(displacement, 
                                     np.array([1, 1, N_types]), 
                                     np.linspace(1.0, 2.0, N_etas, dtype=dtype), 
                                     4)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   gr_out = gr(R)
   self.assertAllClose(gr_out.shape, (3, N_types * N_etas))
   self.assertAllClose(gr_out[2, 0], dtype(0.411717), rtol=1e-6, atol=1e-6)