Ejemplo n.º 1
0
  def test_nvt_nose_hoover_jammed(self, dtype, sy_steps):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic(state.box[0, 0])

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    invariant = partial(simulate.nvt_nose_hoover_invariant, E)

    kT = 1e-3
    init_fn, apply_fn = simulate.nvt_nose_hoover(E, shift_fn, 1e-3,
                                                 kT=kT, sy_steps=sy_steps)
    apply_fn = jit(apply_fn)

    state = init_fn(key, state.real_position)

    E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS,))

    def step_fn(i, state_and_energy):
      state, energy = state_and_energy
      state = apply_fn(state)
      energy = ops.index_update(energy, i, invariant(state, kT))
      return state, energy

    Es = np.zeros((DYNAMICS_STEPS,))
    state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

    tol = 1e-3 if dtype is f32 else 1e-7
    self.assertEqual(state.position.dtype, dtype)
    self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
Ejemplo n.º 2
0
  def test_nve_jammed_periodic_general(self, dtype, coords):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic_general(state.box,
                                                       coords == 'fractional')

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)

    init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
    apply_fn = jit(apply_fn)

    state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3)

    E_T = lambda state: \
        E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
    E_initial = E_T(state) * np.ones((DYNAMICS_STEPS,))

    def step_fn(i, state_and_energy):
      state, energy = state_and_energy
      state = apply_fn(state)
      energy = ops.index_update(energy, i, E_T(state))
      return state, energy

    Es = np.zeros((DYNAMICS_STEPS,))
    state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

    tol = 1e-3 if dtype is f32 else 1e-7
    self.assertEqual(state.position.dtype, dtype)
    self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
Ejemplo n.º 3
0
  def test_pressure_jammed_periodic(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic(jnp.diag(state.box))

    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = state.real_position

    tol = 1e-7 if dtype is f64 else 2e-5

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure,
                        atol=tol, rtol=tol)
Ejemplo n.º 4
0
  def test_pressure_jammed_periodic_general(self, dtype, coords):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    displacement_fn, shift_fn = space.periodic_general(state.box,
                                                       coords == 'fractional')
    print(state.pressure)
    E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma)
    pos = getattr(state, coords + '_position')

    tol = 1e-7 if dtype is f64 else 2e-5

    self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure,
                        atol=tol, rtol=tol)
Ejemplo n.º 5
0
  def test_swap_mc_jammed(self, dtype):
    key = random.PRNGKey(0)

    state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
    space_fn = space.periodic(state.box[0, 0])
    displacement_fn, shift_fn = space_fn

    sigma = np.diag(state.sigma)[state.species]

    energy_fn = lambda dr, sigma: energy.soft_sphere(dr, sigma=sigma)
    neighbor_fn = partition.neighbor_list(displacement_fn,
                                          state.box[0, 0],
                                          np.max(sigma) + 0.1,
                                          dr_threshold=0.5)

    kT = 1e-2
    t_md = 0.1
    N_swap = 10
    init_fn, apply_fn = simulate.hybrid_swap_mc(space_fn,
                                                energy_fn,
                                                neighbor_fn,
                                                1e-3,
                                                kT,
                                                t_md,
                                                N_swap)
    state = init_fn(key, state.real_position, sigma)

    Ts = np.zeros((DYNAMICS_STEPS,))

    def step_fn(i, state_and_temp):
      state, temp = state_and_temp
      state = apply_fn(state)
      temp = temp.at[i].set(quantity.temperature(state.md.velocity))
      return state, temp

    state, Ts = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Ts))

    tol = 5e-4
    self.assertAllClose(Ts[10:],
                        kT * np.ones((DYNAMICS_STEPS - 10)),
                        rtol=5e-1,
                        atol=5e-3)
    self.assertAllClose(np.mean(Ts[10:]), kT, rtol=tol, atol=tol)
    self.assertTrue(not np.all(state.sigma == sigma))
Ejemplo n.º 6
0
    def test_npt_nose_hoover_jammed(self, dtype, sy_steps):
        key = random.PRNGKey(0)

        state = test_util.load_jammed_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic_general(state.box)

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)
        invariant = partial(simulate.npt_nose_hoover_invariant, E)
        pressure_fn = partial(quantity.pressure, E)

        nhc_kwargs = {sy_steps: sy_steps}
        kT = 1e-3
        P = state.pressure
        init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT,
                                                     nhc_kwargs, nhc_kwargs)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.fractional_position, state.box)

        E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, ))
        P_target = P * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_energy_pressure):
            state, energy, pressure = state_energy_pressure
            state = apply_fn(state)
            energy = energy.at[i].set(invariant(state, P, kT))
            box = simulate.npt_box(state)
            KE = quantity.kinetic_energy(state.velocity, state.mass)
            p = pressure_fn(state.position, box, KE)
            pressure = pressure.at[i].set(p)
            return state, energy, pressure

        Es = np.zeros((DYNAMICS_STEPS, ))
        Ps = np.zeros((DYNAMICS_STEPS, ))
        state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn,
                                      (state, Es, Ps))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
        self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05)