def setUp(self): cNode = Node.BayesNode(0, 2, name="cloudy") sNode = Node.BayesNode(1, 2, name="sprinkler") rNode = Node.BayesNode(2, 2, name="rain") wNode = Node.BayesNode(3, 2, name="wetgrass") #cloudy cNode.add_child(sNode) cNode.add_child(rNode) #sprinkler sNode.add_parent(cNode) sNode.add_child(wNode) #rain rNode.add_parent(cNode) rNode.add_child(wNode) #wetgrass wNode.add_parent(sNode) wNode.add_parent(rNode) self.nodes = [cNode, sNode, rNode, wNode] #create distributions #cloudy distribution cDistribution = Distribution.DiscreteDistribution(cNode) index = cDistribution.generate_index([], []) cDistribution[index] = 0.5 cNode.set_dist(cDistribution) #sprinkler dist = zeros([cNode.size(), sNode.size()], type=Float32) dist[0, ] = 0.5 dist[1, ] = [0.9, 0.1] sDistribution = Distribution.ConditionalDiscreteDistribution( nodes=[cNode, sNode], table=dist) sNode.set_dist(sDistribution) #rain dist = zeros([cNode.size(), rNode.size()], type=Float32) dist[0, ] = [0.8, 0.2] dist[1, ] = [0.2, 0.8] rDistribution = Distribution.ConditionalDiscreteDistribution( nodes=[cNode, rNode], table=dist) rNode.set_dist(rDistribution) #wetgrass dist = zeros([sNode.size(), rNode.size(), wNode.size()], type=Float32) dist[0, 0, ] = [1.0, 0.0] dist[1, 0, ] = [0.1, 0.9] dist[0, 1, ] = [0.1, 0.9] dist[1, 1, ] = [0.01, 0.99] wgDistribution = Distribution.ConditionalDiscreteDistribution( nodes=[sNode, rNode, wNode], table=dist) wNode.set_dist(wgDistribution) #create bayes net self.bnet = Graph.BayesNet(self.nodes) self.engine = Inference.InferenceEngine(self.bnet)
def testBasicSetIndex(self): """ Test that indices are correct relative to each other, using very basic network structure """ self.graph = Graph.DAG(self.nodes) assert(self.nodes[0].index > self.nodes[4].index and \ self.nodes[1].index > self.nodes[0].index and \ self.nodes[1].index > self.nodes[3].index and \ self.nodes[2].index > self.nodes[0].index), \ "Indexes were not set properly in DAG.topological_sort()"
def testAllIndexSet(self): """ Test that all indices are >= 0 """ self.graph = Graph.DAG(self.nodes) for node in self.nodes: assert (node.index >= 0), "Index was less than 0"