def test_pruning_non_leaf_parent(self): a = LeafNode(Split(self.data)) b = LeafNode(Split(self.data)) c = LeafNode(Split(self.data)) d = DecisionNode(Split(self.data), a, b) e = DecisionNode(Split(self.data), c, d) with self.assertRaises(TypeError): PruneMutation(e, a)
def test_head_prune(self): b, c = LeafNode(Split(self.data)), LeafNode(Split(self.data)) a = DecisionNode(Split(self.data), b, c) tree = Tree([a, b, c]) updated_a = LeafNode(Split(self.data)) prune_mutation = PruneMutation(a, updated_a) mutate(tree, prune_mutation) self.assertIn(updated_a, tree.leaf_nodes) self.assertNotIn(self.a, tree.nodes)
def test_pruning_leaf(self): with self.assertRaises(TypeError): PruneMutation(LeafNode(Split(self.data)), LeafNode(Split(self.data)))
def uniformly_sample_prune_mutation(tree: Tree) -> TreeMutation: node = random_prunable_decision_node(tree) updated_node = LeafNode(node.split, depth=node.depth) return PruneMutation(node, updated_node)
def prune_mutations(tree: Tree) -> List[TreeMutation]: return [ PruneMutation(x, LeafNode(x.split, depth=x.depth)) for x in tree.prunable_decision_nodes ]
def test_invalid_prune(self): with self.assertRaises(TypeError): updated_a = LeafNode(Split(self.data)) PruneMutation(self.a, updated_a)