コード例 #1
0
def test_mutate_any_node():
    # should be able to mutate every node (except 'root')
    sampler = mutate.get_random_mutation_sampler()
    rule_counter = Counter()
    for p in data.pipelines:
        t = pt.to_tree(p)
        t.annotate()
        flat_nodes = flatten(t)

        for n in flat_nodes:
            rules = sampler.sample_rules(n, return_proba=True)
            # always has one rule per node
            assert len(rules) == 1
            rule, prob = rules[0]
            rule_counter[type(rule)] += 1
            if pt.is_root_node(n):
                assert rule is None
                assert prob == 0.0
                continue

            if pt.is_composition_op(n):
                # can only insert for these
                if n.label.endswith("Pipeline"):
                    assert isinstance(rule, ComponentInsert)
                if n.label.endswith("FeatureUnion"):
                    assert isinstance(rule, (ComponentInsert, ComponentRemove))

            if pt.is_param_node(n):
                assert n.parent is not None
                parent_label = n.parent.label
                param = n.payload[0]
                if parent_label not in sampler.config or param not in sampler.config[
                        parent_label]:
                    assert isinstance(rule, HyperparamRemove) or rule is None
                    if rule is None:
                        assert prob == 0.0
                    else:
                        assert prob > 0.0
                    continue

                if param == "score_func":
                    assert rule is None
                    assert prob == 0.0
                    continue

            if pt.is_component_node(n):
                if isinstance(rule, ComponentInsert):
                    assert pt.is_composition_node(n)

            # all others should have non-identity rule
            assert rule is not None
            assert prob > 0.0

    # should have all rule types (bad luck if don't....)
    for t in RULE_TYPES:
        assert rule_counter.get(t, 0.0) > 0
コード例 #2
0
def must_delete_subtree(tree, new_children):
    """
    Some subtrees become no-ops (or invalid) if they have
    no children, so check if that is the case
    """
    del_if_empty_combinators = [
        "sklearn.pipeline.Pipeline",
        "tpot.builtins.stacking_estimator.StackingEstimator"
    ]
    return (pt.is_root_node(tree) or
            (pt.is_component_node(tree)
             and tree.label in del_if_empty_combinators)) and all(
                 c is None for c in new_children)
コード例 #3
0
    def sample_rules(self, n, return_proba=False, random_state=None):
        if pt.is_root_node(n):
            return [(None, 0.0)]
        else:
            try:
                random_mutation = self.get_random_mutation(n)
            except:
                random_mutation = None

            if random_mutation is None:
                # deprioritize identity ops
                random_prob = 0.0
            else:
                random_prob = self.random_state.uniform()

            return [(random_mutation, random_prob)]
コード例 #4
0
def replace_node(tree, orig_node, new_node):
    """
    Replace a node anywhere in the tree, and share references
    to unchanged portions of the tree
    """
    if tree == orig_node:

        if pt.is_root_node(tree):
            # we're at the root, so we should copy the entire tree
            new_tree = pt.deep_copy(tree)
            # root has a single child by pt.to_tree construction
            assert len(new_tree.children) == 1
            new_tree.replace_child(0, new_node)
            return new_tree, True

        return new_node, True
    else:
        any_change = False
        new_children = []
        for ix, c in enumerate(tree.children):
            new_c, changed = replace_node(
                c,
                orig_node,
                new_node,
            )
            new_children.append(new_c)
            if changed:
                # rest are unchanged
                new_children.extend(tree.children[(ix + 1):])
                any_change = True
                break
        if not any_change:
            # not on the path to our target replacement node
            # so can share all references
            return tree, False
        else:
            return replace_children(tree, new_children), True
コード例 #5
0
 def get_probability(self, rule=None, node=None):
     if pt.is_root_node(node):
         return 0.0
     else:
         return self.random_state.uniform()