예제 #1
0
파일: energy_test.py 프로젝트: zizai/jax-md
 def test_simple_spring(self, spatial_dimension, dtype):
   key = random.PRNGKey(0)
   disp, _ = space.free()
   if spatial_dimension == 2:
     R = np.array([[0., 0.], [1., 1.]], dtype=dtype)
     dist = np.sqrt(2.)
   elif spatial_dimension == 3:
     R = np.array([[0., 0., 0.], [1., 1., 1.]], dtype=dtype)
     dist = np.sqrt(3.)
   bonds = np.array([[0, 1]], np.int32)
   for _ in range(STOCHASTIC_SAMPLES):
     key, l_key, a_key = random.split(key, 3)
     length = random.uniform(key, (), minval=0.1, maxval=3.0)
     alpha = random.uniform(key, (), minval=2., maxval=4.)
     E = energy.simple_spring_bond(disp, bonds, length=length, alpha=alpha)
     E_exact = dtype((dist - length) ** alpha / alpha)
     self.assertAllClose(E(R), E_exact, True)
예제 #2
0
    def test_nvt_nose_hoover(self, spatial_dimension, dtype, sy_steps):
        key = random.PRNGKey(0)

        box_size = quantity.box_size_at_number_density(PARTICLE_COUNT,
                                                       f32(1.2),
                                                       spatial_dimension)
        displacement_fn, shift_fn = space.periodic(box_size)

        bonds_i = np.arange(PARTICLE_COUNT)
        bonds_j = np.roll(bonds_i, 1)
        bonds = np.stack([bonds_i, bonds_j])

        E = energy.simple_spring_bond(displacement_fn, bonds)

        invariant = partial(simulate.nvt_nose_hoover_invariant, E)

        for _ in range(STOCHASTIC_SAMPLES):
            key, pos_key, vel_key, T_key, masses_key = random.split(key, 5)

            R = box_size * random.uniform(pos_key,
                                          (PARTICLE_COUNT, spatial_dimension),
                                          dtype=dtype)
            T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype)
            mass = 1 + random.uniform(masses_key, (PARTICLE_COUNT, ),
                                      dtype=dtype)
            init_fn, apply_fn = simulate.nvt_nose_hoover(E,
                                                         shift_fn,
                                                         1e-3,
                                                         T,
                                                         sy_steps=sy_steps)
            apply_fn = jit(apply_fn)

            state = init_fn(vel_key, R, mass=mass)

            initial = invariant(state, T)

            for _ in range(DYNAMICS_STEPS):
                state = apply_fn(state)

            T_final = quantity.temperature(state.velocity, state.mass)
            assert np.abs(T_final - T) / T < 0.1
            tol = 5e-4 if dtype is f32 else 1e-6
            self.assertAllClose(invariant(state, T), initial, rtol=tol)
            self.assertEqual(state.position.dtype, dtype)