Пример #1
0
    def search(index: int = 0):

        left_child_index, right_child_index = sklearn_tree.children_left[
            index], sklearn_tree.children_right[index]

        if left_child_index == -1:  # Trees are binary splits, so only need to check left tree
            return

        searched_node: LeafNode = nodes[index]

        split_conditions = map_sklearn_split_into_bartpy_split_conditions(
            sklearn_tree, index)
        decision_node = split_node(searched_node, split_conditions)

        left_child: LeafNode = decision_node.left_child
        right_child: LeafNode = decision_node.right_child
        left_child.set_value(sklearn_tree.value[left_child_index][0][0])
        right_child.set_value(sklearn_tree.value[right_child_index][0][0])

        mutation = GrowMutation(searched_node, decision_node)
        mutate(bartpy_tree, mutation)

        nodes[index] = decision_node
        nodes[left_child_index] = decision_node.left_child
        nodes[right_child_index] = decision_node.right_child

        search(left_child_index)
        search(right_child_index)
Пример #2
0
 def test_internal_prune(self):
     updated_c = LeafNode(Split(self.data))
     prune_mutation = TreeMutation("prune", self.c, updated_c)
     mutate(self.tree, prune_mutation)
     self.assertIn(updated_c, self.tree.leaf_nodes)
     self.assertNotIn(self.c, self.tree.nodes)
     self.assertNotIn(self.d, self.tree.nodes)
     self.assertNotIn(self.e, self.tree.nodes)
Пример #3
0
 def test_head_prune(self):
     b, c = LeafNode(Split(self.data)), LeafNode(Split(self.data))
     a = DecisionNode(Split(self.data), b, c)
     tree = Tree([a, b, c])
     updated_a = LeafNode(Split(self.data))
     prune_mutation = PruneMutation(a, updated_a)
     mutate(tree, prune_mutation)
     self.assertIn(updated_a, tree.leaf_nodes)
     self.assertNotIn(self.a, tree.nodes)
Пример #4
0
 def test_grow(self):
     f, g = LeafNode(Split(self.data)), LeafNode(Split(self.data))
     updated_d = DecisionNode(Split(self.data), f, g)
     grow_mutation = TreeMutation("grow", self.d, updated_d)
     mutate(self.tree, grow_mutation)
     self.assertIn(updated_d, self.tree.decision_nodes)
     self.assertIn(updated_d, self.tree.prunable_decision_nodes)
     self.assertIn(f, self.tree.leaf_nodes)
     self.assertNotIn(self.d, self.tree.nodes)
Пример #5
0
    def setUp(self):
        self.data = Data(
            pd.DataFrame({
                "a": [1, 2, 3],
                "b": [1, 2, 3]
            }).values, np.array([1, 2, 3]))

        self.a = split_node(LeafNode(Split(
            self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt)))
        self.b = self.a.left_child
        self.x = self.a.right_child
        self.tree = Tree([self.a, self.b, self.x])

        self.c = split_node(
            self.a._right_child,
            (SplitCondition(1, 2, le), SplitCondition(1, 2, gt)))
        mutate(self.tree, TreeMutation("grow", self.x, self.c))

        self.d = self.c.left_child
        self.e = self.c.right_child
Пример #6
0
    def setUp(self):
        X = format_covariate_matrix(
            pd.DataFrame({
                "a": [1, 2, 3],
                "b": [1, 2, 3]
            }))
        self.data = Data(X, np.array([1, 2, 3]).astype(float))

        self.a = split_node(LeafNode(Split(
            self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt)))
        self.b = self.a.left_child
        self.x = self.a.right_child
        self.tree = Tree([self.a, self.b, self.x])

        self.c = split_node(
            self.a._right_child,
            (SplitCondition(1, 2, le), SplitCondition(1, 2, gt)))
        mutate(self.tree, TreeMutation("grow", self.x, self.c))

        self.d = self.c.left_child
        self.e = self.c.right_child
Пример #7
0
 def step(self, model: Model, tree: Tree) -> Optional[TreeMutation]:
     mutation = self.sample(model, tree)
     if mutation is not None:
         mutate(tree, mutation)
     return mutation
Пример #8
0
 def step(self, model: Model, tree: Tree) -> Optional[List[TreeMutation]]:
     mutations = self.sample(model, tree)
     if mutations is not None:
         for mutation in mutations:
             mutate(tree, mutation)
     return mutations
Пример #9
0
 def step(self, model: Model, tree: Tree) -> None:
     mutation = self.sample(model, tree)
     if mutation is not None:
         mutate(tree, mutation)