Beispiel #1
0
    def test_dot(self):
        """Check that dot works with position hints."""
        fragment = DiscreteBayesNet()
        fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
        MyAsia = gtsam.symbol('a', 0), 2  # use a symbol!
        fragment.add(Tuberculosis, [MyAsia], "99/1 95/5")
        fragment.add(LungCancer, [Smoking], "99/1 90/10")

        # Make sure we can *update* position hints
        writer = gtsam.DotWriter()
        ph: dict = writer.positionHints
        ph.update({'a': 2})  # hint at symbol position
        writer.positionHints = ph

        # Check the output of dot
        actual = fragment.dot(writer=writer)
        expected_result = """\
            digraph {
              size="5,5";

              var3[label="3"];
              var4[label="4"];
              var5[label="5"];
              var6[label="6"];
              var6989586621679009792[label="a0", pos="0,2!"];

              var4->var6
              var6989586621679009792->var3
              var3->var5
              var6->var5
            }"""
        self.assertEqual(actual, textwrap.dedent(expected_result))
Beispiel #2
0
    def test_fragment(self):
        """Test sampling and optimizing for Asia fragment."""

        # Create a reverse-topologically sorted fragment:
        fragment = DiscreteBayesNet()
        fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
        fragment.add(Tuberculosis, [Asia], "99/1 95/5")
        fragment.add(LungCancer, [Smoking], "99/1 90/10")

        # Create assignment with missing values:
        given = DiscreteValues()
        for key in [Asia, Smoking]:
            given[key[0]] = 0

        # Now sample from fragment:
        actual = fragment.sample(given)
        self.assertEqual(len(actual), 5)
Beispiel #3
0
    def test_constructor(self):
        """Test constructing a Bayes net."""

        bayesNet = DiscreteBayesNet()
        Parent, Child = (0, 2), (1, 2)
        empty = DiscreteKeys()
        prior = DiscreteConditional(Parent, empty, "6/4")
        bayesNet.add(prior)

        parents = DiscreteKeys()
        parents.push_back(Parent)
        conditional = DiscreteConditional(Child, parents, "7/3 8/2")
        bayesNet.add(conditional)

        # Check conversion to factor graph:
        fg = DiscreteFactorGraph(bayesNet)
        self.assertEqual(fg.size(), 2)
        self.assertEqual(fg.at(1).size(), 2)
Beispiel #4
0
    def test_elimination(self):
        """Test Multifrontal elimination."""

        # Define DiscreteKey pairs.
        keys = [(j, 2) for j in range(15)]

        # Create thin-tree Bayesnet.
        bayesNet = DiscreteBayesNet()

        bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
        bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4")
        bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1")
        bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1")

        bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1")
        bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4")
        bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1")
        bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1")

        bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1")
        bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4")
        bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1")
        bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1")

        bayesNet.add(keys[12], [keys[14]], "3/1 3/1")
        bayesNet.add(keys[13], [keys[14]], "1/3 3/1")

        bayesNet.add(keys[14], "1/3")

        # Create a factor graph out of the Bayes net.
        factorGraph = DiscreteFactorGraph(bayesNet)

        # Create a BayesTree out of the factor graph.
        ordering = Ordering()
        for j in range(15):
            ordering.push_back(j)
        bayesTree = factorGraph.eliminateMultifrontal(ordering)

        # Uncomment these for visualization:
        # print(bayesTree)
        # for key in range(15):
        #     bayesTree[key].printSignature()
        # bayesTree.saveGraph("test_DiscreteBayesTree.dot")

        self.assertFalse(bayesTree.empty())
        self.assertEqual(12, bayesTree.size())

        # The root is P( 8 12 14), we can retrieve it by key:
        root = bayesTree[8]
        self.assertIsInstance(root, DiscreteBayesTreeClique)
        self.assertTrue(root.isRoot())
        self.assertIsInstance(root.conditional(), DiscreteConditional)
Beispiel #5
0
    def test_Asia(self):
        """Test full Asia example."""

        asia = DiscreteBayesNet()
        asia.add(Asia, "99/1")
        asia.add(Smoking, "50/50")

        asia.add(Tuberculosis, [Asia], "99/1 95/5")
        asia.add(LungCancer, [Smoking], "99/1 90/10")
        asia.add(Bronchitis, [Smoking], "70/30 40/60")

        asia.add(Either, [Tuberculosis, LungCancer], "F T T T")

        asia.add(XRay, [Either], "95/5 2/98")
        asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9")

        # Convert to factor graph
        fg = DiscreteFactorGraph(asia)

        # Create solver and eliminate
        ordering = Ordering()
        for j in range(8):
            ordering.push_back(j)
        chordal = fg.eliminateSequential(ordering)
        expected2 = DiscreteDistribution(Bronchitis, "11/9")
        self.gtsamAssertEquals(chordal.at(7), expected2)

        # solve
        actualMPE = fg.optimize()
        expectedMPE = DiscreteValues()
        for key in [
                Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
                Bronchitis
        ]:
            expectedMPE[key[0]] = 0
        self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))

        # Check value for MPE is the same
        self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))

        # add evidence, we were in Asia and we have dyspnea
        fg.add(Asia, "0 1")
        fg.add(Dyspnea, "0 1")

        # solve again, now with evidence
        actualMPE2 = fg.optimize()
        expectedMPE2 = DiscreteValues()
        for key in [XRay, Tuberculosis, Either, LungCancer]:
            expectedMPE2[key[0]] = 0
        for key in [Asia, Dyspnea, Smoking, Bronchitis]:
            expectedMPE2[key[0]] = 1
        self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))

        # now sample from it
        chordal2 = fg.eliminateSequential(ordering)
        actualSample = chordal2.sample()
        self.assertEqual(len(actualSample), 8)