Пример #1
0
class Inspection():
    def __init__(self,ori_dataset):
        self.ori_dataset = ori_dataset
        self.ld = LoadData()

        self.er = 0
        self.gi = 0

    # majority vote
    def majority_vote(self,dataset):
        count1 = 0
        label = self.ld.get_value(dataset,-1)
        for row in dataset:
            if row[-1] == label[0]:
                count1 += 1
            else:
                continue
        count2 = len(dataset) - count1
        if count1 > count2:
            return label[0]
        elif count2 > count1:
            return label[1]
        elif count2==0:
            return label[0]
        else:
            return label[1]

    # error rate
    def error_rate(self,dataset):
        label = self.majority_vote(dataset)
        count = 0
        for row in dataset:
            if row[-1] != label:
                count += 1
        self.er = count/len(dataset)
        return self.er

    #gini impurity
    def gini_impurity(self,dataset):
        if len(dataset)==0:
            self.gi=0
        else:
            count1 = 0
            for item in dataset:
                if item[-1]==dataset[0][-1]:
                    count1+=1
            count2 = len(dataset)-count1
            self.gi = (count1/len(dataset))*(count2/len(dataset))+(count2/len(dataset))*(count1/len(dataset))
        return self.gi

    # evaluate with error_rate and gini_impurity
    def evaluate(self):
        err_rate = self.error_rate(self.ori_dataset)
        gini_impurity = self.gini_impurity(self.ori_dataset)
        return err_rate,gini_impurity
class DecisionTree():
    def __init__(self, ori_dataset, max_depth):
        self.col = []

        self.max_depth = max_depth
        self.dataset = ori_dataset

        self.ld = LoadData()
        self.ins = Inspection(ori_dataset)

    # divide the dataset with certain attribute
    def divide_dataset(self, dataset, col_index):
        label = self.ld.get_value(dataset, col_index)
        dataset0 = []
        dataset1 = []
        for row in dataset:
            if row[col_index] == label[0]:
                dataset0.append(row)
            else:
                dataset1.append(row)
        dataset0 = np.array(dataset0)
        dataset1 = np.array(dataset1)

        return dataset0, dataset1

    # calculate the gini impurity given attribute
    def gini_impurity(self, dataset, col_index=-1):

        if col_index == -1:
            gi = self.ins.gini_impurity(dataset)
        else:
            # print('dataset:\n', dataset)
            # print('col index:',col_index)
            ds = self.divide_dataset(dataset, col_index)
            # print('ds0:\n',ds[0])
            # print('ds1:\n', ds[1])
            gi_left = self.ins.gini_impurity(ds[0])
            gi_right = self.ins.gini_impurity(ds[1])
            gi = (len(ds[0]) / len(dataset)) * gi_left + (
                len(ds[1]) / len(dataset)) * gi_right

        return gi

    # calculate the gini gain given attribute
    def gini_gain(self, dataset, col_index):
        ori_gi = self.gini_impurity(dataset)
        new_gi = self.gini_impurity(dataset, col_index)
        gg = ori_gi - new_gi

        return gg

    def get_attribute(self, dataset, used_col):
        gg_arr = {}
        col_arr = [i for i in range(len(dataset[0]) - 1)]
        for item in list(set(col_arr).difference(set(used_col))):
            gg_arr[item] = self.gini_gain(dataset, item)
        col_index = max(gg_arr, key=gg_arr.get)

        return col_index

    # 记录下路径,再进行搞
    def construct(self, dataset, col_index=-1, depth=0):
        # print('\nlen:',len(dataset))
        # print(dataset)
        # print('used_col:',self.col)
        # print('depth:', depth)
        # reach the max depth
        if depth > self.max_depth:
            print('depth reach max depth')
            # self.col.pop(col_index)
            return None

        # after divide the dataset is empty
        elif len(dataset) == 0:
            print('dataset is empty.')
            # self.col.pop(col_index)
            return None

        # No more attribute to divide
        elif len(dataset[0]) == len(self.col) + 1:
            # print(self.col)
            print('all the attributes have been used.')
            # self.col.pop(col_index)
            # print(depth)
            return None

        # after divide the gini-impurity of dataset is 0
        elif self.gini_impurity(dataset) == 0:
            print('no need to do more division!')
            # self.col.pop(col_index)
            return None

        # recursively construct the left and right node
        else:
            col_index = self.get_attribute(dataset, self.col)
            # construct the current node
            node = Node(col_index, dataset, depth=depth)
            self.col.append(col_index)

            # divide the dataset according to max gini-gain
            new_dataset = self.divide_dataset(dataset, col_index)
            depth += 1
            #recurse the left branch
            left_branch = self.construct(new_dataset[0], col_index, depth)
            node.left = left_branch
            # self.col.pop(col_index)
            #recurse the right branch
            right_branch = self.construct(new_dataset[1], col_index, depth)
            node.right = right_branch
            # print('col_index:',col_index)
            self.col.remove(col_index)

            return node

    def traverse(self, node):
        if node:
            # print(node.dataset,'\n')
            print(node.depth, '\t', node.attribute)
            self.traverse(node.left)
            self.traverse(node.right)