Example #1
0
def sample_split_node(node: LeafNode) -> DecisionNode:
    """
    Split a leaf node into a decision node with two leaf children
    The variable and value to split on is determined by sampling from their respective distributions
    """
    conditions = sample_split_condition(node)
    return split_node(node, conditions)
Example #2
0
    def search(index: int = 0):

        left_child_index, right_child_index = sklearn_tree.children_left[
            index], sklearn_tree.children_right[index]

        if left_child_index == -1:  # Trees are binary splits, so only need to check left tree
            return

        searched_node: LeafNode = nodes[index]

        split_conditions = map_sklearn_split_into_bartpy_split_conditions(
            sklearn_tree, index)
        decision_node = split_node(searched_node, split_conditions)

        left_child: LeafNode = decision_node.left_child
        right_child: LeafNode = decision_node.right_child
        left_child.set_value(sklearn_tree.value[left_child_index][0][0])
        right_child.set_value(sklearn_tree.value[right_child_index][0][0])

        mutation = GrowMutation(searched_node, decision_node)
        mutate(bartpy_tree, mutation)

        nodes[index] = decision_node
        nodes[left_child_index] = decision_node.left_child
        nodes[right_child_index] = decision_node.right_child

        search(left_child_index)
        search(right_child_index)
Example #3
0
    def setUp(self):
        self.data = Data(
            pd.DataFrame({
                "a": [1, 2, 3],
                "b": [1, 2, 3]
            }).values, np.array([1, 2, 3]))

        self.a = split_node(LeafNode(Split(
            self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt)))
        self.b = self.a.left_child
        self.x = self.a.right_child
        self.tree = Tree([self.a, self.b, self.x])

        self.c = split_node(
            self.a._right_child,
            (SplitCondition(1, 2, le), SplitCondition(1, 2, gt)))
        mutate(self.tree, TreeMutation("grow", self.x, self.c))

        self.d = self.c.left_child
        self.e = self.c.right_child
Example #4
0
    def setUp(self):
        X = format_covariate_matrix(
            pd.DataFrame({
                "a": [1, 2, 3],
                "b": [1, 2, 3]
            }))
        self.data = Data(X, np.array([1, 2, 3]).astype(float))

        self.a = split_node(LeafNode(Split(
            self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt)))
        self.b = self.a.left_child
        self.x = self.a.right_child
        self.tree = Tree([self.a, self.b, self.x])

        self.c = split_node(
            self.a._right_child,
            (SplitCondition(1, 2, le), SplitCondition(1, 2, gt)))
        mutate(self.tree, TreeMutation("grow", self.x, self.c))

        self.d = self.c.left_child
        self.e = self.c.right_child
Example #5
0
def sample_split_node(node: LeafNode) -> DecisionNode:
    """
    Split a leaf node into a decision node with two leaf children
    The variable and value to split on is determined by sampling from their respective distributions
    """
    if node.is_splittable():
        conditions = sample_split_condition(node)
        return split_node(node, conditions)
    else:
        return DecisionNode(node.split,
                            LeafNode(node.split, depth=node.depth + 1),
                            LeafNode(node.split, depth=node.depth + 1),
                            depth=node.depth)
Example #6
0
    def test_split(self):
        left_split_condition = SplitCondition(0, 3, le)
        right_split_condition = SplitCondition(0, 3, gt)
        updated_node = split_node(
            self.node, [left_split_condition, right_split_condition])
        self.assertIsInstance(updated_node, DecisionNode)

        self.assertEqual(updated_node.data.y.summed_y(), 15)
        self.assertEqual(updated_node.left_child.data.y.summed_y(), 6)
        self.assertEqual(updated_node.right_child.data.y.summed_y(), 9)

        self.assertEqual(updated_node.data.X.n_obsv, 5)
        self.assertEqual(updated_node.left_child.data.X.n_obsv, 3)
        self.assertEqual(updated_node.right_child.data.X.n_obsv, 2)

        updated_node.update_y([2.0, 4.0, 6.0, 8.0, 10.0])

        self.assertEqual(updated_node.left_child.data.y.summed_y(), 12)