def test_markdown(self): """Test whether the _repr_markdown_ method.""" A = (2, 2) B = (1, 2) C = (0, 3) parents = DiscreteKeys() parents.push_back(B) parents.push_back(C) conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") expected = " *P(A|B,C):*\n\n" \ "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ "|0|1|0.25|0.75|\n" \ "|0|2|0.5|0.5|\n" \ "|1|0|0.75|0.25|\n" \ "|1|1|0|1|\n" \ "|1|2|1|0|\n" def formatter(x: int): names = ["C", "B", "A"] return names[x] actual = conditional._repr_markdown_(formatter) self.assertEqual(actual, expected)
def test_marginals(self): conditional = DiscreteConditional(A, [B], "1/2 2/1") prior = DiscreteConditional(B, "1/2") pAB = prior * conditional self.gtsamAssertEquals(prior, pAB.marginal(B[0])) pA = DiscreteConditional(A, "5/4") self.gtsamAssertEquals(pA, pAB.marginal(A[0]))
def test_multiply(self): """Check calculation of joint P(A,B)""" conditional = DiscreteConditional(A, [B], "1/2 2/1") prior = DiscreteConditional(B, "1/2") # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) for actual in [prior * conditional, conditional * prior]: self.assertEqual(2, actual.nrFrontals()) for v, value in actual.enumerate(): self.assertAlmostEqual(actual(v), conditional(v) * prior(v))
def test_multiply2(self): """Check calculation of conditional joint P(A,B|C)""" A_given_B = DiscreteConditional(A, [B], "1/3 3/1") B_given_C = DiscreteConditional(B, [C], "1/3 3/1") # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) for actual in [A_given_B * B_given_C, B_given_C * A_given_B]: self.assertEqual(2, actual.nrFrontals()) self.assertEqual(1, actual.nrParents()) for v, value in actual.enumerate(): self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v))
def test_multiply4(self): """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)""" A_given_B = DiscreteConditional(A, [B], "1/3 3/1") B_given_D = DiscreteConditional(B, [D], "1/3 3/1") AB_given_D = A_given_B * B_given_D C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4") # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]: self.assertEqual(3, actual.nrFrontals()) self.assertEqual(2, actual.nrParents()) for v, value in actual.enumerate(): self.assertAlmostEqual(actual(v), AB_given_D(v) * C_given_DE(v))
def test_single_value_versions(self): X = (0, 2) Y = (1, 3) conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") actual0 = conditional.likelihood(0) expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") self.gtsamAssertEquals(actual0, expected0, 1e-9) actual1 = conditional.likelihood(1) expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") self.gtsamAssertEquals(actual1, expected1, 1e-9) actual = conditional.sample(2) self.assertIsInstance(actual, int)
def test_constructor(self): """Test constructing a Bayes net.""" bayesNet = DiscreteBayesNet() Parent, Child = (0, 2), (1, 2) empty = DiscreteKeys() prior = DiscreteConditional(Parent, empty, "6/4") bayesNet.add(prior) parents = DiscreteKeys() parents.push_back(Parent) conditional = DiscreteConditional(Child, parents, "7/3 8/2") bayesNet.add(conditional) # Check conversion to factor graph: fg = DiscreteFactorGraph(bayesNet) self.assertEqual(fg.size(), 2) self.assertEqual(fg.at(1).size(), 2)