def test_mutate_any_node(): # should be able to mutate every node (except 'root') sampler = mutate.get_random_mutation_sampler() rule_counter = Counter() for p in data.pipelines: t = pt.to_tree(p) t.annotate() flat_nodes = flatten(t) for n in flat_nodes: rules = sampler.sample_rules(n, return_proba=True) # always has one rule per node assert len(rules) == 1 rule, prob = rules[0] rule_counter[type(rule)] += 1 if pt.is_root_node(n): assert rule is None assert prob == 0.0 continue if pt.is_composition_op(n): # can only insert for these if n.label.endswith("Pipeline"): assert isinstance(rule, ComponentInsert) if n.label.endswith("FeatureUnion"): assert isinstance(rule, (ComponentInsert, ComponentRemove)) if pt.is_param_node(n): assert n.parent is not None parent_label = n.parent.label param = n.payload[0] if parent_label not in sampler.config or param not in sampler.config[ parent_label]: assert isinstance(rule, HyperparamRemove) or rule is None if rule is None: assert prob == 0.0 else: assert prob > 0.0 continue if param == "score_func": assert rule is None assert prob == 0.0 continue if pt.is_component_node(n): if isinstance(rule, ComponentInsert): assert pt.is_composition_node(n) # all others should have non-identity rule assert rule is not None assert prob > 0.0 # should have all rule types (bad luck if don't....) for t in RULE_TYPES: assert rule_counter.get(t, 0.0) > 0
def must_delete_subtree(tree, new_children): """ Some subtrees become no-ops (or invalid) if they have no children, so check if that is the case """ del_if_empty_combinators = [ "sklearn.pipeline.Pipeline", "tpot.builtins.stacking_estimator.StackingEstimator" ] return (pt.is_root_node(tree) or (pt.is_component_node(tree) and tree.label in del_if_empty_combinators)) and all( c is None for c in new_children)
def sample_rules(self, n, return_proba=False, random_state=None): if pt.is_root_node(n): return [(None, 0.0)] else: try: random_mutation = self.get_random_mutation(n) except: random_mutation = None if random_mutation is None: # deprioritize identity ops random_prob = 0.0 else: random_prob = self.random_state.uniform() return [(random_mutation, random_prob)]
def replace_node(tree, orig_node, new_node): """ Replace a node anywhere in the tree, and share references to unchanged portions of the tree """ if tree == orig_node: if pt.is_root_node(tree): # we're at the root, so we should copy the entire tree new_tree = pt.deep_copy(tree) # root has a single child by pt.to_tree construction assert len(new_tree.children) == 1 new_tree.replace_child(0, new_node) return new_tree, True return new_node, True else: any_change = False new_children = [] for ix, c in enumerate(tree.children): new_c, changed = replace_node( c, orig_node, new_node, ) new_children.append(new_c) if changed: # rest are unchanged new_children.extend(tree.children[(ix + 1):]) any_change = True break if not any_change: # not on the path to our target replacement node # so can share all references return tree, False else: return replace_children(tree, new_children), True
def get_probability(self, rule=None, node=None): if pt.is_root_node(node): return 0.0 else: return self.random_state.uniform()