Пример #1
0
    def apply_fn(state: FireDescentState, **kwargs) -> FireDescentState:
        R, V, F_old, dt, alpha, n_pos = dataclasses.astuple(state)

        R = shift_fn(R, dt * V + dt**f32(2) * F_old, **kwargs)

        F = force(R, **kwargs)

        V = V + dt * f32(0.5) * (F_old + F)

        # NOTE(schsam): This will be wrong if F_norm ~< 1e-8.
        # TODO(schsam): We should check for forces below 1e-6. @ErrorChecking
        F_norm = jnp.sqrt(jnp.sum(F**f32(2)) + f32(1e-6))
        V_norm = jnp.sqrt(jnp.sum(V**f32(2)))

        P = jnp.array(jnp.dot(jnp.reshape(F, (-1)), jnp.reshape(V, (-1))))

        V = V + alpha * (F * V_norm / F_norm - V)

        # NOTE(schsam): Can we clean this up at all?
        n_pos = jnp.where(P >= 0, n_pos + 1, 0)
        dt_choice = jnp.array([dt * f_inc, dt_max])
        dt = jnp.where(P > 0, jnp.where(n_pos > n_min, jnp.min(dt_choice), dt),
                       dt)
        dt = jnp.where(P < 0, dt * f_dec, dt)
        alpha = jnp.where(P > 0,
                          jnp.where(n_pos > n_min, alpha * f_alpha, alpha),
                          alpha)
        alpha = jnp.where(P < 0, alpha_start, alpha)
        V = (P < 0) * jnp.zeros_like(V) + (P >= 0) * V

        return FireDescentState(R, V, F, dt, alpha, n_pos)  # pytype: disable=wrong-arg-count
Пример #2
0
    def update_chain_mass_fn(state, kT):
        xi, v_xi, Q, _tau, KE, DOF = dataclasses.astuple(state)

        Q = kT * _tau**f32(2) * jnp.ones(chain_length, dtype=f32)
        Q = ops.index_update(Q, 0, Q[0] * DOF)

        return NoseHooverChain(xi, v_xi, Q, _tau, KE, DOF)  # pytype: disable=wrong-arg-count
Пример #3
0
    def substep_fn(delta, V, state, kT):
        """Apply a single update to the chain parameters and rescales velocity."""

        xi, v_xi, Q, _tau, KE, DOF = dataclasses.astuple(state)

        delta_2 = delta / f32(2.0)
        delta_4 = delta_2 / f32(2.0)
        delta_8 = delta_4 / f32(2.0)

        M = chain_length - 1

        G = (v_xi[M - 1]**f32(2) * Q[M - 1] - kT) / Q[M]
        v_xi = ops.index_add(v_xi, M, delta_4 * G)

        def backward_loop_fn(v_xi_new, m):
            G = (v_xi[m - 1]**2 * Q[m - 1] - kT) / Q[m]
            scale = jnp.exp(-delta_8 * v_xi_new)
            v_xi_new = scale * (scale * v_xi[m] + delta_4 * G)
            return v_xi_new, v_xi_new

        idx = jnp.arange(M - 1, 0, -1)
        _, v_xi_update = lax.scan(backward_loop_fn, v_xi[M], idx, unroll=2)
        v_xi = ops.index_update(v_xi, idx, v_xi_update)

        G = (f32(2.0) * KE - DOF * kT) / Q[0]
        scale = jnp.exp(-delta_8 * v_xi[1])
        v_xi = ops.index_update(v_xi, 0,
                                scale * (scale * v_xi[0] + delta_4 * G))

        scale = jnp.exp(-delta_2 * v_xi[0])
        KE = KE * scale**f32(2)
        V = V * scale

        xi = xi + delta_2 * v_xi

        G = (f32(2) * KE - DOF * kT) / Q[0]

        def forward_loop_fn(G, m):
            scale = jnp.exp(-delta_8 * v_xi[m + 1])
            v_xi_update = scale * (scale * v_xi[m] + delta_4 * G)
            G = (v_xi_update**2 * Q[m] - kT) / Q[m + 1]
            return G, v_xi_update

        idx = jnp.arange(M)
        G, v_xi_update = lax.scan(forward_loop_fn, G, idx, unroll=2)
        v_xi = ops.index_update(v_xi, idx, v_xi_update)
        v_xi = ops.index_add(v_xi, M, delta_4 * G)

        return V, NoseHooverChain(xi, v_xi, Q, _tau, KE, DOF), kT  # pytype: disable=wrong-arg-count
Пример #4
0
    def apply_fn(state, t=f32(0), **kwargs):

        R, mass, key = dataclasses.astuple(state)

        key, split = random.split(key)

        F = force_fn(R, t=t, **kwargs)
        xi = random.normal(split, R.shape, R.dtype)

        nu = f32(1) / (mass * gamma)

        dR = F * dt * nu + np.sqrt(f32(2) * T_schedule(t) * dt * nu) * xi
        R = shift(R, dR, t=t, **kwargs)

        return BrownianState(R, mass, key)  # pytype: disable=wrong-arg-count
Пример #5
0
    def apply_fn(state, **kwargs):
        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        R, mass, key = dataclasses.astuple(state)

        key, split = random.split(key)

        F = force_fn(R, **kwargs)
        xi = random.normal(split, R.shape, R.dtype)

        nu = f32(1) / (mass * gamma)

        dR = F * dt * nu + np.sqrt(f32(2) * _kT * dt * nu) * xi
        R = shift(R, dR, **kwargs)

        return BrownianState(R, mass, key)  # pytype: disable=wrong-arg-count
Пример #6
0
    def apply_fn(state, t=f32(0), **kwargs):
        R, V, F, mass, key = dataclasses.astuple(state)

        N, dim = R.shape

        key, xi_key, theta_key = random.split(key, 3)
        xi = random.normal(xi_key, (N, dim), dtype=R.dtype)
        theta = random.normal(theta_key,
                              (N, dim), dtype=R.dtype) / np.sqrt(f32(3))

        # NOTE(schsam): We really only need to recompute sigma if the temperature
        # is nonconstant. @Optimization
        # TODO(schsam): Check that this is really valid in the case that the masses
        # are non identical for all particles.
        sigma = np.sqrt(f32(2) * T_schedule(t) * gamma / mass)
        C = dt2 * (F - gamma * V) + sigma * dt32 * (xi + theta)

        R = shift(R, dt * V + F + C, t=t, **kwargs)
        F_new = force_fn(R, t=t, **kwargs)
        V = (f32(1) - dt * gamma) * V + dt_2 * (F_new + F)
        V = V + sigma * np.sqrt(dt) * xi - gamma * C

        return NVTLangevinState(R, V, F_new, mass, key)  # pytype: disable=wrong-arg-count
Пример #7
0
    def apply_fn(state, **kwargs):
        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        R, V, F, mass, KE, xi, v_xi, Q = dataclasses.astuple(state)

        DOF = R.size

        Q = _kT * tau**f32(2) * np.ones(chain_length, dtype=R.dtype)
        Q = ops.index_update(Q, 0, Q[0] * DOF)

        KE, V, xi, v_xi, *_ = half_step_chain_fn(KE, V, xi, v_xi, Q, DOF, _kT)

        R = shift_fn(R, V * dt + F * dt**2 / (2 * mass), **kwargs)

        F_new = force_fn(R, **kwargs)

        V = V + dt_2 * (F_new + F) / mass

        V = V - np.mean(V, axis=0, keepdims=True)
        KE = quantity.kinetic_energy(V, mass)

        KE, V, xi, v_xi, *_ = half_step_chain_fn(KE, V, xi, v_xi, Q, DOF, _kT)

        return NVTNoseHooverState(R, V, F_new, mass, KE, xi, v_xi, Q)
Пример #8
0
    def apply_fun(state, t=f32(0), **kwargs):
        T = T_schedule(t)

        R, V, mass, KE, xi, v_xi, Q = dataclasses.astuple(state)

        DOF, = static_cast(R.shape[0] * R.shape[1])

        Q = T * tau**f32(2) * np.ones(chain_length, dtype=R.dtype)
        Q = ops.index_update(Q, 0, Q[0] * DOF)

        KE, V, xi, v_xi = step_chain(KE, V, xi, v_xi, Q, DOF, T)
        R = shift_fn(R, V * dt_2, t=t, **kwargs)

        F = force(R, t=t, **kwargs)

        V = V + dt * F / mass
        # NOTE(schsam): Do we need to mean subtraction here?
        V = V - np.mean(V, axis=0, keepdims=True)
        KE = quantity.kinetic_energy(V, mass)
        R = shift_fn(R, V * dt_2, t=t, **kwargs)

        KE, V, xi, v_xi = step_chain(KE, V, xi, v_xi, Q, DOF, T)

        return NVTNoseHooverState(R, V, mass, KE, xi, v_xi, Q)  # pytype: disable=wrong-arg-count
Пример #9
0
 def apply_fun(state: NVEState, **kwargs) -> NVEState:
     R, V, A, mass = dataclasses.astuple(state)
     R = shift_fn(R, V * dt + A * dt_2, **kwargs)
     A_prime = force(R, **kwargs) / mass
     V = V + f32(0.5) * (A + A_prime) * dt
     return NVEState(R, V, A_prime, mass)  # pytype: disable=wrong-arg-count