def test_nve_jammed_periodic_general(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general( state.box, coords == 'fractional') E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_nvt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.nvt_nose_hoover_invariant, E) kT = 1e-3 init_fn, apply_fn = simulate.nvt_nose_hoover(E, shift_fn, 1e-3, kT=kT, sy_steps=sy_steps) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position) E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, kT)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_pressure_non_minimized_free(self, dim, dtype): key = random.PRNGKey(0) N = 64 box = quantity.box_size_at_number_density(N, 0.8, dim) displacement_fn, _ = space.free() pos = random.uniform(key, (N, dim)) * box energy_fn = energy.soft_sphere_pair(displacement_fn) def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1)) exact_pressure = -1 / dim * jnp.trace(exact_stress(pos)) ad_pressure = quantity.pressure(energy_fn, pos, box) tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(exact_pressure, ad_pressure, atol=tol, rtol=tol)
def test_nve_jammed(self, spatial_dimension, dtype): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position, kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = energy.at[i].set(E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
def test_stress_non_minimized_periodic_general(self, dim, dtype, coords): key = random.PRNGKey(0) N = 64 box = quantity.box_size_at_number_density(N, 0.8, dim) displacement_fn, _ = space.periodic_general(box, coords == 'fractional') pos = random.uniform(key, (N, dim)) pos = pos if coords == 'fractional' else pos * box energy_fn = energy.soft_sphere_pair(displacement_fn) def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1)) exact_stress = exact_stress(pos) ad_stress = quantity.stress(energy_fn, pos, box) tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(exact_stress, ad_stress, atol=tol, rtol=tol)
def test_nve_ensemble_time_dependence(self, spatial_dimension, dtype): key = random.PRNGKey(0) pos_key, center_key, vel_key, mass_key = random.split(key, 4) R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mass = random.uniform(mass_key, (PARTICLE_COUNT, ), minval=0.1, maxval=5.0, dtype=dtype) displacement, shift = space.free() E = energy.soft_sphere_pair(displacement) init_fn, apply_fn = simulate.nve(E, shift, 1e-3) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, mass=mass) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) for t in range(SHORT_DYNAMICS_STEPS): state = apply_fn(state, t=t * 1e-3) E_total = E_T(state) assert np.abs(E_total - E_initial) < E_initial * 0.01 assert state.position.dtype == dtype
def test_pressure_jammed(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box, coords == 'fractional') E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) pos = getattr(state, coords + '_position') self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure)
def run(N=32, n_iter=1000, with_jit=True): import jax.numpy as jnp from jax import random, jit from jax_md import space, energy, simulate # MD configs dt = 1e-1 temperature = 0.1 # R: current position # dR: displacement # displacement(Ra, Rb): # dR = Ra - Rb # periodic displacement(Ra, Rb): # dR = Ra - Rb # np.mod(dR + side * f32(0.5), side) - f32(0.5) * side # periodic shift: # np.mod(R + dR, side) # shift: # R + dR displacement, shift = space.free() # Simulation init # dr: pairwise distances # epsilon: interaction energy scale (const) # alpha: interaction stiffness # dr = distance(R) # U(dr) = np.where(dr < 1.0, (1 - dr) ** 2, 0) # energy_fn(R) = diagonal_mask(U(dr)) energy_fn = energy.soft_sphere_pair(displacement) # force(energy) = -d(energy)/dR # xi = random.normal(R.shape, R.dtype) # gamma = 0.1 # nu = 1 / (mass * gamma) # dR = force(R) * dt * nu + np.sqrt(2 * temperature * dt * nu) * xi # BrownianState(position, mass, rng) pos_key, sim_key = random.split(random.PRNGKey(0)) R = random.uniform(pos_key, (N, 2), dtype=jnp.float32) init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, temperature) if with_jit: apply_fn = jit(apply_fn) state = init_fn(sim_key, R) # Start simulation times = [] for i in range(n_iter): time_start = time.perf_counter_ns() state = apply_fn(state) time_end = time.perf_counter_ns() times.append(time_end - time_start) # Finish with profiling times return times
def test_soft_sphere_cell_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) exact_energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) energy_fn = energy.soft_sphere_cell_list(displacement, box_size, R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R), True)
def test_pressure_jammed_periodic(self, dtype): key = random.PRNGKey(0) state = test_util.load_jammed_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(jnp.diag(state.box)) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) pos = state.real_position tol = 1e-7 if dtype is f64 else 2e-5 self.assertAllClose(quantity.pressure(E, pos, state.box), state.pressure, atol=tol, rtol=tol)
def test_pair_grid_force_incommensurate(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(12.1) cell_size = f32(3.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement, quantity.Dynamic) force_fn = quantity.force(energy_fn) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_force_fn = jit(smap.grid(force_fn, box_size, cell_size, R)) species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64) self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype), grid_force_fn(R), True)
def make_periodic_general_test_system(N, dim, dtype, box_format): assert box_format in BOX_FORMATS box_size = quantity.box_size_at_number_density(N, 1.0, dim) box = dtype(box_size) if box_format == 'vector': box = jnp.array(jnp.ones(dim) * box_size, dtype) elif box_format == 'matrix': box = jnp.array(jnp.eye(dim) * box_size, dtype) d, s = space.periodic(jnp.diag(box) if box_format == 'matrix' else box) d_gf, s_gf = space.periodic_general(box) d_g, s_g = space.periodic_general(box, fractional_coordinates=False) key = random.PRNGKey(0) R_f = random.uniform(key, (N, dim), dtype=dtype) R = space.transform(box, R_f) E = jit(energy.soft_sphere_pair(d)) E_gf = jit(energy.soft_sphere_pair(d_gf)) E_g = jit(energy.soft_sphere_pair(d_g)) return R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g)
def main(unused_argv): key = random.PRNGKey(0) # Setup some variables describing the system. N = 500 dimension = 2 box_size = f32(25.0) # Create helper functions to define a periodic box of some size. displacement, shift = space.periodic(box_size) metric = space.metric(displacement) # Use JAX's random number generator to generate random initial positions. key, split = random.split(key) R = random.uniform(split, (N, dimension), minval=0.0, maxval=box_size, dtype=f32) # The system ought to be a 50:50 mixture of two types of particles, one # large and one small. sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32) N_2 = int(N / 2) species = np.array([0] * N_2 + [1] * N_2, dtype=i32) # Create an energy function. energy_fn = energy.soft_sphere_pair(displacement, species, sigma) force_fn = quantity.force(energy_fn) # Create a minimizer. init_fn, apply_fn = minimize.fire_descent(energy_fn, shift) opt_state = init_fn(R) # Minimize the system. minimize_steps = 50 print_every = 10 print('Minimizing.') print('Step\tEnergy\tMax Force') print('-----------------------------------') for step in range(minimize_steps): opt_state = apply_fn(opt_state) if step % print_every == 0: R = opt_state.position print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R), np.max(force_fn(R))))
def test_cell_list_incommensurate(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(12.1) cell_size = f32(3.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_list_energy = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) cell_list_energy = \ jit(smap.cell_list(cell_list_energy, box_size, cell_size, R)) self.assertAllClose( np.array(energy_fn(R), dtype=dtype), cell_list_energy(R), True)
def test_pair_cell_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.soft_sphere_pair(displacement) energy_fn = smap.cartesian_product(energy.soft_sphere, metric) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_energy_fn = smap.cell_list(energy_fn, box_size, cell_size, R) self.assertAllClose( np.array(exact_energy_fn(R), dtype=dtype), cell_energy_fn(R), True)
def test_cell_list_direct_force_jit(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement) force_fn = quantity.force(energy_fn) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_energy_fn = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) grid_force_fn = quantity.force(grid_energy_fn) grid_force_fn = jit(smap.cell_list(grid_force_fn, box_size, cell_size, R)) self.assertAllClose( np.array(force_fn(R), dtype=dtype), grid_force_fn(R), True)
def test_soft_sphere_neighbor_list_energy(self, spatial_dimension, dtype, format): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) exact_energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list( displacement, box_size, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_pair_grid_force_nonuniform(self, spatial_dimension, dtype): key = random.PRNGKey(1) if spatial_dimension == 2: box_size = f32(np.array([[8.0, 10.0]])) else: box_size = f32(np.array([[8.0, 10.0, 12.0]])) cell_size = f32(2.0) displacement, _ = space.periodic(box_size[0]) energy_fn = energy.soft_sphere_pair(displacement, quantity.Dynamic) force_fn = quantity.force(energy_fn) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_force_fn = smap.grid(force_fn, box_size, cell_size, R) species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64) self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype), grid_force_fn(R), True)
def test_EMT_from_db_dynamic(self, spatial_dimension, dtype, low_pressure): if spatial_dimension == 2: N = 64 else: N = 128 if dtype == jnp.float32: max_grad_thresh = 1e-5 atol = 1e-4 rtol = 1e-3 else: max_grad_thresh = 1e-10 atol = 1e-8 rtol = 1e-5 for index in range(NUM_SAMPLES): cijkl, R, sigma, box = test_util.load_elasticity_test_data( spatial_dimension, low_pressure, dtype, index) R = space.transform(box, R) box = box[0, 0] displacement, shift = space.periodic(box) #Below we use the wrong sigma, so we must pass it dynamically energy_fn = energy.soft_sphere_pair(displacement, sigma=1.0) maxgrad = jnp.max(jnp.abs(grad(energy_fn)(R, sigma=sigma))) assert (maxgrad < max_grad_thresh) EMT_fn = jit( elasticity.athermal_moduli(energy_fn, check_convergence=True)) C, converged = EMT_fn(R, box, sigma=sigma) assert (C.dtype == dtype) assert (C.shape == (spatial_dimension, spatial_dimension, spatial_dimension, spatial_dimension)) if converged: self.assertAllClose(cijkl, elasticity._extract_elements(C, False), atol=atol, rtol=rtol) #make sure the symmetries are there self.assertAllClose(C, jnp.einsum("ijkl->jikl", C)) self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C)) self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
def test_npt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.npt_nose_hoover_invariant, E) pressure_fn = partial(quantity.pressure, E) nhc_kwargs = {sy_steps: sy_steps} kT = 1e-3 P = state.pressure init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT, nhc_kwargs, nhc_kwargs) apply_fn = jit(apply_fn) state = init_fn(key, state.fractional_position, state.box) E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, )) P_target = P * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_energy_pressure): state, energy, pressure = state_energy_pressure state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, P, kT)) box = simulate.npt_box(state) KE = quantity.kinetic_energy(state.velocity, state.mass) p = pressure_fn(state.position, box, KE) pressure = ops.index_update(pressure, i, p) return state, energy, pressure Es = np.zeros((DYNAMICS_STEPS, )) Ps = np.zeros((DYNAMICS_STEPS, )) state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es, Ps)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05)
def test_EMT_from_db_fraccoord(self, spatial_dimension, dtype, low_pressure): if spatial_dimension == 2: N = 64 else: N = 128 if dtype == jnp.float32: max_grad_thresh = 1e-5 atol = 1e-4 rtol = 1e-3 else: max_grad_thresh = 1e-10 atol = 1e-8 rtol = 1e-5 for index in range(NUM_SAMPLES): cijkl, R, sigma, box = test_util.load_elasticity_test_data( spatial_dimension, low_pressure, dtype, index) displacement, shift = space.periodic_general( box, fractional_coordinates=True) energy_fn = energy.soft_sphere_pair(displacement, sigma=sigma) assert (jnp.max(jnp.abs(grad(energy_fn)(R))) < max_grad_thresh) EMT_fn = jit( elasticity.athermal_moduli(energy_fn, check_convergence=True)) C, converged = EMT_fn(R, box) assert (C.dtype == dtype) assert (C.shape == (spatial_dimension, spatial_dimension, spatial_dimension, spatial_dimension)) if converged: self.assertAllClose(cijkl, elasticity._extract_elements(C, False), atol=atol, rtol=rtol) #make sure the symmetries are there self.assertAllClose(C, jnp.einsum("ijkl->jikl", C)) self.assertAllClose(C, jnp.einsum("ijkl->ijlk", C)) self.assertAllClose(C, jnp.einsum("ijkl->lkij", C))
def test_cell_list_force_nonuniform(self, spatial_dimension, dtype): key = random.PRNGKey(1) if spatial_dimension == 2: box_size = f32(np.array([[8.0, 10.0]])) else: box_size = f32(np.array([[8.0, 10.0, 12.0]])) cell_size = f32(2.0) displacement, _ = space.periodic(box_size[0]) energy_fn = energy.soft_sphere_pair(displacement) force_fn = quantity.force(energy_fn) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) cell_energy_fn = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) cell_force_fn = quantity.force(cell_energy_fn) cell_force_fn = smap.cell_list(cell_force_fn, box_size, cell_size, R) df = np.sum((force_fn(R) - cell_force_fn(R)) ** 2, axis=1) self.assertAllClose( np.array(force_fn(R), dtype=dtype), cell_force_fn(R), True)