def test_inactive(): net = Network.from_edges( n=4, edges=jnp.array([(0, 2), (2, 1), (2, 3), (0, 1), (1, 0)]), directed=True, ) np = NetworkProcess([ operations.Fun( edge_f=lambda data: {"si": 1 + data.src_node["i"]}, node_f=lambda data: {"x": data.in_edges["sum.si"]}, ) ]) s0 = np.new_state( net, props={ "node.x": jnp.zeros(net.n, dtype=jnp.int32), "edge.si": jnp.zeros(net.m, dtype=jnp.int32), }, ) s1 = np.run(s0, steps=1, jit=False) assert (s1.node["x"] == jnp.array([2, 4, 1, 3])).all() s1.edge["active"] = jnp.array([True, False, True, True, False]) s2 = np.run(s1, steps=1, jit=False) assert (s2.node["x"] == jnp.array([0, 1, 1, 3])).all()
def test_seir_model(): N = 20 g = nx.random_graphs.barabasi_albert_graph(N, 3, seed=40) net = Network.from_graph(g) np = NetworkProcess([epidemics.SEIRUpdateOp(immunity_loss=True)]) s = np.new_state( net, props={ "edge_expose_rate": 0.7, "infectious_rate": 1.0, "recovery_rate": 0.8, "immunity_loss_rate": 0.1, }, seed=46, ) # Few passes without any infections print(s.node["compartment"]) s = np.run(s, steps=5) print(s.node["compartment"]) assert sum(s.node["compartment"]) == 0 # Infect a single high-degree node s.node["compartment"] = jnp.array([1] + [0] * (s.n - 1)) # Infection spread s = np.run(s, steps=5) print(s.node["compartment"]) assert sum(s.node["compartment"] == 0) in range(3, 10) assert sum(s.node["compartment"] == 1) in range(3, 10) assert sum(s.node["compartment"] == 2) in range(3, 10) assert sum(s.node["compartment"] == 3) in range(2, 6) assert np._traced == 1 ## Not essential, testing tracing-once property
def test_custom_process(): # Full message passing class TestOp(OperationBase): def update_edge(self, data: EdgeUpdateData): return { "aa": data.edge["aa"] + data.src_node["y"], "_nope": 1, "_tgt_x": data.tgt_node["x"], "_src_x": data.src_node["x"], "_deg": 1, } def update_node(self, data: NodeUpdateData): return { "x": data.in_edges["sum"]["_src_x"], "y": data.node["y"] + jax.lax.cond( data.out_edges["sum"]["_deg"] >= 2, lambda _: 100.0, lambda _: jnp.float32( jax.random.randint(data.prng_key, (), 200, 300)), None, ), "indeg": data.in_edges["sum"]["_deg"], "outdeg": data.out_edges["sum"]["_deg"], "_nope": 0, } def update_params(self, data: ParamUpdateData): a = jnp.sum(data.state.node["indeg"]) return { "_a": a, "_a_rec": a * data.state.edge["aa"][0], } np = NetworkProcess([TestOp()], record_keys=["_a_rec"]) sb0 = _new_state(np) sb1 = np.run(sb0, steps=1) print(np.trace_log()) assert (sb1.node["indeg"] == jnp.array([1, 2, 1, 1])).all() assert (sb1.node["outdeg"] == jnp.array([2, 1, 2, 0])).all() assert (sb1.node["x"] == jnp.array([[3, 4], [6, 8], [1, 2], [5, 6]])).all() # NB: this one depends on the RNG for reproducibility assert (sb1.node["y"] == jnp.array([100.1, 298.2, 100.3, 285.4])).all() assert (sb1.edge["stat"] == sb1.edge["stat"]).all() assert (sb1.edge["aa"] == jnp.array([1.1, 2.3, 3.3, 4.1, 5.2])).all() assert (sb1.records.all_records()["_a_rec"] == jnp.array([5.5])).all() # Check underscored are ommited assert "_nope" not in sb1.node assert "_nope" not in sb1.edge assert "_nope" not in sb1 # Also test the computed degrees assert (sb1.node["in_deg"] == jnp.array([1, 2, 1, 1])).all() assert (sb1.node["out_deg"] == jnp.array([2, 1, 2, 0])).all()
def epi_demo(edge_beta, gamma, infect, nodes, steps): k = 3 np = NetworkProcess([ epidemics.SIRUpdateOp(), # operations.CountNodeStatesOp(states=3, key="compartment"), # operations.CountNodeTransitionsOp(states=3, key="compartment"), ]) params = {"edge_infection_rate": edge_beta, "recovery_rate": gamma} log.info( f"Network: Barabasi-Albert. n={nodes}, k={k}, cca {nodes*k*2:.2e} directed edges" ) with utils.logged_time(" Creating graph", logger=log): g = nx.random_graphs.barabasi_albert_graph(nodes, k) with utils.logged_time(" Creating state", logger=log): net = Network.from_graph(g) state = np.new_state(net, props=params, seed=42) rng = jax.random.PRNGKey(43) comp = jnp.int32( jax.random.bernoulli(rng, infect / nodes, shape=[nodes])) state.node["compartment"] = comp with utils.logged_time(" Running model", logger=log): t0 = time.time() state2 = np.run(state, steps=steps) state2.block_on_all() t1 = time.time() log.info(np.trace_log()) sps = steps / (t1 - t0) log.info(f"{steps} steps took {t1-t0:.2g} s, {sps:.3g} steps/s, " + f"{sps*state.m:.3g} edge_ops/s, {sps * state.n:.3g} node_ops/s")
def test_sir_model(): N = 10 g = nx.random_graphs.barabasi_albert_graph(N, 3, seed=46) net = Network.from_graph(g) np = NetworkProcess([ epidemics.SIRUpdateOp(), operations.IncrementParam("t", 1.0, default=0.0) ]) s = np.new_state(net, props={ "edge_infection_rate": 0.6, "recovery_rate": 0.5 }, seed=43) # Few passes without any infections print(s.node["compartment"]) s = np.run(s, steps=4) print(s.node["compartment"]) assert sum(s.node["compartment"]) == 0 # Infect a single high-degree node s.node["compartment"] = jnp.array([1] + [0] * (s.n - 1)) # Infection spread s = np.run(s, steps=4) print(s.node["compartment"]) assert sum(s.node["compartment"] == 0) in range(2, 5) assert sum(s.node["compartment"] == 1) in range(3, 7) assert sum(s.node["compartment"] == 2) in range(1, 6) assert abs(s["t"] - 8.0) < 1e-3 assert np._traced == 1 ## Not essential, testing tracing-once property
def test_branching_records(): np = NetworkProcess( [operations.Fun(params_f=lambda data: {"_x": data.state.step})], record_keys=["_x"], ) n0 = Network.from_graph(nx.complete_graph(4)) s0 = np.new_state(n0, seed=32) s1 = np.run(s0, steps=3, jit=False) s2a = np.run(s1, steps=2, jit=False) s2b = np.run(s1, steps=1, jit=False) assert (s2a.records.all_records()["_x"] == jnp.array([0, 1, 2, 3, 4])).all() assert (s2b.records.all_records()["_x"] == jnp.array([0, 1, 2, 3])).all()
def test_best_response_game(): N = 30 net_g = nx.random_graphs.barabasi_albert_graph(N, 3, seed=43) net = Network.from_graph(net_g) g = games.RegretMatchingGame(["C", "D"], jnp.array([[4, 0], [5, 1]])) np = NetworkProcess([g]) s = np.new_state(net, seed=47, props={"node.next_action": [0] * N}) s = np.run(s, steps=10, jit=True) assert sum(s.node["action"]) > 0.9 * N s = np.run(s, steps=10, jit=True) print(s.node["action"]) assert sum(s.node["action"]) == N
def test_nop_process(): np = NetworkProcess([OperationBase()]) n0 = Network.from_graph(nx.complete_graph(4)) sa0 = np.new_state(n0, seed=32, props={"beta": 1.5}) sa1 = np.run(sa0, steps=4, jit=False) assert sa1["beta"] == 1.5 sb0 = _new_state(np) sb1 = np.run(sb0, steps=2, jit=False) # Look at step separately, set to 0 to allow comparison assert sb1.step == 2 assert not (sb0.prng_key == sb1.prng_key).all() sb1["step"] = 0 sb1["prng_key"] = sb0.prng_key assert sb0.data_eq(sb1)
def test_time_advancing_op(): np = NetworkProcess([operations.IncrementParam("t", "delta_t", default=0.0)]) n0 = Network.from_graph(nx.complete_graph(4)) sa0 = np.new_state(n0, seed=32, props={"delta_t": 0.1}) sa1 = np.run(sa0, steps=4, jit=False) assert abs(sa1["t"] - 0.4) < 1e-6 np = NetworkProcess([operations.IncrementParam("t", 1.0, default=0.0)]) sb0 = np.new_state(n0, seed=42) sb1 = np.run(sb0, steps=5, jit=False) assert abs(sb1["t"] - 5.0) < 1e-6
def test_best_response_game(): N = 30 net_g = nx.random_graphs.barabasi_albert_graph(N, 3, seed=42) net = Network.from_graph(net_g) p = games.SoftmaxPolicy(beta="beta") g = games.BestResponseGame(["C", "D"], jnp.array([[4, 0], [5, 1]]), p) np = NetworkProcess([g]) s = np.new_state(net, seed=43, props={ "beta": 1.0, "node.next_action": [0] * N }) s = np.run(s, steps=10, jit=True) assert sum(s.node["action"]) > 0.9 * N s["beta"] = 0.05 s = np.run(s, steps=10, jit=True) print(s.node["action"]) assert sum(s.node["action"]) < 0.8 * N assert sum(s.node["action"]) > 0.4 * N
def test_si_model(): N = 10 g = nx.random_graphs.barabasi_albert_graph(N, 3, seed=42) net = Network.from_graph(g) np = NetworkProcess([epidemics.SIUpdateOp()]) s = np.new_state(net, props={"edge_infection_rate": 0.3}, seed=43) # Few passes without any infections print(s.node["compartment"]) s = np.run(s, steps=3) print(s.node["compartment"]) assert sum(s.node["compartment"]) == 0 # Infect a single high-degree node s.node["compartment"] = jnp.array([1] + [0] * (s.n - 1)) # Infection spread s = np.run(s, steps=3) print(s.node["compartment"]) assert sum(s.node["compartment"]) < 10 assert sum(s.node["compartment"]) > 4 assert np._traced == 1 ## Not essential, testing tracing-once property