Пример #1
0
def test_add_cpds():
    spbn = SemiparametricBN([('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'),
                             ('b', 'd'), ('c', 'd')], [('d', pbn.CKDEType())])

    assert spbn.node_type('a') == pbn.UnknownFactorType()
    spbn.add_cpds([CKDE('a', [])])
    assert spbn.node_type('a') == pbn.CKDEType()

    with pytest.raises(ValueError) as ex:
        spbn.add_cpds([LinearGaussianCPD('d', ['a', 'b', 'c'])])
    assert "Bayesian network expects type" in str(ex.value)

    lg = LinearGaussianCPD('b', ['a'], [2.5, 1.65], 4)
    ckde = CKDE('d', ['a', 'b', 'c'])
    assert lg.fitted()
    assert not ckde.fitted()

    spbn.add_cpds([lg, ckde])

    spbn.set_node_type('a', pbn.UnknownFactorType())
    with pytest.raises(ValueError) as ex:
        not spbn.cpd('a').fitted()
    assert "CPD of variable \"a\" not added. Call add_cpds() or fit() to add the CPD." in str(
        ex.value)

    assert spbn.cpd('b').fitted()

    with pytest.raises(ValueError) as ex:
        not spbn.cpd('c').fitted()
    assert "CPD of variable \"c\" not added. Call add_cpds() or fit() to add the CPD." in str(
        ex.value)

    assert not spbn.cpd('d').fitted()
Пример #2
0
def test_serialization_bn_model(gaussian_bytes, spbn_bytes, kde_bytes,
                                discrete_bytes, genericbn_bytes, newbn_bytes,
                                otherbn_bytes):
    loaded_g = pickle.loads(gaussian_bytes)
    assert set(loaded_g.nodes()) == set(["a", "b", "c", "d"])
    assert loaded_g.arcs() == [("a", "b")]
    assert loaded_g.type() == pbn.GaussianNetworkType()

    loaded_s = pickle.loads(spbn_bytes)
    assert set(loaded_s.nodes()) == set(["a", "b", "c", "d"])
    assert loaded_s.arcs() == [("a", "b")]
    assert loaded_s.type() == pbn.SemiparametricBNType()
    assert loaded_s.node_types() == {
        'a': pbn.UnknownFactorType(),
        'b': pbn.CKDEType(),
        'c': pbn.UnknownFactorType(),
        'd': pbn.UnknownFactorType()
    }

    loaded_k = pickle.loads(kde_bytes)
    assert set(loaded_k.nodes()) == set(["a", "b", "c", "d"])
    assert loaded_k.arcs() == [("a", "b")]
    assert loaded_k.type() == pbn.KDENetworkType()

    loaded_d = pickle.loads(discrete_bytes)
    assert set(loaded_d.nodes()) == set(["a", "b", "c", "d"])
    assert loaded_d.arcs() == [("a", "b")]
    assert loaded_d.type() == pbn.DiscreteBNType()

    loaded_gen = pickle.loads(genericbn_bytes)
    assert set(loaded_gen.nodes()) == set(["a", "b", "c", "d"])
    assert loaded_gen.arcs() == [("a", "b")]
    assert loaded_gen.type() == MyRestrictedGaussianNetworkType()

    loaded_nn = pickle.loads(newbn_bytes)
    assert set(loaded_g.nodes()) == set(["a", "b", "c", "d"])
    assert loaded_nn.arcs() == [("a", "b")]
    assert loaded_nn.type() == MyRestrictedGaussianNetworkType()

    loaded_o = pickle.loads(otherbn_bytes)
    assert set(loaded_g.nodes()) == set(["a", "b", "c", "d"])
    assert loaded_o.arcs() == [("a", "b")]
    assert loaded_o.type() == NonHomogeneousType()
    assert loaded_o.node_types() == {
        'a': pbn.UnknownFactorType(),
        'b': pbn.LinearGaussianCPDType(),
        'c': pbn.CKDEType(),
        'd': pbn.DiscreteFactorType()
    }
    assert loaded_o.extra_info == "extra"

    assert loaded_nn.type() != loaded_o.type()
Пример #3
0
def test_serialization_dbn_model(dyn_gaussian_bytes, dyn_spbn_bytes,
                                 dyn_kde_bytes, dyn_discrete_bytes,
                                 dyn_genericbn_bytes, dyn_newbn_bytes,
                                 dyn_otherbn_bytes):
    loaded_g = pickle.loads(dyn_gaussian_bytes)
    assert set(loaded_g.variables()) == set(["a", "b", "c", "d"])
    assert loaded_g.static_bn().arcs() == [("a_t_2", "d_t_1")]
    assert loaded_g.transition_bn().arcs() == [("c_t_2", "b_t_0")]
    assert loaded_g.type() == pbn.GaussianNetworkType()

    loaded_s = pickle.loads(dyn_spbn_bytes)
    assert set(loaded_s.variables()) == set(["a", "b", "c", "d"])
    assert loaded_s.static_bn().arcs() == [("a_t_2", "d_t_1")]
    assert loaded_s.transition_bn().arcs() == [("c_t_2", "b_t_0")]
    assert loaded_s.type() == pbn.SemiparametricBNType()
    node_types = {
        v + "_t_0": pbn.UnknownFactorType()
        for v in loaded_s.variables()
    }
    node_types["b_t_0"] = pbn.CKDEType()
    assert loaded_s.transition_bn().node_types() == node_types

    loaded_k = pickle.loads(dyn_kde_bytes)
    assert set(loaded_k.variables()) == set(["a", "b", "c", "d"])
    assert loaded_k.static_bn().arcs() == [("a_t_2", "d_t_1")]
    assert loaded_k.transition_bn().arcs() == [("c_t_2", "b_t_0")]
    assert loaded_k.type() == pbn.KDENetworkType()

    loaded_d = pickle.loads(dyn_discrete_bytes)
    assert set(loaded_d.variables()) == set(["a", "b", "c", "d"])
    assert loaded_d.static_bn().arcs() == [("a_t_2", "d_t_1")]
    assert loaded_d.transition_bn().arcs() == [("c_t_2", "b_t_0")]
    assert loaded_d.type() == pbn.DiscreteBNType()

    loaded_gen = pickle.loads(dyn_genericbn_bytes)
    assert set(loaded_gen.variables()) == set(["a", "b", "c", "d"])
    assert loaded_gen.static_bn().arcs() == [("a_t_2", "d_t_1")]
    assert loaded_gen.transition_bn().arcs() == [("a_t_2", "b_t_0")]
    assert loaded_gen.type() == MyRestrictedGaussianNetworkType()

    loaded_nn = pickle.loads(dyn_newbn_bytes)
    assert set(loaded_nn.variables()) == set(["a", "b", "c", "d"])
    assert loaded_nn.static_bn().arcs() == [("a_t_2", "d_t_1")]
    assert loaded_nn.transition_bn().arcs() == [("a_t_2", "b_t_0")]
    assert loaded_nn.type() == MyRestrictedGaussianNetworkType()

    loaded_other = pickle.loads(dyn_otherbn_bytes)
    assert set(loaded_other.variables()) == set(["a", "b", "c", "d"])
    assert loaded_other.static_bn().arcs() == [("a_t_2", "d_t_1")]
    assert loaded_other.transition_bn().arcs() == [("a_t_2", "b_t_0")]
    assert loaded_other.type() == NonHomogeneousType()
    assert loaded_other.extra_info == "extra"

    assert loaded_other.static_bn().node_type(
        "c_t_1") == pbn.DiscreteFactorType()
    assert loaded_other.static_bn().node_type("d_t_1") == pbn.CKDEType()
    assert loaded_other.transition_bn().node_type("d_t_0") == pbn.CKDEType()
Пример #4
0
def test_node_type():
    spbn = SemiparametricBN(['a', 'b', 'c', 'd'])
    assert spbn.num_nodes() == 4
    assert spbn.num_arcs() == 0
    assert spbn.nodes() == ['a', 'b', 'c', 'd']

    for n in spbn.nodes():
        assert spbn.node_type(n) == pbn.UnknownFactorType()

    spbn.set_node_type('b', pbn.CKDEType())
    assert spbn.node_type('b') == pbn.CKDEType()
    spbn.set_node_type('b', pbn.LinearGaussianCPDType())
    assert spbn.node_type('b') == pbn.LinearGaussianCPDType()
Пример #5
0
def test_apply():
    gbn = pbn.GaussianNetwork(['a', 'b', 'c', 'd'])
    assert gbn.num_arcs() == 0
    assert not gbn.has_arc('a', 'b')

    o = pbn.AddArc("a", "b", 1)
    o.apply(gbn)
    assert gbn.num_arcs() == 1
    assert gbn.has_arc('a', 'b')

    o = pbn.FlipArc("a", "b", 1)
    o.apply(gbn)
    assert gbn.num_arcs() == 1
    assert not gbn.has_arc('a', 'b')
    assert gbn.has_arc('b', 'a')

    o = pbn.RemoveArc("b", "a", 1)
    o.apply(gbn)
    assert gbn.num_arcs() == 0
    assert not gbn.has_arc('b', 'a')

    o = pbn.ChangeNodeType("a", pbn.CKDEType(), 1)
    with pytest.raises(ValueError) as ex:
        o.apply(gbn)
    assert "Wrong factor type" in str(ex.value)

    spbn = pbn.SemiparametricBN(['a', 'b', 'c', 'd'])
    assert spbn.num_arcs() == 0

    o = pbn.ChangeNodeType("a", pbn.CKDEType(), 1)
    assert (spbn.node_type('a') == pbn.UnknownFactorType())
    o.apply(spbn)
    assert (spbn.node_type('a') == pbn.CKDEType())

    assert not spbn.has_arc('a', 'b')
    o = pbn.AddArc("a", "b", 1)
    o.apply(spbn)
    assert spbn.num_arcs() == 1
    assert spbn.has_arc('a', 'b')

    o = pbn.FlipArc("a", "b", 1)
    o.apply(spbn)
    assert spbn.num_arcs() == 1
    assert not spbn.has_arc('a', 'b')
    assert spbn.has_arc('b', 'a')

    o = pbn.RemoveArc("b", "a", 1)
    o.apply(spbn)
    assert spbn.num_arcs() == 0
    assert not spbn.has_arc('b', 'a')
Пример #6
0
def test_serialization_conditional_bn_model(
        cond_gaussian_bytes, cond_spbn_bytes, cond_kde_bytes,
        cond_discrete_bytes, cond_genericbn_bytes, cond_newbn_bytes,
        cond_otherbn_bytes, newbn_bytes, otherbn_bytes):
    loaded_g = pickle.loads(cond_gaussian_bytes)
    assert set(loaded_g.nodes()) == set(["c", "d"])
    assert set(loaded_g.interface_nodes()) == set(["a", "b"])
    assert loaded_g.arcs() == [("a", "c")]
    assert loaded_g.type() == pbn.GaussianNetworkType()

    loaded_s = pickle.loads(cond_spbn_bytes)
    assert set(loaded_s.nodes()) == set(["c", "d"])
    assert set(loaded_s.interface_nodes()) == set(["a", "b"])
    assert loaded_s.arcs() == [("a", "c")]
    assert loaded_s.type() == pbn.SemiparametricBNType()
    assert loaded_s.node_types() == {
        'c': pbn.CKDEType(),
        'd': pbn.UnknownFactorType()
    }

    loaded_k = pickle.loads(cond_kde_bytes)
    assert set(loaded_k.nodes()) == set(["c", "d"])
    assert set(loaded_k.interface_nodes()) == set(["a", "b"])
    assert loaded_k.arcs() == [("a", "c")]
    assert loaded_k.type() == pbn.KDENetworkType()

    loaded_d = pickle.loads(cond_discrete_bytes)
    assert set(loaded_d.nodes()) == set(["c", "d"])
    assert set(loaded_d.interface_nodes()) == set(["a", "b"])
    assert loaded_d.arcs() == [("a", "c")]
    assert loaded_d.type() == pbn.DiscreteBNType()

    loaded_gen = pickle.loads(cond_genericbn_bytes)
    assert set(loaded_gen.nodes()) == set(["c", "d"])
    assert set(loaded_gen.interface_nodes()) == set(["a", "b"])
    assert loaded_gen.arcs() == [("a", "c")]
    assert loaded_gen.type() == MyRestrictedGaussianNetworkType()

    loaded_nn = pickle.loads(cond_newbn_bytes)
    assert set(loaded_nn.nodes()) == set(["c", "d"])
    assert set(loaded_nn.interface_nodes()) == set(["a", "b"])
    assert loaded_nn.arcs() == [("a", "c")]
    assert loaded_nn.type() == MyRestrictedGaussianNetworkType()

    loaded_o = pickle.loads(cond_otherbn_bytes)
    assert set(loaded_o.nodes()) == set(["c", "d"])
    assert set(loaded_o.interface_nodes()) == set(["a", "b"])
    assert loaded_o.arcs() == [("a", "c")]
    assert loaded_o.type() == NonHomogeneousType()
    assert loaded_o.node_types() == {
        'c': pbn.CKDEType(),
        'd': pbn.DiscreteFactorType()
    }
    assert loaded_o.extra_info == "extra"

    assert loaded_nn.type() != loaded_o.type()

    loaded_unconditional_nn = pickle.loads(newbn_bytes)
    loaded_unconditional_o = pickle.loads(otherbn_bytes)

    assert loaded_nn.type() == loaded_unconditional_nn.type()
    assert loaded_o.type() == loaded_unconditional_o.type()
Пример #7
0
def test_create_spbn():
    spbn = SemiparametricBN(['a', 'b', 'c', 'd'])
    assert spbn.num_nodes() == 4
    assert spbn.num_arcs() == 0
    assert spbn.nodes() == ['a', 'b', 'c', 'd']

    for n in spbn.nodes():
        assert spbn.node_type(n) == pbn.UnknownFactorType()

    spbn = SemiparametricBN(['a', 'b', 'c', 'd'], [('a', 'c')])
    assert spbn.num_nodes() == 4
    assert spbn.num_arcs() == 1
    assert spbn.nodes() == ['a', 'b', 'c', 'd']

    for n in spbn.nodes():
        assert spbn.node_type(n) == pbn.UnknownFactorType()

    spbn = SemiparametricBN([('a', 'c'), ('b', 'd'), ('c', 'd')])
    assert spbn.num_nodes() == 4
    assert spbn.num_arcs() == 3
    assert spbn.nodes() == ['a', 'c', 'b', 'd']

    for n in spbn.nodes():
        assert spbn.node_type(n) == pbn.UnknownFactorType()

    with pytest.raises(TypeError) as ex:
        spbn = SemiparametricBN(['a', 'b', 'c'], [('a', 'c', 'b')])
    assert "incompatible constructor arguments" in str(ex.value)

    with pytest.raises(IndexError) as ex:
        spbn = SemiparametricBN(['a', 'b', 'c'], [('a', 'd')])
    assert "not present in the graph" in str(ex.value)

    with pytest.raises(ValueError) as ex:
        spbn = SemiparametricBN([('a', 'b'), ('b', 'c'), ('c', 'a')])
    assert "must be a DAG" in str(ex.value)

    with pytest.raises(ValueError) as ex:
        spbn = SemiparametricBN(['a', 'b', 'c', 'd'], [('a', 'b'), ('b', 'c'),
                                                       ('c', 'a')])
    assert "must be a DAG" in str(ex.value)

    expected_node_type = {
        'a': pbn.CKDEType(),
        'b': pbn.UnknownFactorType(),
        'c': pbn.CKDEType(),
        'd': pbn.UnknownFactorType()
    }

    spbn = SemiparametricBN(['a', 'b', 'c', 'd'], [('a', pbn.CKDEType()),
                                                   ('c', pbn.CKDEType())])
    assert spbn.num_nodes() == 4
    assert spbn.num_arcs() == 0
    assert spbn.nodes() == ['a', 'b', 'c', 'd']

    for n in spbn.nodes():
        assert spbn.node_type(n) == expected_node_type[n]

    spbn = SemiparametricBN(['a', 'b', 'c', 'd'], [('a', 'c')],
                            [('a', pbn.CKDEType()), ('c', pbn.CKDEType())])
    assert spbn.num_nodes() == 4
    assert spbn.num_arcs() == 1
    assert spbn.nodes() == ['a', 'b', 'c', 'd']

    for n in spbn.nodes():
        assert spbn.node_type(n) == expected_node_type[n]

    spbn = SemiparametricBN([('a', 'c'), ('b', 'd'), ('c', 'd')],
                            [('a', pbn.CKDEType()), ('c', pbn.CKDEType())])
    assert spbn.num_nodes() == 4
    assert spbn.num_arcs() == 3
    assert spbn.nodes() == ['a', 'c', 'b', 'd']

    for n in spbn.nodes():
        assert spbn.node_type(n) == expected_node_type[n]

    with pytest.raises(TypeError) as ex:
        spbn = SemiparametricBN(['a', 'b', 'c'], [('a', 'c', 'b')],
                                [('a', pbn.CKDEType()), ('c', pbn.CKDEType())])
    assert "incompatible constructor arguments" in str(ex.value)

    with pytest.raises(IndexError) as ex:
        spbn = SemiparametricBN(['a', 'b', 'c'], [('a', 'd')],
                                [('a', pbn.CKDEType()), ('c', pbn.CKDEType())])
    assert "not present in the graph" in str(ex.value)

    with pytest.raises(ValueError) as ex:
        spbn = SemiparametricBN([('a', 'b'), ('b', 'c'), ('c', 'a')],
                                [('a', pbn.CKDEType()), ('c', pbn.CKDEType())])
    assert "must be a DAG" in str(ex.value)

    with pytest.raises(ValueError) as ex:
        spbn = SemiparametricBN(['a', 'b', 'c', 'd'], [('a', 'b'), ('b', 'c'),
                                                       ('c', 'a')],
                                [('a', pbn.CKDEType()), ('c', pbn.CKDEType())])
    assert "must be a DAG" in str(ex.value)