def apply_fn(state, **kwargs): S = state _kT = kT if 'kT' not in kwargs else kwargs['kT'] bc = barostat.update_mass(S.barostat, _kT) tc = thermostat.update_mass(S.thermostat, _kT) S = update_box_mass(S, _kT) V_b, bc = barostat.half_step(S.box_velocity, bc, _kT) V, tc = thermostat.half_step(S.velocity, tc, _kT) S = dataclasses.replace(S, velocity=V, box_velocity=V_b) S = inner_step(S, **kwargs) KE = quantity.kinetic_energy(S.velocity, S.mass) tc = dataclasses.replace(tc, kinetic_energy=KE) KE_box = quantity.kinetic_energy(S.box_velocity, S.box_mass) bc = dataclasses.replace(bc, kinetic_energy=KE_box) V, tc = thermostat.half_step(S.velocity, tc, _kT) V_b, bc = barostat.half_step(S.box_velocity, bc, _kT) S = dataclasses.replace(S, thermostat=tc, barostat=bc, velocity=V, box_velocity=V_b) return S
def assertGraphTuplesClose(self, a, b, tol=1e-6): a_mask = (a.edge_idx < a.nodes.shape[0]).reshape(a.edge_idx.shape + (1,)) b_mask = (b.edge_idx < b.nodes.shape[0]).reshape(b.edge_idx.shape + (1,)) a = dataclasses.replace(a, edges=a.edges * a_mask) b = dataclasses.replace(b, edges=b.edges * b_mask) a = dataclasses.asdict(a) b = dataclasses.asdict(b) self.assertAllClose(a, b)
def __call__(self, graph: GraphsTuple) -> GraphsTuple: if self._edge_fn is not None: graph = dataclasses.replace(graph, edges=self._edge_fn(graph)) if self._node_fn is not None: graph = dataclasses.replace(graph, nodes=self._node_fn(graph)) if self._global_fn is not None: graph = dataclasses.replace(graph, globals=self._global_fn(graph)) return graph
def embed_fn(graph): return dataclasses.replace( graph, nodes=_node_fn(graph.nodes), edges=_edge_fn(graph.edges), globals=_global_fn(graph.globals) )
def apply_fn(state, **kwargs): _kT = kT if 'kT' not in kwargs else kwargs['kT'] chain = state.chain chain = chain_fns.update_mass(chain, _kT) v, chain = chain_fns.half_step(state.velocity, chain, _kT) state = dataclasses.replace(state, velocity=v) state = velocity_verlet(force_fn, shift_fn, dt, state, **kwargs) KE = quantity.kinetic_energy(state.velocity, state.mass) chain = dataclasses.replace(chain, kinetic_energy=KE) v, chain = chain_fns.half_step(state.velocity, chain, _kT) state = dataclasses.replace(state, velocity=v, chain=chain) return state
def test_connect_graph_network(self, network_fn, dtype): for g in _get_graphs(): g = dataclasses.replace(g, nodes=np.array(g.nodes, dtype), edges=np.array(g.edges, dtype), globals=np.array(g.globals, dtype)) with self.subTest('nojit'): out = network_fn(g) self.assertGraphTuplesClose(out, g) with self.subTest('jit'): out = jit(network_fn)(g) self.assertGraphTuplesClose(out, g)
def velocity_verlet(force_fn: Callable[..., Array], shift_fn: ShiftFn, dt: float, state: T, **kwargs) -> T: """Apply a single step of velocity verlet integration to a state.""" dt = f32(dt) dt_2 = f32(dt / 2) dt2_2 = f32(dt**2 / 2) R, V, F, M = state.position, state.velocity, state.force, state.mass Minv = 1 / M R = shift_fn(R, V * dt + F * dt2_2 * Minv, **kwargs) F_new = force_fn(R, **kwargs) V += (F + F_new) * dt_2 * Minv return dataclasses.replace(state, position=R, velocity=V, force=F_new)
def inner_step(state, **kwargs): _pressure = kwargs.pop('pressure', pressure) R, V, M, F = state.position, state.velocity, state.mass, state.force R_b, V_b, M_b = state.box_position, state.box_velocity, state.box_mass N, dim = R.shape vol, box_fn = _npt_box_info(state) alpha = 1 + 1 / N G_e = box_force(alpha, vol, box_fn, R, V, M, F, _pressure, **kwargs) V_b = V_b + dt_2 * G_e / M_b V = exp_iL2(alpha, V, F / M, V_b) R_b = R_b + V_b * dt state = dataclasses.replace(state, box_position=R_b) vol, box_fn = _npt_box_info(state) box = box_fn(vol) R = exp_iL1(box, R, V, V_b) F = force_fn(R, box=box, **kwargs) V = exp_iL2(alpha, V, F / M, V_b) G_e = box_force(alpha, vol, box_fn, R, V, M, F, _pressure, **kwargs) V_b = V_b + dt_2 * G_e / M_b return dataclasses.replace(state, position=R, velocity=V, mass=M, force=F, box_position=R_b, box_velocity=V_b, box_mass=M_b)
def update_box_mass(state, kT): N, dim = state.position.shape dtype = state.position.dtype box_mass = jnp.array(dim * (N + 1) * kT * state.barostat.tau**2, dtype) return dataclasses.replace(state, box_mass=box_mass)