def setUp(self): self.data = Data(pd.DataFrame({"a": [1]}).values, np.array([1])) self.d = LeafNode(Split(self.data), None) self.e = LeafNode(Split(self.data), None) self.c = DecisionNode(Split(self.data), self.d, self.e) self.b = LeafNode(Split(self.data)) self.a = DecisionNode(Split(self.data), self.b, self.c) self.tree = Tree([self.a, self.b, self.c, self.d, self.e])
def test_pruning_non_leaf_parent(self): a = LeafNode(Split(self.data)) b = LeafNode(Split(self.data)) c = LeafNode(Split(self.data)) d = DecisionNode(Split(self.data), a, b) e = DecisionNode(Split(self.data), c, d) with self.assertRaises(TypeError): PruneMutation(e, a)
def test_growing_decision_node(self): a = LeafNode(Split(self.data)) b = LeafNode(Split(self.data)) c = LeafNode(Split(self.data)) d = DecisionNode(Split(self.data), a, b) e = DecisionNode(Split(self.data), c, d) with self.assertRaises(TypeError): GrowMutation(d, a)
def setUp(self): self.data = Data(format_covariate_matrix(pd.DataFrame({"a": [1]})), np.array([1]).astype(float)) self.d = LeafNode(Split(self.data), None) self.e = LeafNode(Split(self.data), None) self.c = DecisionNode(Split(self.data), self.d, self.e) self.b = LeafNode(Split(self.data)) self.a = DecisionNode(Split(self.data), self.b, self.c) self.tree = Tree([self.a, self.b, self.c, self.d, self.e])
def test_head_prune(self): b, c = LeafNode(Split(self.data)), LeafNode(Split(self.data)) a = DecisionNode(Split(self.data), b, c) tree = Tree([a, b, c]) updated_a = LeafNode(Split(self.data)) prune_mutation = PruneMutation(a, updated_a) mutate(tree, prune_mutation) self.assertIn(updated_a, tree.leaf_nodes) self.assertNotIn(self.a, tree.nodes)
def test_grow(self): f, g = LeafNode(Split(self.data)), LeafNode(Split(self.data)) updated_d = DecisionNode(Split(self.data), f, g) grow_mutation = TreeMutation("grow", self.d, updated_d) mutate(self.tree, grow_mutation) self.assertIn(updated_d, self.tree.decision_nodes) self.assertIn(updated_d, self.tree.prunable_decision_nodes) self.assertIn(f, self.tree.leaf_nodes) self.assertNotIn(self.d, self.tree.nodes)
def sample_split_node(node: LeafNode) -> DecisionNode: """ Split a leaf node into a decision node with two leaf children The variable and value to split on is determined by sampling from their respective distributions """ if node.is_splittable(): conditions = sample_split_condition(node) return split_node(node, conditions) else: return DecisionNode(node.split, LeafNode(node.split, depth=node.depth + 1), LeafNode(node.split, depth=node.depth + 1), depth=node.depth)
def __init__(self, existing_node: DecisionNode, updated_node: LeafNode): if not type(existing_node ) == DecisionNode or not existing_node.is_prunable(): raise TypeError("Pruning only valid on prunable decision nodes") super().__init__("prune", existing_node, updated_node)