Exemplo n.º 1
0
def test_basic():
    print '\bb.',
    n = Network()
    assert_raises(AttributeError, n.set, (0,1,2))

    n = network.independent([DiscreteRV(4) for i in range(10)])
    n.set((0,1,2,3,0,1,2,3,0,1))
    assert n.forward() == (1./4)**10, str(n.forward()) + ',' + str((1./4)**10)

    n = network.independent('4cx1')
    n.set((0,1,2,3))
    assert n.forward() == prod([norm.pdf(x) for x in (0,1,2,3)])

    n = network.independent('4d3')
    assert_raises(ValueError, n.set, (0,1,2,3))
Exemplo n.º 2
0
def test_history():
    print '\bh.',
    net = network.independent('8d3')

    def do_random_op():
        # randomly select a product node and child subset
        pnl = [n for n in net.pot if n.is_prod() and len(n.children) >=2]
        pn = random.choice(pnl)
        chl_size = random.choice(range(2, len(pn.children)+1))
        chl = list(pn.children)
        random.shuffle(chl)
        chl = chl[:chl_size]

        # perform an operation
        mo = ops.MixOp(net, pn, chl)
        mo.connect(random.choice(range(2, 5)))
        net.pot = None
        return mo

    opot = list(net.pot)
    ops_hist = []
    pot_hist = []
    for i in range(10):
        ops_hist.append( do_random_op() )
        pot_hist.append( list(net.pot) )
    for i in range(len(pot_hist))[::-1]:
        assert len(pot_hist[i]) == len(net.pot)
        assert set(pot_hist[i]) == set(net.pot)
        ops_hist[i].undo()
    assert len(opot) == len(net.pot)
    assert set(opot) == set(net.pot)
Exemplo n.º 3
0
Arquivo: learn.py Projeto: awd4/spnss
def seed_network(trn, schema):
    net = network.independent(np.array(schema))
    for i,cn in enumerate(net.graph.root.children):
        train.categorical_node_map_with_dirichlet_prior(cn, trn[:,i], 2*np.ones(cn.masses.size()))
    return net