def test_nve_neighbor_list(self, spatial_dimension, dtype): Nx = particles_per_side = 8 spacing = f32(1.25) tol = 5e-12 if dtype == np.float64 else 5e-3 L = Nx * spacing if spatial_dimension == 2: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing elif spatial_dimension == 3: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx, Nx)]) * spacing R = np.array(R, dtype) displacement, shift = space.periodic(L) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, L) exact_energy_fn = energy.lennard_jones_pair(displacement) init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn(R) state = init_fn(random.PRNGKey(0), R, neighbor=nbrs) exact_state = exact_init_fn(random.PRNGKey(0), R) def body_fn(i, state): state, nbrs, exact_state = state nbrs = neighbor_fn(state.position, nbrs) state = apply_fn(state, neighbor=nbrs) return state, nbrs, exact_apply_fn(exact_state) step = 0 for i in range(20): new_state, nbrs, new_exact_state = lax.fori_loop( 0, 100, body_fn, (state, nbrs, exact_state)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn(state.position) else: state = new_state exact_state = new_exact_state step += 1 assert state.position.dtype == dtype self.assertAllClose(state.position, exact_state.position, atol=tol, rtol=tol)
def test_nve_ensemble(self, spatial_dimension, dtype): key = random.PRNGKey(0) pos_key, center_key, vel_key, mass_key = random.split(key, 4) R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mass = random.uniform(mass_key, (PARTICLE_COUNT, ), minval=0.1, maxval=5.0, dtype=dtype) _, shift = space.free() E = lambda R, **kwargs: np.sum((R - R0)**2) init_fn, apply_fn = simulate.nve(E, shift, 1e-3) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, mass=mass) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) for _ in range(DYNAMICS_STEPS): state = apply_fn(state) E_total = E_T(state) assert np.abs(E_total - E_initial) < E_initial * 0.01 assert state.position.dtype == dtype
def test_nve_ensemble_time_dependence(self, spatial_dimension, dtype): key = random.PRNGKey(0) pos_key, center_key, vel_key, mass_key = random.split(key, 4) R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mass = random.uniform(mass_key, (PARTICLE_COUNT, ), minval=0.1, maxval=5.0, dtype=dtype) displacement, shift = space.free() E = energy.soft_sphere_pair(displacement) init_fn, apply_fn = simulate.nve(E, shift, 1e-3) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, mass=mass) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) for t in range(SHORT_DYNAMICS_STEPS): state = apply_fn(state, t=t * 1e-3) E_total = E_T(state) assert np.abs(E_total - E_initial) < E_initial * 0.01 assert state.position.dtype == dtype
def test_nve_jammed_periodic_general(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general( state.box, coords == 'fractional') E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_nve_jammed(self, spatial_dimension, dtype): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position, kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = energy.at[i].set(E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
edge_length = pow(args.parts / args.dense, 1.0 / 3.0) # edge_length*=2 spatial_dimension = 3 box_size = onp.asarray([edge_length] * spatial_dimension) displacement_fn, shift_fn = space.periodic(box_size) key = random.PRNGKey(0) R = random.uniform(key, (args.parts, spatial_dimension), minval=0.0, maxval=box_size[0], dtype=np.float64) # print(R) energy_fn = energy.lennard_jones_pair(displacement_fn) print('E = {}'.format(energy_fn(R))) force_fn = quantity.force(energy_fn) print('Total Squared Force = {}'.format(np.sum(force_fn(R)**2))) init, apply = simulate.nve(energy_fn, shift_fn, args.time / args.steps) apply = jit(apply) state = init(key, R, velocity_scale=0.0) PE = [] KE = [] print_every = args.log old_time = time.perf_counter() print('Step\tKE\tPE\tTotal Energy\ttime/step') print('----------------------------------------') for i in range(args.steps // print_every): state = lax.fori_loop(0, print_every, lambda _, state: apply(state), state) PE += [energy_fn(state.position)] KE += [quantity.kinetic_energy(state.velocity)]