Example #1
0
    def test_pruning_non_leaf_parent(self):
        a = LeafNode(Split(self.data))
        b = LeafNode(Split(self.data))
        c = LeafNode(Split(self.data))
        d = DecisionNode(Split(self.data), a, b)
        e = DecisionNode(Split(self.data), c, d)

        with self.assertRaises(TypeError):
            PruneMutation(e, a)
Example #2
0
 def test_head_prune(self):
     b, c = LeafNode(Split(self.data)), LeafNode(Split(self.data))
     a = DecisionNode(Split(self.data), b, c)
     tree = Tree([a, b, c])
     updated_a = LeafNode(Split(self.data))
     prune_mutation = PruneMutation(a, updated_a)
     mutate(tree, prune_mutation)
     self.assertIn(updated_a, tree.leaf_nodes)
     self.assertNotIn(self.a, tree.nodes)
Example #3
0
 def test_pruning_leaf(self):
     with self.assertRaises(TypeError):
         PruneMutation(LeafNode(Split(self.data)),
                       LeafNode(Split(self.data)))
Example #4
0
def uniformly_sample_prune_mutation(tree: Tree) -> TreeMutation:
    node = random_prunable_decision_node(tree)
    updated_node = LeafNode(node.split, depth=node.depth)
    return PruneMutation(node, updated_node)
Example #5
0
def prune_mutations(tree: Tree) -> List[TreeMutation]:
    return [
        PruneMutation(x, LeafNode(x.split, depth=x.depth))
        for x in tree.prunable_decision_nodes
    ]
Example #6
0
 def test_invalid_prune(self):
     with self.assertRaises(TypeError):
         updated_a = LeafNode(Split(self.data))
         PruneMutation(self.a, updated_a)