Example #1
0
class TestClusterGraphFactorOperations(unittest.TestCase):
    def setUp(self):
        self.graph = ClusterGraph()

    def test_add_single_factor(self):
        self.graph.add_node(('a', 'b'))
        phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        self.graph.add_factors(phi1)
        six.assertCountEqual(self, self.graph.factors, [phi1])

    def test_add_single_factor_raises_error(self):
        self.graph.add_node(('a', 'b'))
        phi1 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        self.assertRaises(ValueError, self.graph.add_factors, phi1)

    def test_add_multiple_factors(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        self.graph.add_factors(phi1, phi2)
        six.assertCountEqual(self, self.graph.factors, [phi1, phi2])

    def test_get_factors(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        six.assertCountEqual(self, self.graph.get_factors(), [])
        self.graph.add_factors(phi1, phi2)
        self.assertEqual(self.graph.get_factors(node=('b', 'a')), phi1)
        self.assertEqual(self.graph.get_factors(node=('b', 'c')), phi2)
        six.assertCountEqual(self, self.graph.get_factors(), [phi1, phi2])

    def test_remove_factors(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        self.graph.add_factors(phi1, phi2)
        self.graph.remove_factors(phi1)
        six.assertCountEqual(self, self.graph.factors, [phi2])

    def test_get_partition_function(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = DiscreteFactor(['a', 'b'], [2, 2], range(4))
        phi2 = DiscreteFactor(['b', 'c'], [2, 2], range(4))
        self.graph.add_factors(phi1, phi2)
        self.assertEqual(self.graph.get_partition_function(), 22.0)

    def tearDown(self):
        del self.graph