def test_DAG_from_amat(): unconnected_amat = np.zeros((4, 4)) unconnected_graph = DAG.from_amat(unconnected_amat, list("ABCD")) unconnected_graph_list = DAG.from_amat(unconnected_amat.tolist(), list("ABCD")) fully_connected_amat = np.tril(np.ones((4, 4)), -1) fully_connected_graph = DAG.from_amat(fully_connected_amat, list("ABCD")) assert DAG.from_amat(unconnected_amat).nodes == {"A", "B", "C", "D"} assert np.all(unconnected_graph.get_numpy_adjacency() == unconnected_amat) assert np.all( unconnected_graph_list.get_numpy_adjacency() == unconnected_amat) assert np.all( fully_connected_graph.get_numpy_adjacency() == fully_connected_amat) assert fully_connected_graph.nodes == unconnected_graph.nodes == { "A", "B", "C", "D" } assert unconnected_graph.edges == set() assert (fully_connected_graph.edges == fully_connected_graph.directed_edges == { ('C', 'A'), ('B', 'A'), ('D', 'B'), ('D', 'C'), ('D', 'A'), ('C', 'B'), })
def test_icdag_degree_caps(): dag = DAG.generate("ide_cozman", 10, burn_in=1000, max_indegree=2, seed=1) assert max(dag.indegree()) <= 2 assert 2 < max(dag.outdegree( )) <= 3, "Make sure out/degree isn't capped when indegree is set" dag = DAG.generate("ide_cozman", 10, burn_in=1000, max_outdegree=2, seed=1) assert max(dag.outdegree()) <= 2 assert 2 < max(dag.indegree( )) <= 3, "Make sure in/degree isn't capped when outdegree is set" dag = DAG.generate("ide_cozman", 10, burn_in=1000, max_degree=2, seed=1) assert max(dag.degree()) <= 2
def test_compare(temp_out, test_dag): img_path = temp_out / 'comparison.png' dag2 = DAG() dag2.add_vertices(list("ABCD")) comparison_1 = test_dag.compare(dag2) assert comparison_1.es['color'] == ['red'] * 3 assert comparison_1.es['style'] == ['dashed'] * 3 comparison_2 = dag2.compare(test_dag) assert comparison_2.es['color'] == ['blue'] * 3 assert comparison_2.es['style'] == ['solid'] * 3 comparison_1.plot(img_path) assert img_path.exists()
def test_DAG_serialise_continuous_str(test_dag): dag = test_dag dag.generate_continuous_parameters() dag_string = dag.save() dag2 = DAG.load(dag_string) assert dag.nodes == dag2.nodes assert dag.edges == dag2.edges
def test_DAG_serialise_discrete_str(test_dag): dag = test_dag dag.generate_discrete_parameters(seed=0) dag_string = dag.save() dag2 = DAG.load(dag_string) assert dag.nodes == dag2.nodes assert dag.edges == dag2.edges
def test_structure_types(structure_type): dag_1 = DAG.generate(structure_type, 10, seed=1) dag_2 = getattr(structure_generation, structure_type.lower().replace(" ", "_"))(10, seed=1) assert dag_1.is_dag() assert dag_1.nodes == dag_2.nodes == set(ascii_uppercase[:10]) assert dag_1.edges == dag_2.edges > set()
def test_DAG_from_other(): test_graph = nx.DiGraph() test_graph.add_nodes_from(list("ABCD")) edges = [("C", "B"), ("D", "B"), ("D", "C")] test_graph.add_edges_from(edges) graph = DAG.from_other(test_graph) assert graph.edges == graph.directed_edges == set(edges) assert graph.nodes == set(list("ABCD"))
def test_DAG_serialise_discrete_file(temp_out, test_dag): dag_path = temp_out / 'cont.pb' dag = test_dag dag.generate_discrete_parameters(seed=0) dag.save(dag_path) dag2 = DAG.load(dag_path) assert dag.nodes == dag2.nodes assert dag.edges == dag2.edges
def test_copy(test_dag): dag = test_dag dag_copy = dag.copy() assert dag.nodes == dag_copy.nodes assert dag.edges == dag_copy.edges dag = DAG.generate("forest_fire", 10, seed=1) dag_copy = dag.copy() assert dag.nodes == dag_copy.nodes assert dag.edges == dag_copy.edges
def test_bif_library(): all_bifs = [ p.stem for p in (Path(__file__).parent.parent / 'baynet' / 'utils' / 'bif_library').resolve().glob('*.bif') ] for bif in all_bifs[:1]: try: dag = DAG.from_bif(bif) data = dag.sample(1) except Exception as e: raise RuntimeError(f"Error loading {bif}: {e}")
def test_dfe_parameters(): dag = DAG.from_modelstring('[A]') dag.vs['levels'] = [["A", "B"] for v in dag.vs] data = pd.DataFrame({'A': [0, 0]}) dag.estimate_parameters(data, method="dfe", method_args={ "iterations": 1, "learning_rate": 0.1 }) assert np.array_equal(dag.vs[0]['CPD'].cumsum_array, np.array([0.55, 1]))
def test_DAG_get_v_structures(test_dag, reversed_dag, partial_dag): dag = test_dag assert partial_dag.get_v_structures() == {("C", "B", "D")} assert dag.get_v_structures() == set() assert dag.get_v_structures(True) == {("C", "B", "D")} assert reversed_dag.get_v_structures(True) == {("B", "D", "C")} # Test node order doesn't change V-structure tuple order other_dag = DAG() other_dag.add_vertices(list("DCBA")) other_dag.add_edges([("C", "B"), ("D", "B"), ("D", "C")]) assert other_dag.get_v_structures(True) == dag.get_v_structures(True)
def test_bif_parser(): bif_path = (Path(__file__).parent.parent / 'baynet' / 'utils' / 'bif_library' / 'earthquake.bif').resolve() dag = dag_from_bif(bif_path) earthquake_dag = DAG.from_modelstring( "[Alarm|Burglary:Earthquake][Burglary][Earthquake][JohnCalls|Alarm][MaryCalls|Alarm]" ) assert dag.nodes == earthquake_dag.nodes assert dag.edges == earthquake_dag.edges dag.sample(10) with pytest.raises(ValueError): dag = dag_from_bif("foo") with pytest.raises(ValueError): dag = dag_from_bif(Path("foo"))
def empty_dag() -> DAG: return DAG.from_amat(np.zeros((4, 4)), list("ABCD"))
def partial_dag() -> DAG: return DAG.from_modelstring("[A][B|C:D][C][D]")
def reversed_dag() -> DAG: return DAG.from_modelstring(REVERSED_MODELSTRING)
def test_dag() -> DAG: return DAG.from_modelstring(TEST_MODELSTRING)
def test_name_nodes(): dag = DAG(igraph.Graph(2, directed=True)) assert dag.get_node_name(0) == "A" assert dag.get_node_index("A") == 0