def sample_split_node(node: LeafNode) -> DecisionNode: """ Split a leaf node into a decision node with two leaf children The variable and value to split on is determined by sampling from their respective distributions """ conditions = sample_split_condition(node) return split_node(node, conditions)
def search(index: int = 0): left_child_index, right_child_index = sklearn_tree.children_left[ index], sklearn_tree.children_right[index] if left_child_index == -1: # Trees are binary splits, so only need to check left tree return searched_node: LeafNode = nodes[index] split_conditions = map_sklearn_split_into_bartpy_split_conditions( sklearn_tree, index) decision_node = split_node(searched_node, split_conditions) left_child: LeafNode = decision_node.left_child right_child: LeafNode = decision_node.right_child left_child.set_value(sklearn_tree.value[left_child_index][0][0]) right_child.set_value(sklearn_tree.value[right_child_index][0][0]) mutation = GrowMutation(searched_node, decision_node) mutate(bartpy_tree, mutation) nodes[index] = decision_node nodes[left_child_index] = decision_node.left_child nodes[right_child_index] = decision_node.right_child search(left_child_index) search(right_child_index)
def setUp(self): self.data = Data( pd.DataFrame({ "a": [1, 2, 3], "b": [1, 2, 3] }).values, np.array([1, 2, 3])) self.a = split_node(LeafNode(Split( self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt))) self.b = self.a.left_child self.x = self.a.right_child self.tree = Tree([self.a, self.b, self.x]) self.c = split_node( self.a._right_child, (SplitCondition(1, 2, le), SplitCondition(1, 2, gt))) mutate(self.tree, TreeMutation("grow", self.x, self.c)) self.d = self.c.left_child self.e = self.c.right_child
def setUp(self): X = format_covariate_matrix( pd.DataFrame({ "a": [1, 2, 3], "b": [1, 2, 3] })) self.data = Data(X, np.array([1, 2, 3]).astype(float)) self.a = split_node(LeafNode(Split( self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt))) self.b = self.a.left_child self.x = self.a.right_child self.tree = Tree([self.a, self.b, self.x]) self.c = split_node( self.a._right_child, (SplitCondition(1, 2, le), SplitCondition(1, 2, gt))) mutate(self.tree, TreeMutation("grow", self.x, self.c)) self.d = self.c.left_child self.e = self.c.right_child
def sample_split_node(node: LeafNode) -> DecisionNode: """ Split a leaf node into a decision node with two leaf children The variable and value to split on is determined by sampling from their respective distributions """ if node.is_splittable(): conditions = sample_split_condition(node) return split_node(node, conditions) else: return DecisionNode(node.split, LeafNode(node.split, depth=node.depth + 1), LeafNode(node.split, depth=node.depth + 1), depth=node.depth)
def test_split(self): left_split_condition = SplitCondition(0, 3, le) right_split_condition = SplitCondition(0, 3, gt) updated_node = split_node( self.node, [left_split_condition, right_split_condition]) self.assertIsInstance(updated_node, DecisionNode) self.assertEqual(updated_node.data.y.summed_y(), 15) self.assertEqual(updated_node.left_child.data.y.summed_y(), 6) self.assertEqual(updated_node.right_child.data.y.summed_y(), 9) self.assertEqual(updated_node.data.X.n_obsv, 5) self.assertEqual(updated_node.left_child.data.X.n_obsv, 3) self.assertEqual(updated_node.right_child.data.X.n_obsv, 2) updated_node.update_y([2.0, 4.0, 6.0, 8.0, 10.0]) self.assertEqual(updated_node.left_child.data.y.summed_y(), 12)