示例#1
0
 def test_angular_symmetry_functions(self, N_types, N_etas, dtype):
   displacement, shift = space.free()
   gr = nn.angular_symmetry_functions(displacement,np.array([1, 1, N_types]), 
                                      etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), 
                                      lambdas=np.array([-1.0] * N_etas, dtype), 
                                      zetas=np.array([1.0] * N_etas, dtype), 
                                      cutoff_distance=8.0)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   gr_out = gr(R)
   self.assertAllClose(gr_out.shape, (3, N_etas *  N_types * (N_types + 1) // 2))
   self.assertAllClose(gr_out[2, 0], dtype(1.577944), rtol=1e-6, atol=1e-6)
示例#2
0
文件: nn_test.py 项目: scnlong/jax-md
    def test_angular_symmetry_functions_neighbor_list(self, N_types, N_etas,
                                                      dtype, dim):
        key = random.PRNGKey(0)

        N = 128
        box_size = 12.0
        r_cutoff = 3.

        displacement, shift = space.periodic(box_size)
        R_key, species_key = random.split(key)
        R = box_size * random.uniform(R_key, (N, dim))
        species = random.choice(species_key, N_types, (N, ))

        neighbor_fn = partition.neighbor_list(displacement, box_size, r_cutoff,
                                              0.)

        etas = np.linspace(1., 2., N_etas, dtype=dtype)
        gr = nn.angular_symmetry_functions(displacement,
                                           species,
                                           etas=etas,
                                           lambdas=np.array([-1.0] * N_etas,
                                                            dtype),
                                           zetas=np.array([1.0] * N_etas,
                                                          dtype),
                                           cutoff_distance=r_cutoff)

        gr_neigh = nn.angular_symmetry_functions_neighbor_list(
            displacement,
            species,
            etas=etas,
            lambdas=np.array([-1.0] * N_etas, dtype),
            zetas=np.array([1.0] * N_etas, dtype),
            cutoff_distance=r_cutoff)

        nbrs = neighbor_fn(R)
        gr_exact = gr(R)
        gr_nbrs = gr_neigh(R, neighbor=nbrs)

        tol = 1e-13 if FLAGS.jax_enable_x64 else 1e-6
        self.assertAllClose(gr_exact, gr_nbrs, atol=tol, rtol=tol)