def apply_fn(state: FireDescentState, **kwargs) -> FireDescentState: R, V, F_old, dt, alpha, n_pos = dataclasses.astuple(state) R = shift_fn(R, dt * V + dt**f32(2) * F_old, **kwargs) F = force(R, **kwargs) V = V + dt * f32(0.5) * (F_old + F) # NOTE(schsam): This will be wrong if F_norm ~< 1e-8. # TODO(schsam): We should check for forces below 1e-6. @ErrorChecking F_norm = jnp.sqrt(jnp.sum(F**f32(2)) + f32(1e-6)) V_norm = jnp.sqrt(jnp.sum(V**f32(2))) P = jnp.array(jnp.dot(jnp.reshape(F, (-1)), jnp.reshape(V, (-1)))) V = V + alpha * (F * V_norm / F_norm - V) # NOTE(schsam): Can we clean this up at all? n_pos = jnp.where(P >= 0, n_pos + 1, 0) dt_choice = jnp.array([dt * f_inc, dt_max]) dt = jnp.where(P > 0, jnp.where(n_pos > n_min, jnp.min(dt_choice), dt), dt) dt = jnp.where(P < 0, dt * f_dec, dt) alpha = jnp.where(P > 0, jnp.where(n_pos > n_min, alpha * f_alpha, alpha), alpha) alpha = jnp.where(P < 0, alpha_start, alpha) V = (P < 0) * jnp.zeros_like(V) + (P >= 0) * V return FireDescentState(R, V, F, dt, alpha, n_pos) # pytype: disable=wrong-arg-count
def update_chain_mass_fn(state, kT): xi, v_xi, Q, _tau, KE, DOF = dataclasses.astuple(state) Q = kT * _tau**f32(2) * jnp.ones(chain_length, dtype=f32) Q = ops.index_update(Q, 0, Q[0] * DOF) return NoseHooverChain(xi, v_xi, Q, _tau, KE, DOF) # pytype: disable=wrong-arg-count
def substep_fn(delta, V, state, kT): """Apply a single update to the chain parameters and rescales velocity.""" xi, v_xi, Q, _tau, KE, DOF = dataclasses.astuple(state) delta_2 = delta / f32(2.0) delta_4 = delta_2 / f32(2.0) delta_8 = delta_4 / f32(2.0) M = chain_length - 1 G = (v_xi[M - 1]**f32(2) * Q[M - 1] - kT) / Q[M] v_xi = ops.index_add(v_xi, M, delta_4 * G) def backward_loop_fn(v_xi_new, m): G = (v_xi[m - 1]**2 * Q[m - 1] - kT) / Q[m] scale = jnp.exp(-delta_8 * v_xi_new) v_xi_new = scale * (scale * v_xi[m] + delta_4 * G) return v_xi_new, v_xi_new idx = jnp.arange(M - 1, 0, -1) _, v_xi_update = lax.scan(backward_loop_fn, v_xi[M], idx, unroll=2) v_xi = ops.index_update(v_xi, idx, v_xi_update) G = (f32(2.0) * KE - DOF * kT) / Q[0] scale = jnp.exp(-delta_8 * v_xi[1]) v_xi = ops.index_update(v_xi, 0, scale * (scale * v_xi[0] + delta_4 * G)) scale = jnp.exp(-delta_2 * v_xi[0]) KE = KE * scale**f32(2) V = V * scale xi = xi + delta_2 * v_xi G = (f32(2) * KE - DOF * kT) / Q[0] def forward_loop_fn(G, m): scale = jnp.exp(-delta_8 * v_xi[m + 1]) v_xi_update = scale * (scale * v_xi[m] + delta_4 * G) G = (v_xi_update**2 * Q[m] - kT) / Q[m + 1] return G, v_xi_update idx = jnp.arange(M) G, v_xi_update = lax.scan(forward_loop_fn, G, idx, unroll=2) v_xi = ops.index_update(v_xi, idx, v_xi_update) v_xi = ops.index_add(v_xi, M, delta_4 * G) return V, NoseHooverChain(xi, v_xi, Q, _tau, KE, DOF), kT # pytype: disable=wrong-arg-count
def apply_fn(state, t=f32(0), **kwargs): R, mass, key = dataclasses.astuple(state) key, split = random.split(key) F = force_fn(R, t=t, **kwargs) xi = random.normal(split, R.shape, R.dtype) nu = f32(1) / (mass * gamma) dR = F * dt * nu + np.sqrt(f32(2) * T_schedule(t) * dt * nu) * xi R = shift(R, dR, t=t, **kwargs) return BrownianState(R, mass, key) # pytype: disable=wrong-arg-count
def apply_fn(state, **kwargs): _kT = kT if 'kT' not in kwargs else kwargs['kT'] R, mass, key = dataclasses.astuple(state) key, split = random.split(key) F = force_fn(R, **kwargs) xi = random.normal(split, R.shape, R.dtype) nu = f32(1) / (mass * gamma) dR = F * dt * nu + np.sqrt(f32(2) * _kT * dt * nu) * xi R = shift(R, dR, **kwargs) return BrownianState(R, mass, key) # pytype: disable=wrong-arg-count
def apply_fn(state, t=f32(0), **kwargs): R, V, F, mass, key = dataclasses.astuple(state) N, dim = R.shape key, xi_key, theta_key = random.split(key, 3) xi = random.normal(xi_key, (N, dim), dtype=R.dtype) theta = random.normal(theta_key, (N, dim), dtype=R.dtype) / np.sqrt(f32(3)) # NOTE(schsam): We really only need to recompute sigma if the temperature # is nonconstant. @Optimization # TODO(schsam): Check that this is really valid in the case that the masses # are non identical for all particles. sigma = np.sqrt(f32(2) * T_schedule(t) * gamma / mass) C = dt2 * (F - gamma * V) + sigma * dt32 * (xi + theta) R = shift(R, dt * V + F + C, t=t, **kwargs) F_new = force_fn(R, t=t, **kwargs) V = (f32(1) - dt * gamma) * V + dt_2 * (F_new + F) V = V + sigma * np.sqrt(dt) * xi - gamma * C return NVTLangevinState(R, V, F_new, mass, key) # pytype: disable=wrong-arg-count
def apply_fn(state, **kwargs): _kT = kT if 'kT' not in kwargs else kwargs['kT'] R, V, F, mass, KE, xi, v_xi, Q = dataclasses.astuple(state) DOF = R.size Q = _kT * tau**f32(2) * np.ones(chain_length, dtype=R.dtype) Q = ops.index_update(Q, 0, Q[0] * DOF) KE, V, xi, v_xi, *_ = half_step_chain_fn(KE, V, xi, v_xi, Q, DOF, _kT) R = shift_fn(R, V * dt + F * dt**2 / (2 * mass), **kwargs) F_new = force_fn(R, **kwargs) V = V + dt_2 * (F_new + F) / mass V = V - np.mean(V, axis=0, keepdims=True) KE = quantity.kinetic_energy(V, mass) KE, V, xi, v_xi, *_ = half_step_chain_fn(KE, V, xi, v_xi, Q, DOF, _kT) return NVTNoseHooverState(R, V, F_new, mass, KE, xi, v_xi, Q)
def apply_fun(state, t=f32(0), **kwargs): T = T_schedule(t) R, V, mass, KE, xi, v_xi, Q = dataclasses.astuple(state) DOF, = static_cast(R.shape[0] * R.shape[1]) Q = T * tau**f32(2) * np.ones(chain_length, dtype=R.dtype) Q = ops.index_update(Q, 0, Q[0] * DOF) KE, V, xi, v_xi = step_chain(KE, V, xi, v_xi, Q, DOF, T) R = shift_fn(R, V * dt_2, t=t, **kwargs) F = force(R, t=t, **kwargs) V = V + dt * F / mass # NOTE(schsam): Do we need to mean subtraction here? V = V - np.mean(V, axis=0, keepdims=True) KE = quantity.kinetic_energy(V, mass) R = shift_fn(R, V * dt_2, t=t, **kwargs) KE, V, xi, v_xi = step_chain(KE, V, xi, v_xi, Q, DOF, T) return NVTNoseHooverState(R, V, mass, KE, xi, v_xi, Q) # pytype: disable=wrong-arg-count
def apply_fun(state: NVEState, **kwargs) -> NVEState: R, V, A, mass = dataclasses.astuple(state) R = shift_fn(R, V * dt + A * dt_2, **kwargs) A_prime = force(R, **kwargs) / mass V = V + f32(0.5) * (A + A_prime) * dt return NVEState(R, V, A_prime, mass) # pytype: disable=wrong-arg-count