class Inspection(): def __init__(self,ori_dataset): self.ori_dataset = ori_dataset self.ld = LoadData() self.er = 0 self.gi = 0 # majority vote def majority_vote(self,dataset): count1 = 0 label = self.ld.get_value(dataset,-1) for row in dataset: if row[-1] == label[0]: count1 += 1 else: continue count2 = len(dataset) - count1 if count1 > count2: return label[0] elif count2 > count1: return label[1] elif count2==0: return label[0] else: return label[1] # error rate def error_rate(self,dataset): label = self.majority_vote(dataset) count = 0 for row in dataset: if row[-1] != label: count += 1 self.er = count/len(dataset) return self.er #gini impurity def gini_impurity(self,dataset): if len(dataset)==0: self.gi=0 else: count1 = 0 for item in dataset: if item[-1]==dataset[0][-1]: count1+=1 count2 = len(dataset)-count1 self.gi = (count1/len(dataset))*(count2/len(dataset))+(count2/len(dataset))*(count1/len(dataset)) return self.gi # evaluate with error_rate and gini_impurity def evaluate(self): err_rate = self.error_rate(self.ori_dataset) gini_impurity = self.gini_impurity(self.ori_dataset) return err_rate,gini_impurity
class DecisionTree(): def __init__(self, ori_dataset, max_depth): self.col = [] self.max_depth = max_depth self.dataset = ori_dataset self.ld = LoadData() self.ins = Inspection(ori_dataset) # divide the dataset with certain attribute def divide_dataset(self, dataset, col_index): label = self.ld.get_value(dataset, col_index) dataset0 = [] dataset1 = [] for row in dataset: if row[col_index] == label[0]: dataset0.append(row) else: dataset1.append(row) dataset0 = np.array(dataset0) dataset1 = np.array(dataset1) return dataset0, dataset1 # calculate the gini impurity given attribute def gini_impurity(self, dataset, col_index=-1): if col_index == -1: gi = self.ins.gini_impurity(dataset) else: # print('dataset:\n', dataset) # print('col index:',col_index) ds = self.divide_dataset(dataset, col_index) # print('ds0:\n',ds[0]) # print('ds1:\n', ds[1]) gi_left = self.ins.gini_impurity(ds[0]) gi_right = self.ins.gini_impurity(ds[1]) gi = (len(ds[0]) / len(dataset)) * gi_left + ( len(ds[1]) / len(dataset)) * gi_right return gi # calculate the gini gain given attribute def gini_gain(self, dataset, col_index): ori_gi = self.gini_impurity(dataset) new_gi = self.gini_impurity(dataset, col_index) gg = ori_gi - new_gi return gg def get_attribute(self, dataset, used_col): gg_arr = {} col_arr = [i for i in range(len(dataset[0]) - 1)] for item in list(set(col_arr).difference(set(used_col))): gg_arr[item] = self.gini_gain(dataset, item) col_index = max(gg_arr, key=gg_arr.get) return col_index # 记录下路径,再进行搞 def construct(self, dataset, col_index=-1, depth=0): # print('\nlen:',len(dataset)) # print(dataset) # print('used_col:',self.col) # print('depth:', depth) # reach the max depth if depth > self.max_depth: print('depth reach max depth') # self.col.pop(col_index) return None # after divide the dataset is empty elif len(dataset) == 0: print('dataset is empty.') # self.col.pop(col_index) return None # No more attribute to divide elif len(dataset[0]) == len(self.col) + 1: # print(self.col) print('all the attributes have been used.') # self.col.pop(col_index) # print(depth) return None # after divide the gini-impurity of dataset is 0 elif self.gini_impurity(dataset) == 0: print('no need to do more division!') # self.col.pop(col_index) return None # recursively construct the left and right node else: col_index = self.get_attribute(dataset, self.col) # construct the current node node = Node(col_index, dataset, depth=depth) self.col.append(col_index) # divide the dataset according to max gini-gain new_dataset = self.divide_dataset(dataset, col_index) depth += 1 #recurse the left branch left_branch = self.construct(new_dataset[0], col_index, depth) node.left = left_branch # self.col.pop(col_index) #recurse the right branch right_branch = self.construct(new_dataset[1], col_index, depth) node.right = right_branch # print('col_index:',col_index) self.col.remove(col_index) return node def traverse(self, node): if node: # print(node.dataset,'\n') print(node.depth, '\t', node.attribute) self.traverse(node.left) self.traverse(node.right)