class TestClusterGraphMethods(unittest.TestCase): def setUp(self): self.graph = ClusterGraph() def test_get_cardinality(self): self.graph.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), (('a', 'b', 'c'), ('a', 'c'))]) self.assertDictEqual(self.graph.get_cardinality(), {}) phi1 = Factor(['a', 'b', 'c'], [1, 2, 2], np.random.rand(4)) self.graph.add_factors(phi1) self.assertDictEqual(self.graph.get_cardinality(), {'a': 1, 'b': 2, 'c':2}) self.graph.remove_factors(phi1) self.assertDictEqual(self.graph.get_cardinality(), {}) phi1 = Factor(['a', 'b'], [1, 2], np.random.rand(2)) phi2 = Factor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1, phi2) self.assertDictEqual(self.graph.get_cardinality(), {'a': 1, 'b': 2, 'c': 2}) phi3 = Factor(['a', 'c'], [1, 1], np.random.rand(1)) self.graph.add_factors(phi3) self.assertDictEqual(self.graph.get_cardinality(), {'c': 1, 'b': 2, 'a': 1}) self.graph.remove_factors(phi1, phi2, phi3) self.assertDictEqual(self.graph.get_cardinality(), {}) def test_get_cardinality_check_cardinality(self): self.graph.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), (('a', 'b', 'c'), ('a', 'c'))]) phi1 = Factor(['a', 'b'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1) self.assertRaises(ValueError, self.graph.get_cardinality, check_cardinality=True) self.graph.remove_factors(phi1) phi2 = Factor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi2) self.assertRaises(ValueError, self.graph.get_cardinality, check_cardinality=True) self.graph.add_factors(phi1) self.assertDictEqual(self.graph.get_cardinality(check_cardinality=True), {'c': 2, 'b': 2, 'a': 1}) def test_check_model(self): self.graph.add_edges_from([(('a', 'b'), ('a', 'c'))]) phi1 = Factor(['a', 'b'], [1, 2], np.random.rand(2)) phi2 = Factor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1, phi2) self.assertTrue(self.graph.check_model()) self.graph.remove_factors(phi2) phi2 = Factor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi2) self.assertTrue(self.graph.check_model()) def test_check_model1(self): self.graph.add_edges_from([(('a', 'b'), ('a', 'c')), (('a', 'c'), ('a', 'd'))]) phi1 = Factor(['a', 'b'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1) self.assertRaises(ValueError, self.graph.check_model) phi2 = Factor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi2) self.assertRaises(ValueError, self.graph.check_model) def test_check_model2(self): self.graph.add_edges_from([(('a', 'b'), ('a', 'c')), (('a', 'c'), ('a', 'd'))]) phi1 = Factor(['a', 'b'], [1, 2], np.random.rand(2)) phi2 = Factor(['a', 'c'], [3, 3], np.random.rand(9)) phi3 = Factor(['a', 'd'], [4, 4], np.random.rand(16)) self.graph.add_factors(phi1, phi2, phi3) self.assertRaises(ValueError, self.graph.check_model) self.graph.remove_factors(phi2) phi2 = Factor(['a', 'c'], [1, 3], np.random.rand(3)) self.graph.add_factors(phi2) self.assertRaises(ValueError, self.graph.check_model) self.graph.remove_factors(phi3) phi3 = Factor(['a', 'd'], [1, 4], np.random.rand(4)) self.graph.add_factors(phi3) self.assertTrue(self.graph.check_model()) def tearDown(self): del self.graph
class TestClusterGraphMethods(unittest.TestCase): def setUp(self): self.graph = ClusterGraph() def test_get_cardinality(self): self.graph.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), (('a', 'b', 'c'), ('a', 'c'))]) self.assertDictEqual(self.graph.get_cardinality(), {}) phi1 = DiscreteFactor(['a', 'b', 'c'], [1, 2, 2], np.random.rand(4)) self.graph.add_factors(phi1) self.assertDictEqual(self.graph.get_cardinality(), {'a': 1, 'b': 2, 'c': 2}) self.graph.remove_factors(phi1) self.assertDictEqual(self.graph.get_cardinality(), {}) phi1 = DiscreteFactor(['a', 'b'], [1, 2], np.random.rand(2)) phi2 = DiscreteFactor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1, phi2) self.assertDictEqual(self.graph.get_cardinality(), {'a': 1, 'b': 2, 'c': 2}) phi3 = DiscreteFactor(['a', 'c'], [1, 1], np.random.rand(1)) self.graph.add_factors(phi3) self.assertDictEqual(self.graph.get_cardinality(), {'c': 1, 'b': 2, 'a': 1}) self.graph.remove_factors(phi1, phi2, phi3) self.assertDictEqual(self.graph.get_cardinality(), {}) def test_get_cardinality_check_cardinality(self): self.graph.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), (('a', 'b', 'c'), ('a', 'c'))]) phi1 = DiscreteFactor(['a', 'b'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1) self.assertRaises(ValueError, self.graph.get_cardinality, check_cardinality=True) self.graph.remove_factors(phi1) phi2 = DiscreteFactor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi2) self.assertRaises(ValueError, self.graph.get_cardinality, check_cardinality=True) self.graph.add_factors(phi1) self.assertDictEqual(self.graph.get_cardinality(check_cardinality=True), {'c': 2, 'b': 2, 'a': 1}) def test_check_model(self): self.graph.add_edges_from([(('a', 'b'), ('a', 'c'))]) phi1 = DiscreteFactor(['a', 'b'], [1, 2], np.random.rand(2)) phi2 = DiscreteFactor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1, phi2) self.assertTrue(self.graph.check_model()) self.graph.remove_factors(phi2) phi2 = DiscreteFactor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi2) self.assertTrue(self.graph.check_model()) def test_check_model1(self): self.graph.add_edges_from([(('a', 'b'), ('a', 'c')), (('a', 'c'), ('a', 'd'))]) phi1 = DiscreteFactor(['a', 'b'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi1) self.assertRaises(ValueError, self.graph.check_model) phi2 = DiscreteFactor(['a', 'c'], [1, 2], np.random.rand(2)) self.graph.add_factors(phi2) self.assertRaises(ValueError, self.graph.check_model) def test_check_model2(self): self.graph.add_edges_from([(('a', 'b'), ('a', 'c')), (('a', 'c'), ('a', 'd'))]) phi1 = DiscreteFactor(['a', 'b'], [1, 2], np.random.rand(2)) phi2 = DiscreteFactor(['a', 'c'], [3, 3], np.random.rand(9)) phi3 = DiscreteFactor(['a', 'd'], [4, 4], np.random.rand(16)) self.graph.add_factors(phi1, phi2, phi3) self.assertRaises(ValueError, self.graph.check_model) self.graph.remove_factors(phi2) phi2 = DiscreteFactor(['a', 'c'], [1, 3], np.random.rand(3)) self.graph.add_factors(phi2) self.assertRaises(ValueError, self.graph.check_model) self.graph.remove_factors(phi3) phi3 = DiscreteFactor(['a', 'd'], [1, 4], np.random.rand(4)) self.graph.add_factors(phi3) self.assertTrue(self.graph.check_model()) def test_copy_with_factors(self): self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]]) phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4)) phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4)) self.graph.add_factors(phi1, phi2) graph_copy = self.graph.copy() self.assertIsInstance(graph_copy, ClusterGraph) self.assertEqual(hf.recursive_sorted(self.graph.nodes()), hf.recursive_sorted(graph_copy.nodes())) self.assertEqual(hf.recursive_sorted(self.graph.edges()), hf.recursive_sorted(graph_copy.edges())) self.assertTrue(graph_copy.check_model()) self.assertEqual(self.graph.get_factors(), graph_copy.get_factors()) self.graph.remove_factors(phi1, phi2) self.assertTrue(phi1 not in self.graph.factors and phi2 not in self.graph.factors) self.assertTrue(phi1 in graph_copy.factors and phi2 in graph_copy.factors) self.graph.add_factors(phi1, phi2) self.graph.factors[0] = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4)) self.assertNotEqual(self.graph.get_factors()[0], graph_copy.get_factors()[0]) self.assertNotEqual(self.graph.factors, graph_copy.factors) def test_copy_without_factors(self): self.graph.add_nodes_from([('a', 'b', 'c'), ('a', 'b'), ('a', 'c')]) self.graph.add_edges_from([(('a', 'b', 'c'), ('a', 'b')), (('a', 'b', 'c'), ('a', 'c'))]) graph_copy = self.graph.copy() self.graph.remove_edge(('a', 'b', 'c'), ('a', 'c')) self.assertFalse(self.graph.has_edge(('a', 'b', 'c'), ('a', 'c'))) self.assertTrue(graph_copy.has_edge(('a', 'b', 'c'), ('a', 'c'))) self.graph.remove_node(('a', 'c')) self.assertFalse(self.graph.has_node(('a', 'c'))) self.assertTrue(graph_copy.has_node(('a', 'c'))) self.graph.add_node(('c', 'd')) self.assertTrue(self.graph.has_node(('c', 'd'))) self.assertFalse(graph_copy.has_node(('c', 'd'))) def tearDown(self): del self.graph