예제 #1
0
def create_tree(data, label, features):
    # 1.如果训练集中所有数据均属于同一类,即可返回
    label_set = set(label)
    if len(label_set) == 1:
        category = label_set.pop()
        subtree = Tree('LEAF', category=category)
        print("[INFO]: create a leaf node, category is %s"%(category))
        return subtree

    # 2.如果特征集为空
    label_set = set(label)
    max_class = label[0]
    max_num = 0
    for each in label_set:
        class_num = label[label==each].shape[0]
        if  class_num > max_num:
            max_class = each
            max_num = class_num
    if len(features) == 0:
        subtree = Tree('LEAF', category=max_class)
        print("[INFO]: create a leaf node, category is %s"%(max_class))
        return subtree

    # 3.找到信息增益最大的特征
    base_entropy = cal_entropy(label)
    best_feature = features[0]
    max_ratio = 0
    info_gains = []
    gain_ratios = []
    for ft in features:
        attr = data[:,ft]
        info_gain = base_entropy - cal_conditional_entropy(attr, label)
        info_gains.append(info_gain)
        gain_ratios.append(info_gain/cal_entropy(attr))
    mean = np.mean(info_gains)
    for i in range(len(info_gains)):
        if info_gains[i] > mean and gain_ratios[i] > max_ratio:
            max_ratio = gain_ratios[i]
            best_feature = features[i]

    # 4.如果信息增益小于阈值,进行预剪枝
    # if max_info_gain < 0.001:
        # return Tree('LEAF', category=max_class)

    # 5.构建内部节点以及递归创造子树
    sub_features = list(filter(lambda x: x != best_feature, features))
    tree = Tree('INTERNAL', feature=best_feature, category=max_class)
    print("[INFO]: create a internal node,feature is %sth"%(best_feature))
    best_feature_values = set(data[:, best_feature])
    for value in best_feature_values:
        indexs = data[:, best_feature] == value
        sub_data = data[indexs]
        sub_label = label[indexs]
        sub_tree = create_tree(sub_data, sub_label, sub_features)
        tree.add_subtree(value, sub_tree)

    return tree
예제 #2
0
def test_add_subtree():
    print("Begin test_add_subtree\n")
    tree_test = Tree()
    tree_test.add_prep_node([], "denoise", None, None, None)
    tree_test.add_split_node(["denoise"], "ts2db")

    tree_add = Tree()
    tree_add.add_model_node([], "rf")
    tree_add.add_eval_node(["rf"], "MSE")

    tree_test.add_subtree(["denoise", "ts2db"], tree_add)

    assert tree_test.root.op == "denoise"
    assert tree_test.root.children[0].op == "ts2db"
    assert tree_test.root.children[0].children[0].op == "rf"
    assert tree_test.root.children[0].children[0].children[0].op == "MSE"
예제 #3
0
파일: test_tree.py 프로젝트: ggila/_conf
class TreeTest(TestCase):
    def setUp(self):

        self.nodes = copy.deepcopy(nodes)
        self.tree = Tree(nodes=copy.deepcopy(nodes))

        self.small_nodes = copy.deepcopy(small_nodes)
        self.small_tree = Tree(nodes=copy.deepcopy(small_nodes))

        self.big_nodes = copy.deepcopy(big_nodes)
        self.big_tree = Tree(nodes=copy.deepcopy(self.big_nodes))

    def test_tree_init(self):
        tree = self.tree
        self.assertEqual(tree.root.id, 1)
        self.assertEqual(tree.root.children, set([1, 2, 3, 4]))
        self.assertEqual(set(tree.nodes), set(range(1, len(self.nodes) + 1)))

    def test_bfs_seq(self):
        seq = self.tree._bfs_seq(1)
        self.assertEqual(seq[0], 1)
        self.assertEqual(set(seq[1:4]), set((2, 3, 4)))
        self.assertEqual(set(seq[4:8]), set((5, 6, 7, 8)))
        self.assertEqual(set(seq[8:]), set((9, 10, 11, 12)))

    def test_bfs_seq_bottomtotop(self):
        seq = self.tree._bfs_seq(1, toptobottom=False)
        self.assertEqual(set(seq[:4]), set((9, 10, 11, 12)))
        self.assertEqual(set(seq[4:8]), set((5, 6, 7, 8)))
        self.assertEqual(set(seq[8:-1]), set((2, 3, 4)))
        self.assertEqual(seq[-1], 1)

    def test_compute_weight(self):
        self.tree._compute_weight()
        self._check_weight()
        for id_, node in self.tree.items():
            node.weight = randint(1, 1000)
        self.tree._compute_weight()
        self._check_weight()

    def _check_weight(self):
        self.assertEqual(self.tree[1].weight, 12)
        self.assertEqual(self.tree[2].weight, 5)
        self.assertEqual(self.tree[3].weight, 3)
        self.assertEqual(self.tree[4].weight, 3)
        self.assertEqual(self.tree[5].weight, 3)
        self.assertEqual(self.tree[6].weight, 1)
        self.assertEqual(self.tree[7].weight, 2)
        self.assertEqual(self.tree[8].weight, 2)
        self.assertEqual(self.tree[9].weight, 1)
        self.assertEqual(self.tree[10].weight, 1)
        self.assertEqual(self.tree[11].weight, 1)
        self.assertEqual(self.tree[12].weight, 1)

    def test_extract_subtree(self):
        tree = self.tree
        cp_tree = self.tree.extract_subtree(tree.root.id)
        self.assertEqual(tree, cp_tree)

    def test_add_subtree(self):
        self.tree.add_subtree(self.small_tree, 8)
        for id_, node in self.tree.nodes.items():
            self.assertEqual(node, self.big_tree[id_])
        self.assertEqual(self.tree, self.big_tree)

    def test_del_subtree(self):
        self.big_tree.del_subtree(self.small_tree.root.id)
        self.assertEqual(self.big_tree, self.tree)