コード例 #1
0
 def setUp(self):
     # model a -> b
     bn = {}
     self.arr = []
     bn[0] = BN(domain=Domain(),
                new_domain_variables={
                    'a': [0, 1],
                    'b': [0, 1]
                })
     bn[0].add_cpts([
         CPT(Factor(variables=['a'], data=[0.5, 0.5]), child='a'),
         CPT(Factor(variables=['a', 'b'], data=[0.3, 0.7, 0.4, 0.6]),
             child='b')
     ])
     self.arr.append([('a', 'b')])
     bn[1] = BN(domain=Domain(),
                new_domain_variables={
                    'a': [0, 1],
                    'b': [0, 1],
                    'c': [0, 1]
                })
     bn[1].add_cpts([
         CPT(Factor(variables=['a'], data=[0.5, 0.5]), child='a'),
         CPT(Factor(variables=['a', 'b'], data=[0.3, 0.7, 0.4, 0.6]),
             child='b'),
         CPT(Factor(variables=['c', 'b'], data=[0.1, 0.9, 0.2, 0.8]),
             child='c')
     ])
     self.arr.append([('a', 'b'), ('b', 'c')])
     self.cbn = [CBN.from_bn(bn[i]) for i in bn.keys()]
コード例 #2
0
ファイル: test_IO.py プロジェクト: HunterAllman/kod
 def testdnet(self):
     from gPy.IO import read_dnet
     from gPy.Models import BN
     from gPy.Variables import Domain
     bnm = BN(domain=Domain())
     bnm.from_dnet(read_dnet('Asia.dnet'))
     self.samegraph(bnm.adg(),self.asia_adg)
     for name, cpt_in_file in self.asia_cpts.items():
         cpt = bnm[name]
         self.samecpt(cpt,cpt_in_file,cpt.child())
コード例 #3
0
 def testdnet(self):
     from gPy.IO import read_dnet
     from gPy.Models import BN
     from gPy.Variables import Domain
     bnm = BN(domain=Domain())
     bnm.from_dnet(read_dnet('Asia.dnet'))
     self.samegraph(bnm.adg(), self.asia_adg)
     for name, cpt_in_file in self.asia_cpts.items():
         cpt = bnm[name]
         self.samecpt(cpt, cpt_in_file, cpt.child())
コード例 #4
0
ファイル: utils_test.py プロジェクト: EJHortala/books-2
def rand_bn(vs, max_potential_parents=15):
    model = BN(domain=Domain(), new_domain_variables=vs)

    for child in vs.keys():
        parents = list(model.variables())
        too_many = len(parents) - max_potential_parents
        if too_many > 0:
            for i in xrange(too_many):
                parents.remove(choice(parents))

        fv = rand_subset(parents) | set([child])
        n = reduce(operator.mul, [len(vs[v]) for v in fv])
        f = Factor(variables=fv,
                   data=rand_factor_data(n),
                   domain=model,
                   check=True)
        cpt = CPT(f, child, True, True)
        model *= cpt
    return model
コード例 #5
0
ファイル: utils_test.py プロジェクト: HunterAllman/kod
def rand_bn(vs, max_potential_parents = 15):
    model = BN(domain = Domain(), new_domain_variables = vs)

    for child in vs.keys():
        parents = list(model.variables())
        too_many = len(parents) - max_potential_parents
        if too_many > 0:
            for i in xrange(too_many):
                parents.remove(choice(parents))

        fv = rand_subset(parents) | set([child])
        n = reduce(operator.mul, [len(vs[v]) for v in fv])
        f = Factor(variables = fv
              ,data      = rand_factor_data(n)
              ,domain    = model
              ,check     = True
              )
        cpt = CPT(f,child,True,True)
        model *= cpt
    return model
コード例 #6
0
 def setUp(self):
     from gPy.Variables import Domain
     bnm = BN(domain=Domain())
     bnm.from_dnet(read_dnet('Asia.dnet'))
     self.hypergraph = bnm._hypergraph
     self.adg = bnm._adg
     self.tarjan = UGraph(range(1,11),
                          ((1,2),(1,3),(2,3),(2,10),(3,10),(4,5),
                           (4,7),(5,6),(5,9),(5,7),(6,7),(6,9),
                           (7,8),(7,9),(8,9),(8,10),(9,10)))
     self.tarjan2 = UGraph(range(1,10),
                           ((1,4),(1,3),(2,3),(2,7),(3,5),(3,6),
                            (4,5),(4,8),(5,6),(5,8),(6,7),(6,9),
                            (7,9),(8,9)))
     self.tarjan3 = UGraph(range(1,10),
                           ((1,4),(1,3),(2,3),(2,7),(3,5),(3,6),
                            (4,5),(4,8),(5,6),(5,8),(6,7),(6,9),
                            (7,9),(8,9),
                            (3,4),(3,7),(4,6),(4,7),(5,7),(6,8),(7,8)))
     self.tarjanh1 = Hypergraph([[3,4],[2,4],[1,2,3]])
     self.tarjanh2 = Hypergraph([[3,4],[2,4],[1,2,3],[2,3,4]])
     self.graph1 = UGraph('ABCDEF',('AB','AC','BD','CE','EF'))
     self.graph2 = UGraph('ABCDEF',('AB','AC','BD','CE','EF','BC','CD','DE'))
コード例 #7
0
ファイル: test_Parameters.py プロジェクト: HunterAllman/kod
    def setUp(self):
        from gPy.Variables import Domain
        self.domain = Domain()
        self.bnm = BN(domain=self.domain)
        self.bnm.from_dnet(read_dnet('Asia.dnet'))
        self.cptdict = {}

        # taken directly from Netica output
        self.marginals = [
            Factor((('VisitAsia'),),
                   [0.99,0.01]),
            Factor((('Tuberculosis'),),
                   [0.9896,0.0104]),
            Factor((('Smoking'),),
                   [0.5,0.5]),
            Factor((('Cancer'),),
                   [0.945,0.055]),
            Factor((('TbOrCa'),),
                   [0.93517, 0.064828]),
            Factor((('XRay'),),
                   [0.11029, 0.88971]),
            Factor((('Bronchitis'),),
                   [0.55,0.45]),
            Factor((('Dyspnea'),),
                   [0.56403,0.43597])            
            ]
        # taken directly from Netica output
        self.cond_marginals = [
            Factor((('VisitAsia'),),
                   [0.95192,0.048077]),
            Factor((('Tuberculosis'),),
                   [0,1]),
            Factor((('Smoking'),),
                   [0.52381,0.47619]),
            #other marginals are conditional on these values
            #Factor((('Cancer'),),
            #       [1,0]),
            #Factor((('TbOrCa'),),
            #       [0,1]),
            Factor((('XRay'),),
                   [0.98, 0.02]),
            Factor((('Bronchitis'),),
                   [0.55714,0.44286]),
            Factor((('Dyspnea'),),
                   [0.21143,0.78857])            
            ]
        for cpt in self.bnm:
            self.cptdict[cpt.child()] = cpt

        self.rawdata = read_csv(open('alarm_1K.dat'))
コード例 #8
0
ファイル: test_Parameters.py プロジェクト: EJHortala/books-2
    def setUp(self):
        from gPy.Variables import Domain
        self.domain = Domain()
        self.bnm = BN(domain=self.domain)
        self.bnm.from_dnet(read_dnet('Asia.dnet'))
        self.cptdict = {}

        # taken directly from Netica output
        self.marginals = [
            Factor((('VisitAsia'), ), [0.99, 0.01]),
            Factor((('Tuberculosis'), ), [0.9896, 0.0104]),
            Factor((('Smoking'), ), [0.5, 0.5]),
            Factor((('Cancer'), ), [0.945, 0.055]),
            Factor((('TbOrCa'), ), [0.93517, 0.064828]),
            Factor((('XRay'), ), [0.11029, 0.88971]),
            Factor((('Bronchitis'), ), [0.55, 0.45]),
            Factor((('Dyspnea'), ), [0.56403, 0.43597])
        ]
        # taken directly from Netica output
        self.cond_marginals = [
            Factor((('VisitAsia'), ), [0.95192, 0.048077]),
            Factor((('Tuberculosis'), ), [0, 1]),
            Factor((('Smoking'), ), [0.52381, 0.47619]),
            #other marginals are conditional on these values
            #Factor((('Cancer'),),
            #       [1,0]),
            #Factor((('TbOrCa'),),
            #       [0,1]),
            Factor((('XRay'), ), [0.98, 0.02]),
            Factor((('Bronchitis'), ), [0.55714, 0.44286]),
            Factor((('Dyspnea'), ), [0.21143, 0.78857])
        ]
        for cpt in self.bnm:
            self.cptdict[cpt.child()] = cpt

        self.rawdata = read_csv(open('alarm_1K.dat'))
コード例 #9
0
ファイル: utils_test.py プロジェクト: EJHortala/books-2
def generate_dense_bn(density, num_vars=8, num_vals=3):
    if density > num_vars:
        raise RuntimeError, 'density must be less than number of variables'

    vars, parents = generate_dense_parents(density, num_vars)
    vals = dict([(var, frozenset([i for i in xrange(num_vals)]))
                 for var in vars])
    bn = BN(domain=Domain(), new_domain_variables=vals)
    for child in vars:
        if child in parents:
            n = num_vals**(len(parents[child]) + 1)
        else:
            n = num_vals
            parents[child] = frozenset()

        f = Factor(variables=frozenset([child]) | parents[child],
                   data=rand_factor_data(n),
                   domain=bn,
                   check=True)
        bn *= CPT(f, child, True, True)
    return bn
コード例 #10
0
ファイル: utils_test.py プロジェクト: HunterAllman/kod
from gPy.Examples import minibn, asia
from gPy.Models import FR,BN
from gPy.Parameters import Factor,CPT
from gPy.Variables import Domain
from random import choice,randrange,uniform,shuffle
import operator, unittest, pickle

xor = BN(domain=Domain(), new_domain_variables={'a': [0,1], 'b':[0,1], 'c':[0,1]})
xor.add_cpts([CPT(Factor(variables=['a'], data=[0.5, 0.5]),child='a')
             ,CPT(Factor(variables=['b'], data=[0.5, 0.5]),child='b')
             ,CPT(Factor(variables=['c','a','b'], data=[1.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0]),child='c')
             ])
cbn_small_names = ['xor','minibn','asia']
cbn_small_test_cases = [xor,minibn,asia]
cbn_large_names = ['alarm','insurance','carpo']
try:
    # load the pickled large Bayes nets.
    cbn_large_test_cases = map(lambda fn: pickle.load(open('networks/'+fn+'_bn.pck','r')),
                    cbn_large_names)
except:
    cbn_large_names = []
    cbn_large_test_cases = []

cbn_test_cases = cbn_small_test_cases + cbn_large_test_cases


def distribution_of(model):
    """Returns a normalised factor representing the joint instantiation
    of the model.
    """
コード例 #11
0
ファイル: utils_test.py プロジェクト: EJHortala/books-2
from gPy.Examples import minibn, asia
from gPy.Models import FR, BN
from gPy.Parameters import Factor, CPT
from gPy.Variables import Domain
from random import choice, randrange, uniform, shuffle
import operator, unittest, pickle

xor = BN(domain=Domain(),
         new_domain_variables={
             'a': [0, 1],
             'b': [0, 1],
             'c': [0, 1]
         })
xor.add_cpts([
    CPT(Factor(variables=['a'], data=[0.5, 0.5]), child='a'),
    CPT(Factor(variables=['b'], data=[0.5, 0.5]), child='b'),
    CPT(Factor(variables=['c', 'a', 'b'],
               data=[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]),
        child='c')
])
cbn_small_names = ['xor', 'minibn', 'asia']
cbn_small_test_cases = [xor, minibn, asia]
cbn_large_names = ['alarm', 'insurance', 'carpo']
try:
    # load the pickled large Bayes nets.
    cbn_large_test_cases = map(
        lambda fn: pickle.load(open('networks/' + fn + '_bn.pck', 'r')),
        cbn_large_names)
except:
    cbn_large_names = []
    cbn_large_test_cases = []
コード例 #12
0
ファイル: test_Parameters.py プロジェクト: EJHortala/books-2
class TestParameters(unittest.TestCase):
    def setUp(self):
        from gPy.Variables import Domain
        self.domain = Domain()
        self.bnm = BN(domain=self.domain)
        self.bnm.from_dnet(read_dnet('Asia.dnet'))
        self.cptdict = {}

        # taken directly from Netica output
        self.marginals = [
            Factor((('VisitAsia'), ), [0.99, 0.01]),
            Factor((('Tuberculosis'), ), [0.9896, 0.0104]),
            Factor((('Smoking'), ), [0.5, 0.5]),
            Factor((('Cancer'), ), [0.945, 0.055]),
            Factor((('TbOrCa'), ), [0.93517, 0.064828]),
            Factor((('XRay'), ), [0.11029, 0.88971]),
            Factor((('Bronchitis'), ), [0.55, 0.45]),
            Factor((('Dyspnea'), ), [0.56403, 0.43597])
        ]
        # taken directly from Netica output
        self.cond_marginals = [
            Factor((('VisitAsia'), ), [0.95192, 0.048077]),
            Factor((('Tuberculosis'), ), [0, 1]),
            Factor((('Smoking'), ), [0.52381, 0.47619]),
            #other marginals are conditional on these values
            #Factor((('Cancer'),),
            #       [1,0]),
            #Factor((('TbOrCa'),),
            #       [0,1]),
            Factor((('XRay'), ), [0.98, 0.02]),
            Factor((('Bronchitis'), ), [0.55714, 0.44286]),
            Factor((('Dyspnea'), ), [0.21143, 0.78857])
        ]
        for cpt in self.bnm:
            self.cptdict[cpt.child()] = cpt

        self.rawdata = read_csv(open('alarm_1K.dat'))

    def test_str(self):
        out = ''
        for node in [
                'VisitAsia', 'Tuberculosis', 'Smoking', 'Cancer', 'TbOrCa',
                'XRay', 'Bronchitis', 'Dyspnea'
        ]:
            out += str(self.cptdict[node])
        self.assertEqual(out, open('Asia.cpts').read())

    def test_cptcheck(self):
        # be nice to check that this is printed out
        errmsg = """
        For child:	Dyspnea
        For row:	Bronchitis=Absent, Dyspnea=Absent, TbOrCa=True
        Sum was:	 0.90 (should be 1.0)
        """

        def bnfromfile(filename):
            x = BN(domain=self.domain)
            x.from_dnet(read_dnet(filename))
            return x

        self.assertRaises(CPTError, bnfromfile, open('Asia_wrong.dnet'))

    def test_multinplace(self):
        factor = self.cptdict['VisitAsia'] * self.cptdict['Tuberculosis']
        fid = id(factor)
        factor *= self.cptdict['Bronchitis']
        self.assertEqual(fid, id(factor))
        factor2 = self.cptdict['VisitAsia'] * \
                  self.cptdict['Tuberculosis'] * self.cptdict['Bronchitis']
        self.samefactor(factor, factor2)

    def test_scalarprod(self):
        for factor in self.bnm:
            for scalar in 0, 0.3, 7:
                tf1 = scalar * factor
                tf2 = factor * scalar
                self.samefactor(tf1, tf2)
                for i, val in enumerate(tf1._data):
                    self.assertAlmostEqual(val, scalar * factor._data[i],
                                           places)

    def test_instbyname(self):
        dyspnea = self.cptdict['Dyspnea']
        datum = dyspnea[{
            'Dyspnea': 'Present',
            'Bronchitis': 'Present',
            'TbOrCa': 'False'
        }]
        self.assertAlmostEqual(datum, 0.8, places)

    def test_getitem(self):
        dyspnea = self.cptdict['Dyspnea']
        datum = dyspnea['Present', 'Present', 'False']
        self.assertAlmostEqual(datum, 0.8, places)

    def test_z(self):
        factor = 1
        order = [
            'VisitAsia', 'Tuberculosis', 'Smoking', 'Cancer', 'TbOrCa', 'XRay',
            'Bronchitis', 'Dyspnea'
        ]
        for name in order:
            newfactor = factor * self.cptdict[name]
            factor *= self.cptdict[name]
            self.samefactor(factor, newfactor)
            self.assertAlmostEqual(factor.z(), 1, places)
        self.assertEqual(str(factor), open('Asia.joint').read())
        factor = 1
        for cpt in self.cptdict.values():
            factor *= cpt
        self.assertAlmostEqual(factor.z(), 1, places)

    def test_marginals(self):
        joint = 1
        for cpt in self.bnm:
            joint *= cpt
        order = [
            'VisitAsia', 'Tuberculosis', 'Smoking', 'Cancer', 'TbOrCa', 'XRay',
            'Bronchitis', 'Dyspnea'
        ]
        varset = set(order)
        for i, var in enumerate(order):
            tmp = varset.copy()
            tmp.remove(var)
            self.samefactor(self.marginals[i], joint.sumout(tmp))

    def test_factorprodcommute(self):
        for f1 in self.bnm:
            for f2 in self.bnm:
                tf1 = f1 * f2
                tf2 = f2 * f1
                self.samefactor(tf1, tf2)

    def test_restrict(self):
        bnm = self.bnm.copy(copy_domain=True)
        given = {'Cancer': ['Absent'], 'TbOrCa': ['True']}
        bnm.condition(given)
        joint = 1
        for cpt in bnm:
            joint *= cpt
        order = [
            'VisitAsia', 'Tuberculosis', 'Smoking', 'XRay', 'Bronchitis',
            'Dyspnea'
        ]
        varset = set(order)
        varset.update(['Cancer', 'TbOrCa'])
        for i, var in enumerate(order):
            marginal = joint.sumout(varset - set([var]))
            marginal /= marginal.z()
            self.samefactor(self.cond_marginals[i], marginal)

    def samefactor(self, tf1, tf2):
        self.assertEqual(tf1.variables(), tf2.variables())
        tf2data = tf2.data()
        for i, val in enumerate(tf1.data()):
            self.assertAlmostEqual(val, tf2data[i], places)
コード例 #13
0

def disp(fn, samples):
    f = open(fn, 'w')
    fact = samples.makeFactor(samples.variables())
    for var in fact.variables():
        print >> f, var,
    print >> f, 'count'
    for inst in fact.insts():
        for i in inst:
            print >> f, i,
        print >> f, fact[inst]
    f.close()


bn0 = BN(domain=Domain(), new_domain_variables={'a': [0, 1], 'b': [0, 1]})
bn0.add_cpts([
    CPT(Factor(variables=['a'], data=[0.5, 0.5]), child='a'),
    CPT(Factor(variables=['a', 'b'], data=[0.3, 0.7, 0.4, 0.6]), child='b')
])
w = CausalWorld(bn0)
samples = w.observe(10000)
disp('two_depend', samples)

bn1 = BN(domain=Domain(), new_domain_variables={'a': [0, 1], 'b': [0, 1]})
bn1.add_cpts([
    CPT(Factor(variables=['a'], data=[0.5, 0.5]), child='a'),
    CPT(Factor(variables=['b'], data=[0.3, 0.7]), child='b')
])
w = CausalWorld(bn1)
samples = w.observe(10000)
コード例 #14
0
ファイル: test_Parameters.py プロジェクト: EJHortala/books-2
 def bnfromfile(filename):
     x = BN(domain=self.domain)
     x.from_dnet(read_dnet(filename))
     return x
コード例 #15
0
ファイル: gen_chi2.py プロジェクト: HunterAllman/kod
from gPy.Variables import Domain
from gPy.LearningUtils import CausalWorld

def disp(fn, samples):
    f = open(fn, 'w')
    fact = samples.makeFactor(samples.variables())
    for var in fact.variables():
        print >>f, var,
    print >>f, 'count'
    for inst in fact.insts():
        for i in inst:
            print >>f, i,
        print >>f, fact[inst]
    f.close()

bn0 = BN(domain=Domain(), new_domain_variables={'a': [0,1], 'b':[0,1]})
bn0.add_cpts([CPT(Factor(variables=['a'], data=[0.5, 0.5]),child='a')
             ,CPT(Factor(variables=['a','b'], data=[0.3, 0.7, 0.4, 0.6]),child='b')
             ])
w = CausalWorld(bn0)
samples = w.observe(10000)
disp('two_depend', samples)

bn1 = BN(domain=Domain(), new_domain_variables={'a': [0,1], 'b':[0,1]})
bn1.add_cpts([CPT(Factor(variables=['a'], data=[0.5, 0.5]),child='a')
             ,CPT(Factor(variables=['b'], data=[0.3, 0.7]),child='b')
             ])
w = CausalWorld(bn1)
samples = w.observe(10000)
disp('two_independ', samples)
コード例 #16
0
ファイル: test_Parameters.py プロジェクト: HunterAllman/kod
class TestParameters(unittest.TestCase):

    def setUp(self):
        from gPy.Variables import Domain
        self.domain = Domain()
        self.bnm = BN(domain=self.domain)
        self.bnm.from_dnet(read_dnet('Asia.dnet'))
        self.cptdict = {}

        # taken directly from Netica output
        self.marginals = [
            Factor((('VisitAsia'),),
                   [0.99,0.01]),
            Factor((('Tuberculosis'),),
                   [0.9896,0.0104]),
            Factor((('Smoking'),),
                   [0.5,0.5]),
            Factor((('Cancer'),),
                   [0.945,0.055]),
            Factor((('TbOrCa'),),
                   [0.93517, 0.064828]),
            Factor((('XRay'),),
                   [0.11029, 0.88971]),
            Factor((('Bronchitis'),),
                   [0.55,0.45]),
            Factor((('Dyspnea'),),
                   [0.56403,0.43597])            
            ]
        # taken directly from Netica output
        self.cond_marginals = [
            Factor((('VisitAsia'),),
                   [0.95192,0.048077]),
            Factor((('Tuberculosis'),),
                   [0,1]),
            Factor((('Smoking'),),
                   [0.52381,0.47619]),
            #other marginals are conditional on these values
            #Factor((('Cancer'),),
            #       [1,0]),
            #Factor((('TbOrCa'),),
            #       [0,1]),
            Factor((('XRay'),),
                   [0.98, 0.02]),
            Factor((('Bronchitis'),),
                   [0.55714,0.44286]),
            Factor((('Dyspnea'),),
                   [0.21143,0.78857])            
            ]
        for cpt in self.bnm:
            self.cptdict[cpt.child()] = cpt

        self.rawdata = read_csv(open('alarm_1K.dat'))

        
    def test_str(self):
        out = ''
        for node in ['VisitAsia','Tuberculosis','Smoking','Cancer',
                     'TbOrCa','XRay','Bronchitis','Dyspnea']:
            out += str(self.cptdict[node])
        self.assertEqual(out,open('Asia.cpts').read())

    def test_cptcheck(self):
        # be nice to check that this is printed out
        errmsg = """
        For child:	Dyspnea
        For row:	Bronchitis=Absent, Dyspnea=Absent, TbOrCa=True
        Sum was:	 0.90 (should be 1.0)
        """
        def bnfromfile(filename): x= BN(domain=self.domain); x.from_dnet(read_dnet(filename)); return x
        self.assertRaises(CPTError,bnfromfile,open('Asia_wrong.dnet'))

    def test_multinplace(self):
        factor = self.cptdict['VisitAsia'] * self.cptdict['Tuberculosis']
        fid = id(factor)
        factor *= self.cptdict['Bronchitis']
        self.assertEqual(fid,id(factor))
        factor2 = self.cptdict['VisitAsia'] * \
                  self.cptdict['Tuberculosis'] * self.cptdict['Bronchitis'] 
        self.samefactor(factor,factor2)
    
    def test_scalarprod(self):
        for factor in self.bnm:
            for scalar in 0,0.3,7:
                tf1 = scalar * factor
                tf2 = factor * scalar
                self.samefactor(tf1,tf2)
                for i, val in enumerate(tf1._data):
                    self.assertAlmostEqual(val,scalar * factor._data[i],places)

    def test_instbyname(self):
        dyspnea = self.cptdict['Dyspnea']
        datum = dyspnea[{'Dyspnea':'Present',
                                       'Bronchitis':'Present',
                                       'TbOrCa':'False'}]
        self.assertAlmostEqual(datum,0.8,places)
        
    def test_getitem(self):
        dyspnea = self.cptdict['Dyspnea']
        datum = dyspnea['Present','Present','False']
        self.assertAlmostEqual(datum,0.8,places)

    def test_z(self):
        factor = 1
        order = ['VisitAsia','Tuberculosis','Smoking',
                 'Cancer','TbOrCa','XRay','Bronchitis',
                 'Dyspnea']
        for name in order:
            newfactor = factor * self.cptdict[name]
            factor *= self.cptdict[name]
            self.samefactor(factor,newfactor)
            self.assertAlmostEqual(factor.z(),1,places)
        self.assertEqual(str(factor),open('Asia.joint').read())
        factor = 1
        for cpt in self.cptdict.values():
            factor *= cpt
        self.assertAlmostEqual(factor.z(),1,places)

    def test_marginals(self):
        joint = 1
        for cpt in self.bnm:
            joint *= cpt
        order = ['VisitAsia','Tuberculosis','Smoking',
                 'Cancer','TbOrCa','XRay','Bronchitis',
                 'Dyspnea']
        varset = set(order)
        for i, var in enumerate(order):
            tmp = varset.copy()
            tmp.remove(var)
            self.samefactor(self.marginals[i],joint.sumout(tmp))
            
    def test_factorprodcommute(self):
        for f1 in self.bnm:
            for f2 in self.bnm:
                tf1 = f1 * f2
                tf2 = f2 * f1
                self.samefactor(tf1,tf2)

    def test_restrict(self):
        bnm = self.bnm.copy(copy_domain=True)
        given = {'Cancer':['Absent'],'TbOrCa':['True']}
        bnm.condition(given)
        joint = 1
        for cpt in bnm:
            joint *= cpt
        order = ['VisitAsia','Tuberculosis','Smoking',
                 'XRay','Bronchitis','Dyspnea']
        varset = set(order)
        varset.update(['Cancer','TbOrCa'])
        for i, var in enumerate(order):
            marginal = joint.sumout(varset - set([var]))
            marginal /= marginal.z()
            self.samefactor(self.cond_marginals[i],marginal)

    def samefactor(self,tf1,tf2):
        self.assertEqual(tf1.variables(),tf2.variables())
        tf2data = tf2.data()
        for i, val in enumerate(tf1.data()):
            self.assertAlmostEqual(val,tf2data[i],places)
コード例 #17
0
ファイル: test_Parameters.py プロジェクト: HunterAllman/kod
 def bnfromfile(filename): x= BN(domain=self.domain); x.from_dnet(read_dnet(filename)); return x
 self.assertRaises(CPTError,bnfromfile,open('Asia_wrong.dnet'))