def map_structure_with_atomic(is_atomic_fn, map_fn, nested): """Maps the atomic elements of a nested structure. Arguments: is_atomic_fn: A function that determines if an element of `nested` is atomic. map_fn: The function to apply to atomic elements of `nested`. nested: A nested structure. Returns: The nested structure, with atomic elements mapped according to `map_fn`. Raises: ValueError: If an element that is neither atomic nor a sequence is encountered. """ if is_atomic_fn(nested): return map_fn(nested) # Recursively convert. if not nest.is_sequence(nested): raise ValueError( 'Received non-atomic and non-sequence element: {}'.format(nested)) if nest._is_mapping(nested): values = [nested[k] for k in nest._sorted(nested)] else: values = nested mapped_values = [ map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values ] return nest._sequence_like(nested, mapped_values)
def _packed_nest_with_indices(structure, flat, index): """Helper function for pack_nest_as. Args: structure: Substructure (tuple of elements and/or tuples) to mimic flat: Flattened values to output substructure for. index: Index at which to start reading from flat. Returns: The tuple (new_index, child), where: * new_index - the updated index into `flat` having processed `structure`. * packed - the subset of `flat` corresponding to `structure`, having started at `index`, and packed into the same nested format. Raises: ValueError: if `structure` contains more elements than `flat` (assuming indexing starts from `index`). """ packed = [] for s in _yield_value(structure): if is_sequence(s): new_index, child = _packed_nest_with_indices(s, flat, index) packed.append(nest._sequence_like(s, child)) # pylint: disable=protected-access index = new_index else: packed.append(flat[index]) index += 1 return index, packed
def _coerce_structure(shallow_tree, input_tree): """Implementation of coerce_structure.""" if not nest.is_nested(shallow_tree): return input_tree if not nest.is_nested(input_tree): raise TypeError( nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))) if len(input_tree) != len(shallow_tree): raise ValueError( nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(input_tree), shallow_length=len(shallow_tree))) # Determine whether shallow_tree should be treated as a Mapping or a Sequence. # Namedtuples can be interpreted either way (but keys take precedence). _shallow_is_namedtuple = nest._is_namedtuple(shallow_tree) # pylint: disable=invalid-name _shallow_is_mapping = isinstance(shallow_tree, collections.abc.Mapping) # pylint: disable=invalid-name shallow_supports_keys = _shallow_is_namedtuple or _shallow_is_mapping shallow_supports_iter = _shallow_is_namedtuple or not _shallow_is_mapping # Branch-selection depends on both shallow and input container-classes. input_is_mapping = isinstance(input_tree, collections.abc.Mapping) if nest._is_namedtuple(input_tree): if shallow_supports_keys: lookup_branch = lambda k: getattr(input_tree, k) else: input_iter = nest._yield_value(input_tree) lookup_branch = lambda _: next(input_iter) elif shallow_supports_keys and input_is_mapping: lookup_branch = lambda k: input_tree[k] elif shallow_supports_iter and not input_is_mapping: input_iter = nest._yield_value(input_tree) lookup_branch = lambda _: next(input_iter) else: raise TypeError( nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format( input_type=type(input_tree), shallow_type=(type(shallow_tree.__wrapped__) if hasattr( shallow_tree, '__wrapped__') else type(shallow_tree)))) flat_coerced = [] needs_wrapping = type(shallow_tree) is not type(input_tree) for shallow_key, shallow_branch in nest._yield_sorted_items(shallow_tree): try: input_branch = lookup_branch(shallow_key) except (KeyError, AttributeError): raise ValueError( nest._SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key])) flat_coerced.append(_coerce_structure(shallow_branch, input_branch)) # Keep track of whether nested elements have changed. needs_wrapping |= input_branch is not flat_coerced[-1] # Only create a new instance if containers differ or contents changed. return (nest._sequence_like(shallow_tree, flat_coerced) if needs_wrapping else input_tree)
def pack_sequence_as(structure, flat_sequence): """Returns a given flattened sequence packed into a nest. If `structure` is a scalar, `flat_sequence` must be a single-element list; in this case the return value is `flat_sequence[0]`. Args: structure: tuple or list constructed of scalars and/or other tuples/lists, or a scalar. Note: numpy arrays are considered scalars. flat_sequence: flat sequence to pack. Returns: packed: `flat_sequence` converted to have the same recursive structure as `structure`. Raises: ValueError: If nest and structure have different element counts. """ if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)): raise TypeError("Argument `flat_sequence` must be a sequence. Got " f"'{type(flat_sequence).__name__}'.") if not is_sequence(structure): if len(flat_sequence) != 1: raise ValueError("Argument `structure` is a scalar but " f"`len(flat_sequence)`={len(flat_sequence)} > 1") return flat_sequence[0] flat_structure = flatten(structure) if len(flat_structure) != len(flat_sequence): raise ValueError( "Could not pack sequence. Argument `structure` had " f"{len(flat_structure)} elements, but argument `flat_sequence` had " f"{len(flat_sequence)} elements. Received structure: " f"{structure}, flat_sequence: {flat_sequence}.") _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) return nest._sequence_like(structure, packed) # pylint: disable=protected-access
def pack_sequence_as(structure, flat_sequence): """Returns a given flattened sequence packed into a nest. If `structure` is a scalar, `flat_sequence` must be a single-element list; in this case the return value is `flat_sequence[0]`. Args: structure: tuple or list constructed of scalars and/or other tuples/lists, or a scalar. Note: numpy arrays are considered scalars. flat_sequence: flat sequence to pack. Returns: packed: `flat_sequence` converted to have the same recursive structure as `structure`. Raises: ValueError: If nest and structure have different element counts. """ if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)): raise TypeError("flat_sequence must be a sequence") if not is_sequence(structure): if len(flat_sequence) != 1: raise ValueError( "Structure is a scalar but len(flat_sequence) == %d > 1" % len(flat_sequence)) return flat_sequence[0] flat_structure = flatten(structure) if len(flat_structure) != len(flat_sequence): raise ValueError( "Could not pack sequence. Structure had %d elements, but flat_sequence " "had %d elements. Structure: %s, flat_sequence: %s." % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) return nest._sequence_like(structure, packed) # pylint: disable=protected-access