Ejemplo n.º 1
0
def test_DAG_from_amat():
    unconnected_amat = np.zeros((4, 4))
    unconnected_graph = DAG.from_amat(unconnected_amat, list("ABCD"))
    unconnected_graph_list = DAG.from_amat(unconnected_amat.tolist(),
                                           list("ABCD"))
    fully_connected_amat = np.tril(np.ones((4, 4)), -1)
    fully_connected_graph = DAG.from_amat(fully_connected_amat, list("ABCD"))

    assert DAG.from_amat(unconnected_amat).nodes == {"A", "B", "C", "D"}

    assert np.all(unconnected_graph.get_numpy_adjacency() == unconnected_amat)
    assert np.all(
        unconnected_graph_list.get_numpy_adjacency() == unconnected_amat)
    assert np.all(
        fully_connected_graph.get_numpy_adjacency() == fully_connected_amat)

    assert fully_connected_graph.nodes == unconnected_graph.nodes == {
        "A", "B", "C", "D"
    }
    assert unconnected_graph.edges == set()
    assert (fully_connected_graph.edges == fully_connected_graph.directed_edges
            == {
                ('C', 'A'),
                ('B', 'A'),
                ('D', 'B'),
                ('D', 'C'),
                ('D', 'A'),
                ('C', 'B'),
            })
Ejemplo n.º 2
0
def test_icdag_degree_caps():
    dag = DAG.generate("ide_cozman", 10, burn_in=1000, max_indegree=2, seed=1)
    assert max(dag.indegree()) <= 2
    assert 2 < max(dag.outdegree(
    )) <= 3, "Make sure out/degree isn't capped when indegree is set"
    dag = DAG.generate("ide_cozman", 10, burn_in=1000, max_outdegree=2, seed=1)
    assert max(dag.outdegree()) <= 2
    assert 2 < max(dag.indegree(
    )) <= 3, "Make sure in/degree isn't capped when outdegree is set"
    dag = DAG.generate("ide_cozman", 10, burn_in=1000, max_degree=2, seed=1)
    assert max(dag.degree()) <= 2
Ejemplo n.º 3
0
def test_compare(temp_out, test_dag):
    img_path = temp_out / 'comparison.png'
    dag2 = DAG()
    dag2.add_vertices(list("ABCD"))
    comparison_1 = test_dag.compare(dag2)
    assert comparison_1.es['color'] == ['red'] * 3
    assert comparison_1.es['style'] == ['dashed'] * 3
    comparison_2 = dag2.compare(test_dag)
    assert comparison_2.es['color'] == ['blue'] * 3
    assert comparison_2.es['style'] == ['solid'] * 3

    comparison_1.plot(img_path)
    assert img_path.exists()
Ejemplo n.º 4
0
def test_DAG_serialise_continuous_str(test_dag):
    dag = test_dag
    dag.generate_continuous_parameters()
    dag_string = dag.save()
    dag2 = DAG.load(dag_string)
    assert dag.nodes == dag2.nodes
    assert dag.edges == dag2.edges
Ejemplo n.º 5
0
def test_DAG_serialise_discrete_str(test_dag):
    dag = test_dag
    dag.generate_discrete_parameters(seed=0)
    dag_string = dag.save()
    dag2 = DAG.load(dag_string)
    assert dag.nodes == dag2.nodes
    assert dag.edges == dag2.edges
Ejemplo n.º 6
0
def test_structure_types(structure_type):
    dag_1 = DAG.generate(structure_type, 10, seed=1)
    dag_2 = getattr(structure_generation,
                    structure_type.lower().replace(" ", "_"))(10, seed=1)
    assert dag_1.is_dag()
    assert dag_1.nodes == dag_2.nodes == set(ascii_uppercase[:10])
    assert dag_1.edges == dag_2.edges > set()
Ejemplo n.º 7
0
def test_DAG_from_other():
    test_graph = nx.DiGraph()
    test_graph.add_nodes_from(list("ABCD"))
    edges = [("C", "B"), ("D", "B"), ("D", "C")]
    test_graph.add_edges_from(edges)
    graph = DAG.from_other(test_graph)
    assert graph.edges == graph.directed_edges == set(edges)
    assert graph.nodes == set(list("ABCD"))
Ejemplo n.º 8
0
def test_DAG_serialise_discrete_file(temp_out, test_dag):
    dag_path = temp_out / 'cont.pb'
    dag = test_dag
    dag.generate_discrete_parameters(seed=0)
    dag.save(dag_path)
    dag2 = DAG.load(dag_path)
    assert dag.nodes == dag2.nodes
    assert dag.edges == dag2.edges
Ejemplo n.º 9
0
def test_copy(test_dag):
    dag = test_dag
    dag_copy = dag.copy()
    assert dag.nodes == dag_copy.nodes
    assert dag.edges == dag_copy.edges

    dag = DAG.generate("forest_fire", 10, seed=1)
    dag_copy = dag.copy()
    assert dag.nodes == dag_copy.nodes
    assert dag.edges == dag_copy.edges
Ejemplo n.º 10
0
def test_bif_library():
    all_bifs = [
        p.stem for p in (Path(__file__).parent.parent / 'baynet' / 'utils' /
                         'bif_library').resolve().glob('*.bif')
    ]
    for bif in all_bifs[:1]:
        try:
            dag = DAG.from_bif(bif)
            data = dag.sample(1)
        except Exception as e:
            raise RuntimeError(f"Error loading {bif}: {e}")
Ejemplo n.º 11
0
def test_dfe_parameters():
    dag = DAG.from_modelstring('[A]')
    dag.vs['levels'] = [["A", "B"] for v in dag.vs]
    data = pd.DataFrame({'A': [0, 0]})
    dag.estimate_parameters(data,
                            method="dfe",
                            method_args={
                                "iterations": 1,
                                "learning_rate": 0.1
                            })
    assert np.array_equal(dag.vs[0]['CPD'].cumsum_array, np.array([0.55, 1]))
Ejemplo n.º 12
0
def test_DAG_get_v_structures(test_dag, reversed_dag, partial_dag):
    dag = test_dag
    assert partial_dag.get_v_structures() == {("C", "B", "D")}
    assert dag.get_v_structures() == set()
    assert dag.get_v_structures(True) == {("C", "B", "D")}
    assert reversed_dag.get_v_structures(True) == {("B", "D", "C")}

    # Test node order doesn't change V-structure tuple order
    other_dag = DAG()
    other_dag.add_vertices(list("DCBA"))
    other_dag.add_edges([("C", "B"), ("D", "B"), ("D", "C")])
    assert other_dag.get_v_structures(True) == dag.get_v_structures(True)
Ejemplo n.º 13
0
def test_bif_parser():
    bif_path = (Path(__file__).parent.parent / 'baynet' / 'utils' /
                'bif_library' / 'earthquake.bif').resolve()

    dag = dag_from_bif(bif_path)
    earthquake_dag = DAG.from_modelstring(
        "[Alarm|Burglary:Earthquake][Burglary][Earthquake][JohnCalls|Alarm][MaryCalls|Alarm]"
    )
    assert dag.nodes == earthquake_dag.nodes
    assert dag.edges == earthquake_dag.edges
    dag.sample(10)

    with pytest.raises(ValueError):
        dag = dag_from_bif("foo")
    with pytest.raises(ValueError):
        dag = dag_from_bif(Path("foo"))
Ejemplo n.º 14
0
def empty_dag() -> DAG:
    return DAG.from_amat(np.zeros((4, 4)), list("ABCD"))
Ejemplo n.º 15
0
def partial_dag() -> DAG:
    return DAG.from_modelstring("[A][B|C:D][C][D]")
Ejemplo n.º 16
0
def reversed_dag() -> DAG:
    return DAG.from_modelstring(REVERSED_MODELSTRING)
Ejemplo n.º 17
0
def test_dag() -> DAG:
    return DAG.from_modelstring(TEST_MODELSTRING)
Ejemplo n.º 18
0
def test_name_nodes():
    dag = DAG(igraph.Graph(2, directed=True))
    assert dag.get_node_name(0) == "A"
    assert dag.get_node_index("A") == 0