def ID3(data_obj, attributes, parent, tree, limit_depth=False, max_depth=10): label = data_obj.get_column('label') if (len(label) <= 0): # No data return elif (len(np.unique(label)) == 1): # Object of only 1 label in the tree # Add a node with labels n = Node(np.unique(label)[0], True) tree.add_node(n, parent) elif (len(attributes) <= 0): # Add majority label to the tree as the node # first get counts of individual labels bins, counts = np.unique(label, return_counts=True) n = Node(bins[np.argmax(counts)], True) tree.add_node(n, parent) else: if (limit_depth): if (max_depth < 0): print("Max-depth should be greater than 0. Aborting!!!") return # Information gain for each features info_gain_per_feature = {} for key in attributes: info_gain_per_feature[key] = info_gain(data_obj, key) # print(key + "," + str(info_gain_per_feature[key])) # Choose the best feature and the possible values best_feature = max(info_gain_per_feature, key=info_gain_per_feature.get) best_feature_values = data_obj.attributes[best_feature].possible_vals # Add a node n = Node(best_feature, False) # Add all possible directions in which node can go for i in range(len(best_feature_values)): # partition into subset based on different values data_subset_obj = data_obj.get_row_subset(best_feature, best_feature_values[i]) # if non-zero items in the subset data if (data_subset_obj.raw_data.shape[0] > 0): n.add_value(best_feature_values[i]) tree.add_node(n, parent) # Check depth of the tree after adding this node. if (limit_depth): depth = tree.get_depth(tree.get_root()) if (depth > max_depth): # Donot grow the tree instead add label nodes tree.del_node(n, parent) # Add majority label to the tree as the node # first get counts of individual labels bins, counts = np.unique(label, return_counts=True) n = Node(bins[np.argmax(counts)], True) tree.add_node(n, parent) return # pop this feature from dictionary attributes_new = remove_key(attributes, best_feature) for i in range(len(n.value)): # partition into subsets based on different values data_subset_obj = data_obj.get_row_subset(best_feature, n.value[i]) if (parent is None): new_parent = tree.get_root() else: new_parent = parent.child[-1] ID3(data_subset_obj, attributes_new, new_parent, tree, limit_depth, max_depth)
from tree import Tree from tree import Node myTree = Tree() #print(myTree.get_root()) n = Node('taste') n.add_value('o') p = Node('var') n.add_value('a') q = Node('var') n.add_value('b') r = Node('var') r.add_value('c') s = Node('name') myTree.add_node(n, myTree.get_root()) print("Traversing the tree after adding 1 node") myTree.print_tree(myTree.get_root(), 0) myTree.add_node(p, n) #myTree.add_node(p,myTree.search_node(myTree.get_root(),n.feature,n.value)) print("Traversing the tree after adding 2 nodes") myTree.print_tree(myTree.get_root(), 0) myTree.add_node(q, n) myTree.add_node(r, n)