def test_brownian(self, spatial_dimension, dtype): key = random.PRNGKey(0) key, T_split, mass_split = random.split(key, 3) _, shift = space.free() energy_fn = lambda R, **kwargs: f32(0) R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype) mass = random.uniform(mass_split, (), minval=0.1, maxval=10.0, dtype=dtype) T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype) dt = f32(1e-2) gamma = f32(0.1) init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, T, gamma=gamma) apply_fn = jit(apply_fn) state = init_fn(key, R, mass) sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt) for _ in range(BROWNIAN_DYNAMICS_STEPS): state = apply_fn(state) msd = np.var(state.position) th_msd = dtype(2 * T / (mass * gamma) * sim_t) assert np.abs(msd - th_msd) / msd < 1e-2 assert state.position.dtype == dtype
def run_brownian(energy_fn, R_init, shift, key, num_steps, kT, dt=0.00001, gamma=0.1) -> jnp.ndarray: """Simulate Brownian motion.""" init, apply = simulate.brownian(energy_fn, shift, dt=dt, kT=kT, gamma=gamma) apply = jit(apply) @jit def scan_fn(state, current_step): # Dynamically pass r0 to apply, which passes it on to energy_fn return apply(state), 0 key, split = random.split(key) state = init(split, R_init) state, _ = lax.scan(scan_fn, state, jnp.arange(num_steps)) return state.position
def run(N=32, n_iter=1000, with_jit=True): import jax.numpy as jnp from jax import random, jit from jax_md import space, energy, simulate # MD configs dt = 1e-1 temperature = 0.1 # R: current position # dR: displacement # displacement(Ra, Rb): # dR = Ra - Rb # periodic displacement(Ra, Rb): # dR = Ra - Rb # np.mod(dR + side * f32(0.5), side) - f32(0.5) * side # periodic shift: # np.mod(R + dR, side) # shift: # R + dR displacement, shift = space.free() # Simulation init # dr: pairwise distances # epsilon: interaction energy scale (const) # alpha: interaction stiffness # dr = distance(R) # U(dr) = np.where(dr < 1.0, (1 - dr) ** 2, 0) # energy_fn(R) = diagonal_mask(U(dr)) energy_fn = energy.soft_sphere_pair(displacement) # force(energy) = -d(energy)/dR # xi = random.normal(R.shape, R.dtype) # gamma = 0.1 # nu = 1 / (mass * gamma) # dR = force(R) * dt * nu + np.sqrt(2 * temperature * dt * nu) * xi # BrownianState(position, mass, rng) pos_key, sim_key = random.split(random.PRNGKey(0)) R = random.uniform(pos_key, (N, 2), dtype=jnp.float32) init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, temperature) if with_jit: apply_fn = jit(apply_fn) state = init_fn(sim_key, R) # Start simulation times = [] for i in range(n_iter): time_start = time.perf_counter_ns() state = apply_fn(state) time_end = time.perf_counter_ns() times.append(time_end - time_start) # Finish with profiling times return times