def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype, format): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, format=format) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn.allocate(r) if dtype == f32 and format is partition.OrderedSparse: self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs), atol=5e-5, rtol=5e-5) else: self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))
def test_nve_neighbor_list(self, spatial_dimension, dtype): Nx = particles_per_side = 8 spacing = f32(1.25) tol = 5e-12 if dtype == np.float64 else 5e-3 L = Nx * spacing if spatial_dimension == 2: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing elif spatial_dimension == 3: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx, Nx)]) * spacing R = np.array(R, dtype) displacement, shift = space.periodic(L) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, L) exact_energy_fn = energy.lennard_jones_pair(displacement) init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn(R) state = init_fn(random.PRNGKey(0), R, neighbor=nbrs) exact_state = exact_init_fn(random.PRNGKey(0), R) def body_fn(i, state): state, nbrs, exact_state = state nbrs = neighbor_fn(state.position, nbrs) state = apply_fn(state, neighbor=nbrs) return state, nbrs, exact_apply_fn(exact_state) step = 0 for i in range(20): new_state, nbrs, new_exact_state = lax.fori_loop( 0, 100, body_fn, (state, nbrs, exact_state)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn(state.position) else: state = new_state exact_state = new_exact_state step += 1 assert state.position.dtype == dtype self.assertAllClose(state.position, exact_state.position, atol=tol, rtol=tol)
def get_energy_function(self) -> Tuple[NeighborFn, EnergyFn]: normalized_ro = self.ro / self.sigma normalized_rc = self.rc / self.sigma return energy.lennard_jones_neighbor_list( self.displacement, self.box, sigma=jnp.array(self.sigma, dtype=self.global_dtype), epsilon=jnp.array(self.epsilon, dtype=self.global_dtype), r_onset=jnp.array(normalized_ro, dtype=self.global_dtype), r_cutoff=jnp.array(normalized_rc, dtype=self.global_dtype), per_particle=True, dr_threshold=self.dr_threshold, format=NeighborListFormat.Dense)
def test_lennard_jones_small_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(5.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.lennard_jones_pair(displacement) R = box_size * random.uniform(key, (10, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size) nbrs = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_lennard_jones_cell_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.lennard_jones_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, R) idx = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, idx), True)
def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn(r) self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))