Ejemplo n.º 1
0
  def test_nvt_langevin(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, R_key, R0_key, T_key, masses_key = random.split(key, 5)

      R = random.normal(
        R_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      R0 = random.normal(
        R0_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      _, shift = space.free()

      E = functools.partial(
          lambda R, R0, **kwargs: np.sum((R - R0) ** 2), R0=R0)

      T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype)
      mass = random.uniform(
        masses_key, (LANGEVIN_PARTICLE_COUNT,), minval=0.1, maxval=10.0, dtype=dtype)
      init_fn, apply_fn = simulate.nvt_langevin(E, shift, f32(1e-2), T, gamma=f32(0.3))
      apply_fn = jit(apply_fn)

      state = init_fn(key, R, mass=mass, T_initial=dtype(1.0))

      T_list = []
      for step in range(LANGEVIN_DYNAMICS_STEPS):
        state = apply_fn(state)
        if step > 4000 and step % 100 == 0:
          T_list += [quantity.temperature(state.velocity, state.mass)]

      T_emp = np.mean(np.array(T_list))
      assert np.abs(T_emp - T) < 0.1
      assert state.position.dtype == dtype
Ejemplo n.º 2
0
    def test_langevin_harmonic(self):
        alpha = 1.0
        E = lambda x: jnp.sum(0.5 * alpha * x**2)
        displacement, shift = space.free()

        N = 10000
        steps = 1000
        kT = 0.25
        dt = 1e-4
        gamma = 3
        mass = 2.0
        tol = 1e-3

        X = jnp.ones((N, 1, 1))
        key = random.split(random.PRNGKey(0), N)

        init_fn, step_fn = simulate.nvt_langevin(E, shift, dt, kT, gamma,
                                                 False)
        step_fn = jit(vmap(step_fn))

        state = vmap(init_fn, (0, 0, None))(key, X, mass)
        v0 = state.velocity

        for i in range(steps):
            state = step_fn(state)

        # Compare mean position and velocity autocorrelation with theoretical
        # prediction.

        d = jnp.sqrt(gamma**2 / 4 - alpha / mass)

        beta_1 = gamma / 2 + d
        beta_2 = gamma / 2 - d
        A = -beta_2 / (beta_1 - beta_2)
        B = beta_1 / (beta_1 - beta_2)
        exp1 = lambda t: jnp.exp(-beta_1 * t)
        exp2 = lambda t: jnp.exp(-beta_2 * t)
        Z = kT / (2 * d * mass)

        pos_fn = lambda t: A * exp1(t) + B * exp2(t)
        vel_fn = lambda t: Z * (-beta_2 * exp2(t) + beta_1 * exp1(t))

        t = steps * dt
        self.assertAllClose(jnp.mean(state.position),
                            pos_fn(t),
                            rtol=tol,
                            atol=tol)
        self.assertAllClose(jnp.mean(state.velocity * v0),
                            vel_fn(t),
                            rtol=tol,
                            atol=tol)