Ejemplo n.º 1
0
    def apply_fn(state, **kwargs):
        S = state
        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        bc = barostat.update_mass(S.barostat, _kT)
        tc = thermostat.update_mass(S.thermostat, _kT)
        S = update_box_mass(S, _kT)

        V_b, bc = barostat.half_step(S.box_velocity, bc, _kT)
        V, tc = thermostat.half_step(S.velocity, tc, _kT)

        S = dataclasses.replace(S, velocity=V, box_velocity=V_b)
        S = inner_step(S, **kwargs)

        KE = quantity.kinetic_energy(S.velocity, S.mass)
        tc = dataclasses.replace(tc, kinetic_energy=KE)

        KE_box = quantity.kinetic_energy(S.box_velocity, S.box_mass)
        bc = dataclasses.replace(bc, kinetic_energy=KE_box)

        V, tc = thermostat.half_step(S.velocity, tc, _kT)
        V_b, bc = barostat.half_step(S.box_velocity, bc, _kT)

        S = dataclasses.replace(S,
                                thermostat=tc,
                                barostat=bc,
                                velocity=V,
                                box_velocity=V_b)

        return S
Ejemplo n.º 2
0
  def assertGraphTuplesClose(self, a, b, tol=1e-6):
    a_mask = (a.edge_idx < a.nodes.shape[0]).reshape(a.edge_idx.shape + (1,))
    b_mask = (b.edge_idx < b.nodes.shape[0]).reshape(b.edge_idx.shape + (1,))

    a = dataclasses.replace(a, edges=a.edges * a_mask)
    b = dataclasses.replace(b, edges=b.edges * b_mask)

    a = dataclasses.asdict(a)
    b = dataclasses.asdict(b)

    self.assertAllClose(a, b) 
Ejemplo n.º 3
0
    def __call__(self, graph: GraphsTuple) -> GraphsTuple:
        if self._edge_fn is not None:
            graph = dataclasses.replace(graph, edges=self._edge_fn(graph))

        if self._node_fn is not None:
            graph = dataclasses.replace(graph, nodes=self._node_fn(graph))

        if self._global_fn is not None:
            graph = dataclasses.replace(graph, globals=self._global_fn(graph))

        return graph
Ejemplo n.º 4
0
 def embed_fn(graph):
   return dataclasses.replace(
       graph,
       nodes=_node_fn(graph.nodes),
       edges=_edge_fn(graph.edges),
       globals=_global_fn(graph.globals)
   )
Ejemplo n.º 5
0
    def apply_fn(state, **kwargs):
        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        chain = state.chain

        chain = chain_fns.update_mass(chain, _kT)

        v, chain = chain_fns.half_step(state.velocity, chain, _kT)
        state = dataclasses.replace(state, velocity=v)

        state = velocity_verlet(force_fn, shift_fn, dt, state, **kwargs)

        KE = quantity.kinetic_energy(state.velocity, state.mass)
        chain = dataclasses.replace(chain, kinetic_energy=KE)

        v, chain = chain_fns.half_step(state.velocity, chain, _kT)
        state = dataclasses.replace(state, velocity=v, chain=chain)

        return state
Ejemplo n.º 6
0
 def test_connect_graph_network(self, network_fn, dtype):
     for g in _get_graphs():
         g = dataclasses.replace(g,
                                 nodes=np.array(g.nodes, dtype),
                                 edges=np.array(g.edges, dtype),
                                 globals=np.array(g.globals, dtype))
         with self.subTest('nojit'):
             out = network_fn(g)
             self.assertGraphTuplesClose(out, g)
         with self.subTest('jit'):
             out = jit(network_fn)(g)
             self.assertGraphTuplesClose(out, g)
Ejemplo n.º 7
0
def velocity_verlet(force_fn: Callable[..., Array], shift_fn: ShiftFn,
                    dt: float, state: T, **kwargs) -> T:
    """Apply a single step of velocity verlet integration to a state."""
    dt = f32(dt)
    dt_2 = f32(dt / 2)
    dt2_2 = f32(dt**2 / 2)

    R, V, F, M = state.position, state.velocity, state.force, state.mass

    Minv = 1 / M

    R = shift_fn(R, V * dt + F * dt2_2 * Minv, **kwargs)
    F_new = force_fn(R, **kwargs)
    V += (F + F_new) * dt_2 * Minv
    return dataclasses.replace(state, position=R, velocity=V, force=F_new)
Ejemplo n.º 8
0
    def inner_step(state, **kwargs):
        _pressure = kwargs.pop('pressure', pressure)

        R, V, M, F = state.position, state.velocity, state.mass, state.force
        R_b, V_b, M_b = state.box_position, state.box_velocity, state.box_mass

        N, dim = R.shape

        vol, box_fn = _npt_box_info(state)

        alpha = 1 + 1 / N
        G_e = box_force(alpha, vol, box_fn, R, V, M, F, _pressure, **kwargs)
        V_b = V_b + dt_2 * G_e / M_b
        V = exp_iL2(alpha, V, F / M, V_b)

        R_b = R_b + V_b * dt
        state = dataclasses.replace(state, box_position=R_b)

        vol, box_fn = _npt_box_info(state)

        box = box_fn(vol)
        R = exp_iL1(box, R, V, V_b)
        F = force_fn(R, box=box, **kwargs)

        V = exp_iL2(alpha, V, F / M, V_b)
        G_e = box_force(alpha, vol, box_fn, R, V, M, F, _pressure, **kwargs)
        V_b = V_b + dt_2 * G_e / M_b

        return dataclasses.replace(state,
                                   position=R,
                                   velocity=V,
                                   mass=M,
                                   force=F,
                                   box_position=R_b,
                                   box_velocity=V_b,
                                   box_mass=M_b)
Ejemplo n.º 9
0
 def update_box_mass(state, kT):
     N, dim = state.position.shape
     dtype = state.position.dtype
     box_mass = jnp.array(dim * (N + 1) * kT * state.barostat.tau**2, dtype)
     return dataclasses.replace(state, box_mass=box_mass)