def get_traverse_shallow_structure(traverse_fn, structure): """Generates a shallow structure from a `traverse_fn` and `structure`. `traverse_fn` must accept any possible subtree of `structure` and return a depth=1 structure containing `True` or `False` values, describing which of the top-level subtrees may be traversed. It may also return scalar `True` or `False` 'traversal is OK / not OK for all subtrees.' Examples are available in the unit tests (nest_test.py). Args: traverse_fn: Function taking a substructure and returning either a scalar `bool` (whether to traverse that substructure or not) or a depth=1 shallow structure of the same type, describing which parts of the substructure to traverse. structure: The structure to traverse. Returns: A shallow structure containing python bools, which can be passed to `map_up_to` and `flatten_up_to`. Raises: TypeError: if `traverse_fn` returns a sequence for a non-sequence input, or a structure with depth higher than 1 for a sequence input, or if any leaf values in the returned structure or scalar are not type `bool`. """ to_traverse = traverse_fn(structure) if not is_nested(structure): if not isinstance(to_traverse, bool): raise TypeError('traverse_fn returned structure: %s for non-structure: %s' % (to_traverse, structure)) return to_traverse level_traverse = [] if isinstance(to_traverse, bool): if not to_traverse: # Do not traverse this substructure at all. Exit early. return False else: # Traverse the entire substructure. for branch in _yield_value(structure): level_traverse.append( get_traverse_shallow_structure(traverse_fn, branch)) elif not is_nested(to_traverse): raise TypeError('traverse_fn returned a non-bool scalar: %s for input: %s' % (to_traverse, structure)) else: # Traverse some subset of this substructure. assert_shallow_structure(to_traverse, structure) for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): if not isinstance(t, bool): raise TypeError( 'traverse_fn didn\'t return a depth=1 structure of bools. saw: %s ' ' for structure: %s' % (to_traverse, structure)) if t: level_traverse.append( get_traverse_shallow_structure(traverse_fn, branch)) else: level_traverse.append(False) return _sequence_like(structure, level_traverse)
def apply_to_structure(branch_fn, leaf_fn, structure): """`apply_to_structure` applies branch_fn and leaf_fn to branches and leaves. This function accepts two separate callables depending on whether the structure is a sequence. Args: branch_fn: A function to call on a struct if is_nested(struct) is `True`. leaf_fn: A function to call on a struct if is_nested(struct) is `False`. structure: A nested structure containing arguments to be applied to. Returns: A nested structure of function outputs. Raises: TypeError: If `branch_fn` or `leaf_fn` is not callable. ValueError: If no structure is provided. """ if not callable(leaf_fn): raise TypeError('leaf_fn must be callable, got: %s' % leaf_fn) if not callable(branch_fn): raise TypeError('branch_fn must be callable, got: %s' % branch_fn) if not is_nested(structure): return leaf_fn(structure) processed = branch_fn(structure) new_structure = [ apply_to_structure(branch_fn, leaf_fn, value) for value in _yield_value(processed) ] return _sequence_like(processed, new_structure)