def test_nvt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.nvt_nose_hoover_invariant, E) kT = 1e-3 init_fn, apply_fn = simulate.nvt_nose_hoover(E, shift_fn, 1e-3, kT=kT, sy_steps=sy_steps) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position) E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS,)) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, kT)) return state, energy Es = np.zeros((DYNAMICS_STEPS,)) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_nve_jammed_periodic_general(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box, coords == 'fractional') E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS,)) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS,)) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_pressure_jammed_periodic(self, dtype): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(jnp.diag(state.box)) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) pos = state.real_position tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure, atol=tol, rtol=tol)
def test_pressure_jammed_periodic_general(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box, coords == 'fractional') print(state.pressure) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) pos = getattr(state, coords + '_position') tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure, atol=tol, rtol=tol)
def test_swap_mc_jammed(self, dtype): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) space_fn = space.periodic(state.box[0, 0]) displacement_fn, shift_fn = space_fn sigma = np.diag(state.sigma)[state.species] energy_fn = lambda dr, sigma: energy.soft_sphere(dr, sigma=sigma) neighbor_fn = partition.neighbor_list(displacement_fn, state.box[0, 0], np.max(sigma) + 0.1, dr_threshold=0.5) kT = 1e-2 t_md = 0.1 N_swap = 10 init_fn, apply_fn = simulate.hybrid_swap_mc(space_fn, energy_fn, neighbor_fn, 1e-3, kT, t_md, N_swap) state = init_fn(key, state.real_position, sigma) Ts = np.zeros((DYNAMICS_STEPS,)) def step_fn(i, state_and_temp): state, temp = state_and_temp state = apply_fn(state) temp = temp.at[i].set(quantity.temperature(state.md.velocity)) return state, temp state, Ts = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Ts)) tol = 5e-4 self.assertAllClose(Ts[10:], kT * np.ones((DYNAMICS_STEPS - 10)), rtol=5e-1, atol=5e-3) self.assertAllClose(np.mean(Ts[10:]), kT, rtol=tol, atol=tol) self.assertTrue(not np.all(state.sigma == sigma))
def test_npt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.npt_nose_hoover_invariant, E) pressure_fn = partial(quantity.pressure, E) nhc_kwargs = {sy_steps: sy_steps} kT = 1e-3 P = state.pressure init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT, nhc_kwargs, nhc_kwargs) apply_fn = jit(apply_fn) state = init_fn(key, state.fractional_position, state.box) E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, )) P_target = P * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_energy_pressure): state, energy, pressure = state_energy_pressure state = apply_fn(state) energy = energy.at[i].set(invariant(state, P, kT)) box = simulate.npt_box(state) KE = quantity.kinetic_energy(state.velocity, state.mass) p = pressure_fn(state.position, box, KE) pressure = pressure.at[i].set(p) return state, energy, pressure Es = np.zeros((DYNAMICS_STEPS, )) Ps = np.zeros((DYNAMICS_STEPS, )) state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es, Ps)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05)