def test_lennard_jones_small_neighbor_list_energy( self, spatial_dimension, dtype, format): key = random.PRNGKey(1) box_size = f32(5.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.lennard_jones_pair(displacement) R = box_size * random.uniform( key, (10, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose( np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, R) force_fn = quantity.force(energy_fn) idx = neighbor_fn(R) self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype), force_fn(R, idx), True)
def test_nve_neighbor_list(self, spatial_dimension, dtype): Nx = particles_per_side = 8 spacing = f32(1.25) tol = 5e-12 if dtype == np.float64 else 5e-3 L = Nx * spacing if spatial_dimension == 2: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing elif spatial_dimension == 3: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx, Nx)]) * spacing R = np.array(R, dtype) displacement, shift = space.periodic(L) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement, L) exact_energy_fn = energy.lennard_jones_pair(displacement) init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn(R) state = init_fn(random.PRNGKey(0), R, neighbor=nbrs) exact_state = exact_init_fn(random.PRNGKey(0), R) def body_fn(i, state): state, nbrs, exact_state = state nbrs = neighbor_fn(state.position, nbrs) state = apply_fn(state, neighbor=nbrs) return state, nbrs, exact_apply_fn(exact_state) step = 0 for i in range(20): new_state, nbrs, new_exact_state = lax.fori_loop( 0, 100, body_fn, (state, nbrs, exact_state)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn(state.position) else: state = new_state exact_state = new_exact_state step += 1 assert state.position.dtype == dtype self.assertAllClose(state.position, exact_state.position, atol=tol, rtol=tol)
def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype, format): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force(energy.lennard_jones_pair(displacement)) r = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, format=format) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn.allocate(r) if dtype == f32 and format is partition.OrderedSparse: self.assertAllClose( np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs), atol=5e-5, rtol=5e-5) else: self.assertAllClose( np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))
default=0.8442, type=float) args = parser.parse_args() edge_length = pow(args.parts / args.dense, 1.0 / 3.0) # edge_length*=2 spatial_dimension = 3 box_size = onp.asarray([edge_length] * spatial_dimension) displacement_fn, shift_fn = space.periodic(box_size) key = random.PRNGKey(0) R = random.uniform(key, (args.parts, spatial_dimension), minval=0.0, maxval=box_size[0], dtype=np.float64) # print(R) energy_fn = energy.lennard_jones_pair(displacement_fn) print('E = {}'.format(energy_fn(R))) force_fn = quantity.force(energy_fn) print('Total Squared Force = {}'.format(np.sum(force_fn(R)**2))) init, apply = simulate.nve(energy_fn, shift_fn, args.time / args.steps) apply = jit(apply) state = init(key, R, velocity_scale=0.0) PE = [] KE = [] print_every = args.log old_time = time.perf_counter() print('Step\tKE\tPE\tTotal Energy\ttime/step') print('----------------------------------------') for i in range(args.steps // print_every):