Ejemplo n.º 1
0
def get_traverse_shallow_structure(traverse_fn, structure):
  """Generates a shallow structure from a `traverse_fn` and `structure`.

  `traverse_fn` must accept any possible subtree of `structure` and return
  a depth=1 structure containing `True` or `False` values, describing which
  of the top-level subtrees may be traversed.  It may also
  return scalar `True` or `False` 'traversal is OK / not OK for all subtrees.'

  Examples are available in the unit tests (nest_test.py).

  Args:
    traverse_fn: Function taking a substructure and returning either a scalar
      `bool` (whether to traverse that substructure or not) or a depth=1
      shallow structure of the same type, describing which parts of the
      substructure to traverse.
    structure: The structure to traverse.

  Returns:
    A shallow structure containing python bools, which can be passed to
    `map_up_to` and `flatten_up_to`.

  Raises:
    TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
      or a structure with depth higher than 1 for a sequence input,
      or if any leaf values in the returned structure or scalar are not type
      `bool`.
  """
  to_traverse = traverse_fn(structure)
  if not is_nested(structure):
    if not isinstance(to_traverse, bool):
      raise TypeError('traverse_fn returned structure: %s for non-structure: %s'
                      % (to_traverse, structure))
    return to_traverse
  level_traverse = []
  if isinstance(to_traverse, bool):
    if not to_traverse:
      # Do not traverse this substructure at all.  Exit early.
      return False
    else:
      # Traverse the entire substructure.
      for branch in _yield_value(structure):
        level_traverse.append(
            get_traverse_shallow_structure(traverse_fn, branch))
  elif not is_nested(to_traverse):
    raise TypeError('traverse_fn returned a non-bool scalar: %s for input: %s'
                    % (to_traverse, structure))
  else:
    # Traverse some subset of this substructure.
    assert_shallow_structure(to_traverse, structure)
    for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
      if not isinstance(t, bool):
        raise TypeError(
            'traverse_fn didn\'t return a depth=1 structure of bools.  saw: %s '
            ' for structure: %s' % (to_traverse, structure))
      if t:
        level_traverse.append(
            get_traverse_shallow_structure(traverse_fn, branch))
      else:
        level_traverse.append(False)
  return _sequence_like(structure, level_traverse)
Ejemplo n.º 2
0
def apply_to_structure(branch_fn, leaf_fn, structure):
  """`apply_to_structure` applies branch_fn and leaf_fn to branches and leaves.

  This function accepts two separate callables depending on whether the
  structure is a sequence.

  Args:
    branch_fn: A function to call on a struct if is_nested(struct) is `True`.
    leaf_fn: A function to call on a struct if is_nested(struct) is `False`.
    structure: A nested structure containing arguments to be applied to.

  Returns:
    A nested structure of function outputs.

  Raises:
    TypeError: If `branch_fn` or `leaf_fn` is not callable.
    ValueError: If no structure is provided.
  """
  if not callable(leaf_fn):
    raise TypeError('leaf_fn must be callable, got: %s' % leaf_fn)

  if not callable(branch_fn):
    raise TypeError('branch_fn must be callable, got: %s' % branch_fn)

  if not is_nested(structure):
    return leaf_fn(structure)

  processed = branch_fn(structure)

  new_structure = [
      apply_to_structure(branch_fn, leaf_fn, value)
      for value in _yield_value(processed)
  ]
  return _sequence_like(processed, new_structure)