def test_nvt_langevin(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, R_key, R0_key, T_key, masses_key = random.split(key, 5) R = random.normal( R_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal( R0_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype) _, shift = space.free() E = functools.partial( lambda R, R0, **kwargs: np.sum((R - R0) ** 2), R0=R0) T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype) mass = random.uniform( masses_key, (LANGEVIN_PARTICLE_COUNT,), minval=0.1, maxval=10.0, dtype=dtype) init_fn, apply_fn = simulate.nvt_langevin(E, shift, f32(1e-2), T, gamma=f32(0.3)) apply_fn = jit(apply_fn) state = init_fn(key, R, mass=mass, T_initial=dtype(1.0)) T_list = [] for step in range(LANGEVIN_DYNAMICS_STEPS): state = apply_fn(state) if step > 4000 and step % 100 == 0: T_list += [quantity.temperature(state.velocity, state.mass)] T_emp = np.mean(np.array(T_list)) assert np.abs(T_emp - T) < 0.1 assert state.position.dtype == dtype
def test_langevin_harmonic(self): alpha = 1.0 E = lambda x: jnp.sum(0.5 * alpha * x**2) displacement, shift = space.free() N = 10000 steps = 1000 kT = 0.25 dt = 1e-4 gamma = 3 mass = 2.0 tol = 1e-3 X = jnp.ones((N, 1, 1)) key = random.split(random.PRNGKey(0), N) init_fn, step_fn = simulate.nvt_langevin(E, shift, dt, kT, gamma, False) step_fn = jit(vmap(step_fn)) state = vmap(init_fn, (0, 0, None))(key, X, mass) v0 = state.velocity for i in range(steps): state = step_fn(state) # Compare mean position and velocity autocorrelation with theoretical # prediction. d = jnp.sqrt(gamma**2 / 4 - alpha / mass) beta_1 = gamma / 2 + d beta_2 = gamma / 2 - d A = -beta_2 / (beta_1 - beta_2) B = beta_1 / (beta_1 - beta_2) exp1 = lambda t: jnp.exp(-beta_1 * t) exp2 = lambda t: jnp.exp(-beta_2 * t) Z = kT / (2 * d * mass) pos_fn = lambda t: A * exp1(t) + B * exp2(t) vel_fn = lambda t: Z * (-beta_2 * exp2(t) + beta_1 * exp1(t)) t = steps * dt self.assertAllClose(jnp.mean(state.position), pos_fn(t), rtol=tol, atol=tol) self.assertAllClose(jnp.mean(state.velocity * v0), vel_fn(t), rtol=tol, atol=tol)