def test_constructor(self): """Test various constructors.""" keys = DiscreteKeys() keys.push_back(X) f = DecisionTreeFactor(keys, "0.4 0.6") expected = DiscreteDistribution(f) actual = DiscreteDistribution(X, "2/3") self.gtsamAssertEquals(actual, expected) actual2 = DiscreteDistribution(X, [0.4, 0.6]) self.gtsamAssertEquals(actual2, expected)
def test_markdown(self): """Test the _repr_markdown_ method.""" prior = DiscreteDistribution(X, "2/3") expected = " *P(0):*\n\n" \ "|0|value|\n" \ "|:-:|:-:|\n" \ "|0|0.4|\n" \ "|1|0.6|\n" \ actual = prior._repr_markdown_() self.assertEqual(actual, expected)
def test_Asia(self): """Test full Asia example.""" asia = DiscreteBayesNet() asia.add(Asia, "99/1") asia.add(Smoking, "50/50") asia.add(Tuberculosis, [Asia], "99/1 95/5") asia.add(LungCancer, [Smoking], "99/1 90/10") asia.add(Bronchitis, [Smoking], "70/30 40/60") asia.add(Either, [Tuberculosis, LungCancer], "F T T T") asia.add(XRay, [Either], "95/5 2/98") asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") # Convert to factor graph fg = DiscreteFactorGraph(asia) # Create solver and eliminate ordering = Ordering() for j in range(8): ordering.push_back(j) chordal = fg.eliminateSequential(ordering) expected2 = DiscreteDistribution(Bronchitis, "11/9") self.gtsamAssertEquals(chordal.at(7), expected2) # solve actualMPE = fg.optimize() expectedMPE = DiscreteValues() for key in [ Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis ]: expectedMPE[key[0]] = 0 self.assertEqual(list(actualMPE.items()), list(expectedMPE.items())) # Check value for MPE is the same self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) # add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1") fg.add(Dyspnea, "0 1") # solve again, now with evidence actualMPE2 = fg.optimize() expectedMPE2 = DiscreteValues() for key in [XRay, Tuberculosis, Either, LungCancer]: expectedMPE2[key[0]] = 0 for key in [Asia, Dyspnea, Smoking, Bronchitis]: expectedMPE2[key[0]] = 1 self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items())) # now sample from it chordal2 = fg.eliminateSequential(ordering) actualSample = chordal2.sample() self.assertEqual(len(actualSample), 8)
def test_multiplication(self): """Test whether multiplication works with overloading.""" v0 = (0, 2) v1 = (1, 2) v2 = (2, 2) # Multiply with a DiscreteDistribution, i.e., Bayes Law! prior = DiscreteDistribution(v1, [1, 3]) f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) self.gtsamAssertEquals(f1 * prior, expected) # Multiply two factors f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") actual = f1 * f2 expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") self.gtsamAssertEquals(actual, expected2)
def test_sample(self): prior = DiscreteDistribution(X, "2/3") actual = prior.sample() self.assertIsInstance(actual, int)
def test_pmf(self): prior = DiscreteDistribution(X, "2/3") expected = np.array([0.4, 0.6]) np.testing.assert_allclose(expected, prior.pmf())
def test_operator(self): prior = DiscreteDistribution(X, "2/3") self.assertAlmostEqual(prior(0), 0.4) self.assertAlmostEqual(prior(1), 0.6)