Пример #1
0
    def test_delete_node_subtree(self):
        root = SimpleTreeNode(0, None)
        tree = SimpleTree(root)

        node = SimpleTreeNode(1, None)
        node2 = SimpleTreeNode(2, None)
        node3 = SimpleTreeNode(3, None)
        node4 = SimpleTreeNode(1, None)
        node5 = SimpleTreeNode(5, None)
        node6 = SimpleTreeNode(1, None)

        tree.AddChild(root, node)
        tree.AddChild(root, node2)
        tree.AddChild(root, node3)
        tree.AddChild(node3, node4)
        tree.AddChild(node3, node5)
        tree.AddChild(node3, node6)

        self.assertEqual(tree.Count(), 7)

        tree.DeleteNode(node3)

        self.assertEqual(tree.Count(), 3)
        for item in [node3, node4, node5, node6]:
            self.assertNotIn(item, tree.GetAllNodes())
Пример #2
0
    def test_add_child(self):
        root = SimpleTreeNode(0, None)
        tree = SimpleTree(root)
        node = SimpleTreeNode(1, None)

        tree.AddChild(root, node)

        self.assertEqual(root.Children, [node])
        self.assertEqual(node.Parent, root)
Пример #3
0
    def __init__(self):
        '''
        init function
        '''
        self.result_tree = SimpleTree()
        self.POSITIVE_VAL = 'Yes'
        self.NEGATIVE_VAL = 'No'

        self.ATTRIB_NAME = 'Attrib'
        self.VALUE_NAME = 'Value'
Пример #4
0
    def test_find_nodes_by_value(self):
        root = SimpleTreeNode(0, None)
        tree = SimpleTree(root)

        node = SimpleTreeNode(1, None)
        node2 = SimpleTreeNode(2, None)
        node3 = SimpleTreeNode(3, None)
        node4 = SimpleTreeNode(1, None)
        node5 = SimpleTreeNode(5, None)
        node6 = SimpleTreeNode(1, None)


        expected_nodes = [node, node4, node6]


        tree.AddChild(root, node)
        tree.AddChild(root, node2)
        tree.AddChild(root, node3)
        tree.AddChild(node3, node4)
        tree.AddChild(node3, node5)
        tree.AddChild(node3, node6)

        nodes = tree.FindNodesByValue(1)

        self.assertEqual(len(nodes), len(expected_nodes))
        for item in expected_nodes:
            self.assertIn(item, nodes)
Пример #5
0
    def test_get_all_nodes(self):
        root = SimpleTreeNode(0, None)
        tree = SimpleTree(root)

        node = SimpleTreeNode(1, None)
        node2 = SimpleTreeNode(2, None)

        expected_nodes = [root, node, node2]


        tree.AddChild(root, node)
        tree.AddChild(root, node2)

        nodes = tree.GetAllNodes()

        self.assertEqual(len(nodes), len(expected_nodes))
        for item in expected_nodes:
            self.assertIn(item, nodes)
    def setUp(self) -> None:
        self.node_root = SimpleTreeNode("root", None)
        self.node_1 = SimpleTreeNode("1", None)
        self.node_2 = SimpleTreeNode("2", None)
        self.node_3 = SimpleTreeNode("3", None)
        self.node_4 = SimpleTreeNode("4", None)
        self.node_5 = SimpleTreeNode("5", None)
        self.node_6 = SimpleTreeNode("6", None)

        self.tree = SimpleTree(self.node_root)

        self.tree.AddChild(self.node_1, self.node_3)
        self.tree.AddChild(self.node_1, self.node_4)

        self.tree.AddChild(self.node_2, self.node_5)
        self.tree.AddChild(self.node_2, self.node_6)

        self.tree.AddChild(self.node_root, self.node_1)
        self.tree.AddChild(self.node_root, self.node_2)
class BaseFullSimpleTreeTestCase(unittest.TestCase):
    def setUp(self) -> None:
        self.node_root = SimpleTreeNode("root", None)
        self.node_1 = SimpleTreeNode("1", None)
        self.node_2 = SimpleTreeNode("2", None)
        self.node_3 = SimpleTreeNode("3", None)
        self.node_4 = SimpleTreeNode("4", None)
        self.node_5 = SimpleTreeNode("5", None)
        self.node_6 = SimpleTreeNode("6", None)

        self.tree = SimpleTree(self.node_root)

        self.tree.AddChild(self.node_1, self.node_3)
        self.tree.AddChild(self.node_1, self.node_4)

        self.tree.AddChild(self.node_2, self.node_5)
        self.tree.AddChild(self.node_2, self.node_6)

        self.tree.AddChild(self.node_root, self.node_1)
        self.tree.AddChild(self.node_root, self.node_2)
class AddChildToRootTestCase(unittest.TestCase):
    def setUp(self) -> None:
        self.tree = SimpleTree(None)

    def test(self):
        node = SimpleTreeNode("root", None)

        self.assertIsNone(self.tree.Root)
        self.tree.AddChild(None, node)

        self.assertEqual(node, self.tree.Root)
        self.assertEqual("root", self.tree.Root.NodeValue)
class AddChildToNodeTestCase(unittest.TestCase):
    def setUp(self) -> None:
        self.root_node = SimpleTreeNode("root", None)
        self.tree = SimpleTree(self.root_node)

    def test(self):
        node = SimpleTreeNode("1", None)

        self.assertEqual("root", self.tree.Root.NodeValue)
        self.tree.AddChild(self.root_node, node)

        self.assertEqual(node, self.tree.Root.Children[0])
        self.assertEqual(self.root_node, node.Parent)
Пример #10
0
    def test_delete_node(self):
        root = SimpleTreeNode(0, None)
        tree = SimpleTree(root)

        node = SimpleTreeNode(1, None)
        node2 = SimpleTreeNode(2, None)

        tree.AddChild(root, node)
        tree.AddChild(root, node2)

        self.assertEqual(tree.Count(), 3)

        tree.DeleteNode(node)

        self.assertEqual(tree.Count(), 2)
        self.assertNotIn(node, tree.GetAllNodes())
Пример #11
0
    def test_count_leaf(self):
        root = SimpleTreeNode(0, None)
        tree = SimpleTree(root)

        node = SimpleTreeNode(1, None)
        node2 = SimpleTreeNode(2, None)
        node3 = SimpleTreeNode(3, None)
        node4 = SimpleTreeNode(1, None)
        node5 = SimpleTreeNode(5, None)
        node6 = SimpleTreeNode(1, None)

        tree.AddChild(root, node)
        tree.AddChild(root, node2)
        tree.AddChild(root, node3)
        tree.AddChild(node3, node4)
        tree.AddChild(node3, node5)
        tree.AddChild(node3, node6)

        self.assertEqual(tree.LeafCount(), 5)
Пример #12
0
    def test_get_all_nodes_empty_tree(self):
        tree = SimpleTree(None)

        self.assertEqual(tree.GetAllNodes(), [])
Пример #13
0
 def setUp(self) -> None:
     self.tree = SimpleTree(None)
Пример #14
0
 def setUp(self) -> None:
     self.root_node = SimpleTreeNode("root", None)
     self.tree = SimpleTree(self.root_node)
Пример #15
0
class ID3Classifier:
    '''
    execution of ID3 algorithm to classify data
    '''
    def __init__(self):
        '''
        init function
        '''
        self.result_tree = SimpleTree()
        self.POSITIVE_VAL = 'Yes'
        self.NEGATIVE_VAL = 'No'

        self.ATTRIB_NAME = 'Attrib'
        self.VALUE_NAME = 'Value'

    def id3_compute(self, attrib_names, attribs, targets):
        '''
        interface function to start the execution of ID3 process
        '''

        for attrib_name in attrib_names:
            self.result_tree.add_node(attrib_name)

        self.result_tree.add_leafnode(self.POSITIVE_VAL)
        self.result_tree.add_leafnode(self.NEGATIVE_VAL)

        parent = {self.ATTRIB_NAME: None, self.VALUE_NAME: None}
        self.__id3(attrib_names, attribs, targets, parent)

    def display_tree(self):
        '''
        displays the results tree
        '''
        self.result_tree.display()
        self.result_tree.show()

    def __id3(self, attrib_names, attribs, targets, parent):
        '''
        Actual ID3 algorithm that is executed recursively to classify the data
        '''
        # calculate the entropy of the set
        unique, counts = np.unique(targets, return_counts=True)
        unique_counts = dict(zip(unique, counts))
        total_count = len(targets)

        if len(unique) == 1:
            leaf_value = unique[0]
            self.result_tree.add_edge(parent[self.ATTRIB_NAME], leaf_value,
                                      parent[self.VALUE_NAME])
        else:
            # prob of results
            prob_yes = unique_counts[self.POSITIVE_VAL] / total_count
            prob_no = unique_counts[self.NEGATIVE_VAL] / total_count

            # entropy calculation
            entropy_set = -((prob_yes * np.log2(prob_yes)) +
                            (prob_no * np.log2(prob_no)))

            # list to contain info gain of all attributes
            attrib_gains = []

            # loop through all attributes
            for idx, _ in enumerate(attrib_names):
                # get all possible values for the attribute, e.g. for temp (hot,mild,cold)
                node_unique_vals = np.unique(attribs[:, idx])
                # the info gaine will be the set entropy minus the attriutre entropies
                # initialize gain with set entropy
                attrib_info_gain = entropy_set

                # loop through unique values
                for aval in node_unique_vals:
                    # get how many attribues have that value
                    aval_filter = attribs[:, idx] == aval
                    aval_count = np.count_nonzero(aval_filter)
                    # how many of these are true and false
                    aval_res_true = np.count_nonzero(
                        targets[aval_filter] == self.POSITIVE_VAL) / aval_count
                    aval_res_false = np.count_nonzero(
                        targets[aval_filter] == self.NEGATIVE_VAL) / aval_count

                    # the value will be 0 if one of the above values is 0 cause log 0
                    aval_entropy = 0
                    if aval_res_false != 0 and aval_res_true != 0:
                        aval_entropy = -(
                            (aval_res_true * np.log2(aval_res_true)) +
                            (aval_res_false * np.log2(aval_res_false)))

                    # factor into the info gain
                    attrib_info_gain = attrib_info_gain - \
                        (aval_entropy * (aval_count / total_count))

                # append the attribute info gain
                attrib_gains.append(attrib_info_gain)

            # select the attribute with the maximum info gain
            selected_attrib_filter = np.where(
                attrib_gains == np.max(attrib_gains))

            # get the index of this attribute
            selected_attribute_idx = selected_attrib_filter[0][0].squeeze()

            # get the attribute name and add it to the tree as a node
            current_attrib = attrib_names[selected_attribute_idx]

            # get the unique values for this attribute
            current_vals = np.unique(
                attribs[:, selected_attrib_filter[0].squeeze()])

            # loop through all values
            for current_val in current_vals:
                # get the items with that value
                attribs_filter = attribs[:,
                                         selected_attribute_idx] == current_val
                data_subset = attribs[attribs_filter]
                data_subset = np.delete(data_subset, selected_attribute_idx, 1)
                targets_subset = targets[attribs_filter]
                attribs_names_subset = attrib_names.copy()
                attribs_names_subset.remove(current_attrib)

                if parent[self.ATTRIB_NAME] is None:
                    self.result_tree.rootnode = current_attrib
                else:
                    self.result_tree.add_edge(parent[self.ATTRIB_NAME],
                                              current_attrib,
                                              parent[self.VALUE_NAME])

                current_parent = {
                    self.ATTRIB_NAME: current_attrib,
                    self.VALUE_NAME: current_val
                }

                self.__id3(attribs_names_subset, data_subset, targets_subset,
                           current_parent)

    def infer(self, case):
        '''
        goes through the case in hand in the tree and returns the result found

        Arguments:
        ----------

        case: dictionary containing the case in had in the form 'node:branch'
        '''
        result = self.result_tree.traverse(case)
        return result
Пример #16
0
'''
simple example demonstrating the operation of the SimpleTree

1. creates a tree with nodes and edges
2. creates another tree with nodes and edges
3. appends the second tree to the forst tree 
'''
from simple_tree import SimpleTree

a_tree = SimpleTree()
a_tree.add_node('One')
a_tree.add_node('Two')
a_tree.add_node('Three')
a_tree.add_node('Four')
a_tree.add_edge('One', 'Three', 'No')
a_tree.add_node('Five')
a_tree.add_edge('One', 'Two', 'Yes')
a_tree.add_edge('Two', 'Four', 'Yes')
a_tree.add_edge('Four', 'Five', 'Yes')
a_tree.add_node('Six')
a_tree.add_edge('Four', 'Six', 'No')

a_tree.set_root_node('One')
a_tree.display()

b_tree = SimpleTree()
b_tree.add_node('Uno')
b_tree.add_node('Due')
b_tree.add_node('Tre')
b_tree.add_node('Quattro')
b_tree.add_edge('Uno', 'Due', 'Si')
Пример #17
0
    def test_move_node_with_subtree(self):
        root = SimpleTreeNode(0, None)
        tree = SimpleTree(root)

        node = SimpleTreeNode(1, None)
        node2 = SimpleTreeNode(2, None)
        node3 = SimpleTreeNode(3, None)
        node4 = SimpleTreeNode(1, None)
        node5 = SimpleTreeNode(5, None)
        node6 = SimpleTreeNode(1, None)

        tree.AddChild(root, node)
        tree.AddChild(root, node2)
        tree.AddChild(root, node3)
        tree.AddChild(node3, node4)
        tree.AddChild(node4, node5)
        tree.AddChild(node4, node6)

        self.assertEqual(tree.Count(), 7)
        self.assertEqual(tree.LeafCount(), 4)

        self.assertEqual(node4.Children, [node5, node6])
        self.assertEqual(node4.Parent, node3)

        tree.MoveNode(node4, node)

        self.assertEqual(tree.Count(), 7)
        self.assertEqual(tree.LeafCount(), 4)

        self.assertEqual(node4.Children, [node5, node6])
        self.assertEqual(node4.Parent, node)