Ejemplo n.º 1
0
    def test_brownian(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        key, T_split, mass_split = random.split(key, 3)

        _, shift = space.free()
        energy_fn = lambda R, **kwargs: f32(0)

        R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype)
        mass = random.uniform(mass_split, (),
                              minval=0.1,
                              maxval=10.0,
                              dtype=dtype)
        T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype)

        dt = f32(1e-2)
        gamma = f32(0.1)

        init_fn, apply_fn = simulate.brownian(energy_fn,
                                              shift,
                                              dt,
                                              T,
                                              gamma=gamma)
        apply_fn = jit(apply_fn)

        state = init_fn(key, R, mass)

        sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt)
        for _ in range(BROWNIAN_DYNAMICS_STEPS):
            state = apply_fn(state)

        msd = np.var(state.position)
        th_msd = dtype(2 * T / (mass * gamma) * sim_t)
        assert np.abs(msd - th_msd) / msd < 1e-2
        assert state.position.dtype == dtype
Ejemplo n.º 2
0
def run_brownian(energy_fn,
                 R_init,
                 shift,
                 key,
                 num_steps,
                 kT,
                 dt=0.00001,
                 gamma=0.1) -> jnp.ndarray:
    """Simulate Brownian motion."""

    init, apply = simulate.brownian(energy_fn,
                                    shift,
                                    dt=dt,
                                    kT=kT,
                                    gamma=gamma)

    apply = jit(apply)

    @jit
    def scan_fn(state, current_step):
        # Dynamically pass r0 to apply, which passes it on to energy_fn
        return apply(state), 0

    key, split = random.split(key)
    state = init(split, R_init)
    state, _ = lax.scan(scan_fn, state, jnp.arange(num_steps))

    return state.position
Ejemplo n.º 3
0
def run(N=32, n_iter=1000, with_jit=True):
    import jax.numpy as jnp
    from jax import random, jit
    from jax_md import space, energy, simulate

    # MD configs
    dt = 1e-1
    temperature = 0.1

    # R: current position
    # dR: displacement
    # displacement(Ra, Rb):
    #   dR = Ra - Rb
    # periodic displacement(Ra, Rb):
    #   dR = Ra - Rb
    #   np.mod(dR + side * f32(0.5), side) - f32(0.5) * side
    # periodic shift:
    #   np.mod(R + dR, side)
    # shift:
    #   R + dR
    displacement, shift = space.free()

    # Simulation init
    # dr: pairwise distances
    # epsilon: interaction energy scale (const)
    # alpha: interaction stiffness
    # dr = distance(R)
    # U(dr) = np.where(dr < 1.0, (1 - dr) ** 2, 0)
    # energy_fn(R) = diagonal_mask(U(dr))
    energy_fn = energy.soft_sphere_pair(displacement)

    # force(energy) = -d(energy)/dR
    # xi = random.normal(R.shape, R.dtype)
    # gamma = 0.1
    # nu = 1 / (mass * gamma)
    # dR = force(R) * dt * nu + np.sqrt(2 * temperature * dt * nu) * xi
    # BrownianState(position, mass, rng)
    pos_key, sim_key = random.split(random.PRNGKey(0))
    R = random.uniform(pos_key, (N, 2), dtype=jnp.float32)
    init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, temperature)
    if with_jit:
        apply_fn = jit(apply_fn)
    state = init_fn(sim_key, R)

    # Start simulation
    times = []
    for i in range(n_iter):
        time_start = time.perf_counter_ns()
        state = apply_fn(state)
        time_end = time.perf_counter_ns()
        times.append(time_end - time_start)

    # Finish with profiling times
    return times