def test_Asia(self): """Test full Asia example.""" asia = DiscreteBayesNet() asia.add(Asia, "99/1") asia.add(Smoking, "50/50") asia.add(Tuberculosis, [Asia], "99/1 95/5") asia.add(LungCancer, [Smoking], "99/1 90/10") asia.add(Bronchitis, [Smoking], "70/30 40/60") asia.add(Either, [Tuberculosis, LungCancer], "F T T T") asia.add(XRay, [Either], "95/5 2/98") asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") # Convert to factor graph fg = DiscreteFactorGraph(asia) # Create solver and eliminate ordering = Ordering() for j in range(8): ordering.push_back(j) chordal = fg.eliminateSequential(ordering) expected2 = DiscreteDistribution(Bronchitis, "11/9") self.gtsamAssertEquals(chordal.at(7), expected2) # solve actualMPE = fg.optimize() expectedMPE = DiscreteValues() for key in [ Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis ]: expectedMPE[key[0]] = 0 self.assertEqual(list(actualMPE.items()), list(expectedMPE.items())) # Check value for MPE is the same self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) # add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1") fg.add(Dyspnea, "0 1") # solve again, now with evidence actualMPE2 = fg.optimize() expectedMPE2 = DiscreteValues() for key in [XRay, Tuberculosis, Either, LungCancer]: expectedMPE2[key[0]] = 0 for key in [Asia, Dyspnea, Smoking, Bronchitis]: expectedMPE2[key[0]] = 1 self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items())) # now sample from it chordal2 = fg.eliminateSequential(ordering) actualSample = chordal2.sample() self.assertEqual(len(actualSample), 8)
def test_elimination(self): """Test Multifrontal elimination.""" # Define DiscreteKey pairs. keys = [(j, 2) for j in range(15)] # Create thin-tree Bayesnet. bayesNet = DiscreteBayesNet() bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4") bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1") bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1") bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1") bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4") bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1") bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1") bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1") bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4") bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1") bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1") bayesNet.add(keys[12], [keys[14]], "3/1 3/1") bayesNet.add(keys[13], [keys[14]], "1/3 3/1") bayesNet.add(keys[14], "1/3") # Create a factor graph out of the Bayes net. factorGraph = DiscreteFactorGraph(bayesNet) # Create a BayesTree out of the factor graph. ordering = Ordering() for j in range(15): ordering.push_back(j) bayesTree = factorGraph.eliminateMultifrontal(ordering) # Uncomment these for visualization: # print(bayesTree) # for key in range(15): # bayesTree[key].printSignature() # bayesTree.saveGraph("test_DiscreteBayesTree.dot") self.assertFalse(bayesTree.empty()) self.assertEqual(12, bayesTree.size()) # The root is P( 8 12 14), we can retrieve it by key: root = bayesTree[8] self.assertIsInstance(root, DiscreteBayesTreeClique) self.assertTrue(root.isRoot()) self.assertIsInstance(root.conditional(), DiscreteConditional)
def test_constructor(self): """Test constructing a Bayes net.""" bayesNet = DiscreteBayesNet() Parent, Child = (0, 2), (1, 2) empty = DiscreteKeys() prior = DiscreteConditional(Parent, empty, "6/4") bayesNet.add(prior) parents = DiscreteKeys() parents.push_back(Parent) conditional = DiscreteConditional(Child, parents, "7/3 8/2") bayesNet.add(conditional) # Check conversion to factor graph: fg = DiscreteFactorGraph(bayesNet) self.assertEqual(fg.size(), 2) self.assertEqual(fg.at(1).size(), 2)
def test_sumProduct(self): """Test sumProduct.""" # Declare a bunch of keys C, A, B = (0, 2), (1, 2), (2, 2) # Create Factor graph graph = DiscreteFactorGraph() graph.add([C, A], "0.2 0.8 0.3 0.7") graph.add([C, B], "0.1 0.9 0.4 0.6") # We know MPE mpe = DiscreteValues() mpe[0] = 0 mpe[1] = 1 mpe[2] = 1 # Use default sumProduct bayesNet = graph.sumProduct() mpeProbability = bayesNet(mpe) self.assertAlmostEqual(mpeProbability, 0.36) # regression # Use sumProduct for ordering_type in [ OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, OrderingType.CUSTOM ]: bayesNet = graph.sumProduct(ordering_type) self.assertEqual(bayesNet(mpe), mpeProbability)
def test_MPE(self): """Test maximum probable explanation (MPE): same as optimize.""" # Declare a bunch of keys C, A, B = (0, 2), (1, 2), (2, 2) # Create Factor graph graph = DiscreteFactorGraph() graph.add([C, A], "0.2 0.8 0.3 0.7") graph.add([C, B], "0.1 0.9 0.4 0.6") # We know MPE mpe = DiscreteValues() mpe[0] = 0 mpe[1] = 1 mpe[2] = 1 # Use maxProduct dag = graph.maxProduct(OrderingType.COLAMD) actualMPE = dag.argmax() self.assertEqual(list(actualMPE.items()), list(mpe.items())) # All in one actualMPE2 = graph.optimize() self.assertEqual(list(actualMPE2.items()), list(mpe.items()))
def test_optimize(self): """Test constructing and optizing a discrete factor graph.""" # Three keys C = (0, 2) B = (1, 2) A = (2, 2) # A simple factor graph (A)-fAC-(C)-fBC-(B) # with smoothness priors graph = DiscreteFactorGraph() graph.add([A, C], "3 1 1 3") graph.add([C, B], "3 1 1 3") # Test optimization expectedValues = DiscreteValues() expectedValues[0] = 0 expectedValues[1] = 0 expectedValues[2] = 0 actualValues = graph.optimize() self.assertEqual(list(actualValues.items()), list(expectedValues.items()))
def test_evaluation(self): """Test constructing and evaluating a discrete factor graph.""" # Three keys P1 = (0, 2) P2 = (1, 2) P3 = (2, 3) # Create the DiscreteFactorGraph graph = DiscreteFactorGraph() # Add two unary factors (priors) graph.add(P1, [0.9, 0.3]) graph.add(P2, "0.9 0.6") # Add a binary factor graph.add([P1, P2], "4 1 10 4") # Instantiate Values assignment = DiscreteValues() assignment[0] = 1 assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) self.assertAlmostEqual(.72, graph(assignment)) # Create a new test with third node and adding unary and ternary factor graph.add(P3, "0.9 0.2 0.5") keys = DiscreteKeys() keys.push_back(P1) keys.push_back(P2) keys.push_back(P3) graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") # Below assignment selects the 8th index in the ternary factor table assignment[0] = 1 assignment[1] = 0 assignment[2] = 1 # Check if graph evaluation works (0.3*0.9*1*0.2*8) self.assertAlmostEqual(4.32, graph(assignment)) # Below assignment selects the 3rd index in the ternary factor table assignment[0] = 0 assignment[1] = 1 assignment[2] = 0 # Check if graph evaluation works (0.9*0.6*1*0.9*4) self.assertAlmostEqual(1.944, graph(assignment)) # Check if graph product works product = graph.product() self.assertAlmostEqual(1.944, product(assignment))