def generate_BayesNet(root): ''' Generate a BayesNet from a XMLBIF. This method is used internally. Do not call it outside this class. ''' network = BayesNet() bif_nodes = root.getElementsByTagName("BIF") if len(bif_nodes) != 1: raise Exception("More than one or none <BIF>-tag in document.") network_nodes = bif_nodes[0].getElementsByTagName("NETWORK") if len(network_nodes) != 1: raise Exception("More than one or none <NETWORK>-tag in document.") variable_nodes = network_nodes[0].getElementsByTagName("VARIABLE") for variable_node in variable_nodes: name = "Unnamed node" value_range = [] position = (0, 0) for name_node in variable_node.getElementsByTagName("NAME"): name = XMLBIF.get_node_text(name_node.childNodes) break for output_node in variable_node.getElementsByTagName("OUTCOME"): value_range.append(XMLBIF.get_node_text(output_node.childNodes)) for position_node in variable_node.getElementsByTagName("PROPERTY"): position = XMLBIF.get_node_position_from_text(position_node.childNodes) break new_node = DiscreteNode(name, value_range) new_node.position = position network.add_node(new_node) definition_nodes = network_nodes[0].getElementsByTagName("DEFINITION") for definition_node in definition_nodes: node = None for for_node in definition_node.getElementsByTagName("FOR"): name = XMLBIF.get_node_text(for_node.childNodes) node = network.get_node(name) break if node == None: continue for given_node in definition_node.getElementsByTagName("GIVEN"): parent_name = XMLBIF.get_node_text(given_node.childNodes) parent_node = network.get_node(parent_name) node.announce_parent(parent_node) for table_node in definition_node.getElementsByTagName("TABLE"): table = XMLBIF.get_node_table_from_text(table_node.childNodes) node.get_cpd().get_table().T.flat = table break return network
class NodeAddAndRemoveTestCase(unittest.TestCase): def setUp(self): self.bn = BayesNet() def tearDown(self): self.bn = None def test_clear_and_len(self): self.assertFalse(0 == len(self.bn)) self.assertFalse(0 == self.bn.number_of_nodes()) self.bn.clear() self.assertEqual(0, len(self.bn)) self.assertEqual(0, self.bn.number_of_nodes()) def test_add_node(self): self.bn.clear() n = DiscreteNode("Some Node", [True, False]) self.bn.add_node(n) self.assertEqual(n, self.bn.get_node("Some Node")) self.assertTrue(n in self.bn.get_nodes(["Some Node"])) node_with_same_name = DiscreteNode("Some Node", [True, False]) self.assertRaises(Exception, self.bn.add_node, node_with_same_name) def test_remove_node(self): self.bn.clear() n = DiscreteNode("Some Node to remove", [True, False]) self.bn.add_node(n) self.bn.remove_node(n) self.assertFalse(n in self.bn.get_nodes([])) def test_add_edge(self): self.bn.clear() n1 = DiscreteNode("1", [True, False]) n2 = DiscreteNode("2", [True, False]) self.bn.add_node(n1) self.bn.add_node(n2) self.bn.add_edge(n1, n2) self.assertTrue(n1 in self.bn.get_parents(n2)) self.assertTrue(n2 in self.bn.get_children(n1)) def test_remove_edge(self): self.bn.clear() n1 = DiscreteNode("1", [True, False]) n2 = DiscreteNode("2", [True, False]) self.bn.add_node(n1) self.bn.add_node(n2) self.bn.add_edge(n1, n2) self.assertEqual([n1], self.bn.get_parents(n2)) self.bn.remove_edge(n1, n2) self.assertEqual([], self.bn.get_parents(n2)) def test_is_valid(self): self.bn.clear() n1 = DiscreteNode("1", [True, False]) n2 = DiscreteNode("2", [True, False]) self.bn.add_node(n1) self.bn.add_node(n2) self.bn.add_edge(n1, n2) self.assertTrue(self.bn.is_valid()) self.bn.add_edge(n1, n1) self.assertFalse(self.bn.is_valid()) self.bn.remove_edge(n1, n1) self.assertTrue(self.bn.is_valid()) n3 = DiscreteNode("3", [True, False]) n4 = DiscreteNode("4", [True, False]) self.bn.add_node(n3) self.bn.add_node(n4) self.assertTrue(self.bn.is_valid()) self.bn.add_edge(n2, n3) self.assertTrue(self.bn.is_valid()) self.bn.add_edge(n3, n4) self.assertTrue(self.bn.is_valid()) self.bn.add_edge(n4, n1) self.assertFalse(self.bn.is_valid())