Ejemplo n.º 1
0
    def init_fn(key, R, box, mass=f32(1.0), **kwargs):
        N, dim = R.shape

        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        mass = quantity.canonicalize_mass(mass)
        V = jnp.sqrt(_kT / mass) * random.normal(key, R.shape, dtype=R.dtype)
        V = V - jnp.mean(V * mass, axis=0, keepdims=True) / mass
        KE = quantity.kinetic_energy(V, mass)

        # The box position is defined via pos = (1 / d) log V / V_0.
        zero = jnp.zeros((), dtype=R.dtype)
        one = jnp.ones((), dtype=R.dtype)
        box_position = zero
        box_velocity = zero
        box_mass = dim * (N + 1) * kT * barostat_kwargs['tau']**2 * one
        KE_box = quantity.kinetic_energy(box_velocity, box_mass)

        if jnp.isscalar(box) or box.ndim == 0:
            # TODO(schsam): This is necessary because of JAX issue #5849.
            box = jnp.eye(R.shape[-1]) * box

        return NPTNoseHooverState(R, V, force_fn(R, box=box, **kwargs), mass,
                                  box, box_position, box_velocity, box_mass,
                                  barostat.initialize(1, KE_box, _kT),
                                  thermostat.initialize(R.size, KE, _kT))  # pytype: disable=wrong-arg-count
Ejemplo n.º 2
0
 def init_fun(key: Array,
              R: Array,
              velocity_scale: float = f32(1.0),
              mass=f32(1.0),
              **kwargs) -> NVEState:
     V = np.sqrt(velocity_scale) * random.normal(
         key, R.shape, dtype=R.dtype)
     mass = quantity.canonicalize_mass(mass)
     return NVEState(R, V, force(R, **kwargs) / mass, mass)  # pytype: disable=wrong-arg-count
Ejemplo n.º 3
0
    def init_fn(key, R, mass=f32(1), **kwargs):
        _kT = kT if 'kT' not in kwargs else kwargs['kT']
        mass = quantity.canonicalize_mass(mass)

        key, split = random.split(key)

        V = np.sqrt(_kT / mass) * random.normal(split, R.shape, dtype=R.dtype)
        V = V - np.mean(V, axis=0, keepdims=True)

        return NVTLangevinState(R, V, force_fn(R, **kwargs), mass, key)  # pytype: disable=wrong-arg-count
Ejemplo n.º 4
0
    def init_fn(key, R, mass=f32(1), T_initial=f32(1)):
        mass = quantity.canonicalize_mass(mass)

        key, split = random.split(key)

        V = np.sqrt(T_initial / mass) * random.normal(
            split, R.shape, dtype=R.dtype)
        V = V - np.mean(V, axis=0, keepdims=True)

        return NVTLangevinState(R, V, force_fn(R, t=f32(0)), mass, key)
Ejemplo n.º 5
0
    def init_fn(key, R, mass=f32(1.0), **kwargs):
        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        mass = quantity.canonicalize_mass(mass)
        V = jnp.sqrt(_kT / mass) * random.normal(key, R.shape, dtype=R.dtype)
        V = V - jnp.mean(V * mass, axis=0, keepdims=True) / mass
        KE = quantity.kinetic_energy(V, mass)

        return NVTNoseHooverState(R, V, force_fn(R, **kwargs), mass,
                                  chain_fns.initialize(R.size, KE, _kT))  # pytype: disable=wrong-arg-count
Ejemplo n.º 6
0
    def init_fn(key, R, mass=f32(1), T_initial=None, **kwargs):
        if T_initial is None:
            T_initial = T_schedule(0.0)

        mass = quantity.canonicalize_mass(mass)

        key, split = random.split(key)

        V = np.sqrt(T_initial / mass) * random.normal(
            split, R.shape, dtype=R.dtype)
        V = V - np.mean(V, axis=0, keepdims=True)

        return NVTLangevinState(R, V, force_fn(R, t=f32(0), **kwargs), mass,
                                key)  # pytype: disable=wrong-arg-count
Ejemplo n.º 7
0
    def init_fun(key, R, mass=f32(1.0), T_initial=f32(1.0)):
        mass = quantity.canonicalize_mass(mass)
        V = np.sqrt(T_initial / mass) * random.normal(
            key, R.shape, dtype=R.dtype)
        V = V - np.mean(V, axis=0, keepdims=True)
        KE = quantity.kinetic_energy(V, mass)

        # Nose-Hoover parameters.
        xi = np.zeros(chain_length, R.dtype)
        v_xi = np.zeros(chain_length, R.dtype)

        DOF, = static_cast(R.shape[0] * R.shape[1])
        Q = T_initial * tau**f32(2) * np.ones(chain_length, dtype=R.dtype)
        Q = ops.index_update(Q, 0, Q[0] * DOF)

        return NVTNoseHooverState(R, V, mass, KE, xi, v_xi, Q)
Ejemplo n.º 8
0
    def init_fn(key, R, mass=f32(1.0), **kwargs):
        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        mass = quantity.canonicalize_mass(mass)
        V = np.sqrt(_kT / mass) * random.normal(key, R.shape, dtype=R.dtype)
        V = V - np.mean(V, axis=0, keepdims=True)
        KE = quantity.kinetic_energy(V, mass)

        # Nose-Hoover parameters.
        xi = np.zeros(chain_length, R.dtype)
        v_xi = np.zeros(chain_length, R.dtype)

        # TODO(schsam): Really, it seems like Q should be set by the goal
        # temperature rather than the initial temperature.
        DOF = f32(R.shape[0] * R.shape[1])
        Q = _kT * tau**f32(2) * np.ones(chain_length, dtype=R.dtype)
        Q = ops.index_update(Q, 0, Q[0] * DOF)

        F = force_fn(R, **kwargs)

        return NVTNoseHooverState(R, V, F, mass, KE, xi, v_xi, Q)  # pytype: disable=wrong-arg-count
Ejemplo n.º 9
0
    def init_fun(key, R, mass=f32(1.0), T_initial=None):
        if T_initial is None:
            T_initial = T_schedule(0.0)

        mass = quantity.canonicalize_mass(mass)
        V = np.sqrt(T_initial / mass) * random.normal(
            key, R.shape, dtype=R.dtype)
        V = V - np.mean(V, axis=0, keepdims=True)
        KE = quantity.kinetic_energy(V, mass)

        # Nose-Hoover parameters.
        xi = np.zeros(chain_length, R.dtype)
        v_xi = np.zeros(chain_length, R.dtype)

        # TODO(schsam): Really, it seems like Q should be set by the goal
        # temperature rather than the initial temperature.
        DOF, = static_cast(R.shape[0] * R.shape[1])
        Q = T_initial * tau**f32(2) * np.ones(chain_length, dtype=R.dtype)
        Q = ops.index_update(Q, 0, Q[0] * DOF)

        return NVTNoseHooverState(R, V, mass, KE, xi, v_xi, Q)  # pytype: disable=wrong-arg-count
Ejemplo n.º 10
0
    def init_fn(key, R, mass=f32(1)):
        mass = quantity.canonicalize_mass(mass)

        return BrownianState(R, mass, key)  # pytype: disable=wrong-arg-count
Ejemplo n.º 11
0
    def init_fn(key, R, mass=f32(1)):
        mass = quantity.canonicalize_mass(mass)

        return BrownianState(R, mass, key)
Ejemplo n.º 12
0
 def init_fun(key, R, velocity_scale=f32(1.0), mass=f32(1.0)):
     V = np.sqrt(velocity_scale) * random.normal(
         key, R.shape, dtype=R.dtype)
     mass = quantity.canonicalize_mass(mass)
     return NVEState(R, V, force(R) / mass, mass)
Ejemplo n.º 13
0
 def test_canonicalize_mass(self):
     assert quantity.canonicalize_mass(3.0) == 3.0
     assert quantity.canonicalize_mass(f32(3.0)) == f32(3.0)
     assert quantity.canonicalize_mass(f64(3.0)) == f64(3.0)
Ejemplo n.º 14
0
 def init_fn(key, R, kT, mass=f32(1.0), **kwargs):
     mass = quantity.canonicalize_mass(mass)
     V = jnp.sqrt(kT / mass) * random.normal(key, R.shape, dtype=R.dtype)
     V = V - jnp.mean(V * mass, axis=0, keepdims=True) / mass
     return NVEState(R, V, force_fn(R, **kwargs), mass)  # pytype: disable=wrong-arg-count
Ejemplo n.º 15
0
 def init_fn(key, R, kT, mass=f32(1.0), **kwargs):
     mass = quantity.canonicalize_mass(mass)
     V = np.sqrt(kT / mass) * random.normal(key, R.shape, dtype=R.dtype)
     V = V - np.mean(V, axis=0, keepdims=True)
     return NVEState(R, V, force_fn(R, **kwargs), mass)