class TestClusterGraphFactorOperations(unittest.TestCase): def setUp(self): self.graph = ClusterGraph() def test_add_single_factor(self): self.graph.add_node(('a', 'b')) phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4)) self.graph.add_factors(phi1) six.assertCountEqual(self, self.graph.factors, [phi1]) def test_add_single_factor_raises_error(self): self.graph.add_node(('a', 'b')) phi1 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4)) self.assertRaises(ValueError, self.graph.add_factors, phi1) def test_add_multiple_factors(self): self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]]) phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4)) phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4)) self.graph.add_factors(phi1, phi2) six.assertCountEqual(self, self.graph.factors, [phi1, phi2]) def test_get_factors(self): self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]]) phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4)) phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4)) six.assertCountEqual(self, self.graph.get_factors(), []) self.graph.add_factors(phi1, phi2) self.assertEqual(self.graph.get_factors(node=('b', 'a')), phi1) self.assertEqual(self.graph.get_factors(node=('b', 'c')), phi2) six.assertCountEqual(self, self.graph.get_factors(), [phi1, phi2]) def test_remove_factors(self): self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]]) phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4)) phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4)) self.graph.add_factors(phi1, phi2) self.graph.remove_factors(phi1) six.assertCountEqual(self, self.graph.factors, [phi2]) def test_get_partition_function(self): self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]]) phi1 = DiscreteFactor(['a', 'b'], [2, 2], range(4)) phi2 = DiscreteFactor(['b', 'c'], [2, 2], range(4)) self.graph.add_factors(phi1, phi2) self.assertEqual(self.graph.get_partition_function(), 22.0) def tearDown(self): del self.graph