def test_categorical_node_map_with_dirichlet_prior(): print '\bcmwdp.', drv, cn = DiscreteRV(4), CategoricalNode() cn.set_rv(drv) data = [0, 0, 2, 3, 3, 3] train.categorical_node_map_with_dirichlet_prior(cn, data) assert np.allclose(cn.masses, [2./6, 0./6, 1./6, 3./6.]) train.categorical_node_map_with_dirichlet_prior(cn, data, prior_params=[2,2,2,2]) assert np.allclose(cn.masses, [3./10., 1./10., 2./10., 4./10.])
def seed_network(trn, schema): net = network.independent(np.array(schema)) for i,cn in enumerate(net.graph.root.children): train.categorical_node_map_with_dirichlet_prior(cn, trn[:,i], 2*np.ones(cn.masses.size())) return net