def map_structure_with_atomic(is_atomic_fn, map_fn, nested): """Maps the atomic elements of a nested structure. Arguments: is_atomic_fn: A function that determines if an element of `nested` is atomic. map_fn: The function to apply to atomic elements of `nested`. nested: A nested structure. Returns: The nested structure, with atomic elements mapped according to `map_fn`. Raises: ValueError: If an element that is neither atomic nor a sequence is encountered. """ if is_atomic_fn(nested): return map_fn(nested) # Recursively convert. if not nest.is_sequence(nested): raise ValueError( 'Received non-atomic and non-sequence element: {}'.format(nested)) if nest._is_mapping(nested): values = [nested[k] for k in nest._sorted(nested)] else: values = nested mapped_values = [ map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values ] return nest._sequence_like(nested, mapped_values)
def map_structure_with_atomic(is_atomic_fn, map_fn, nested): """Maps the atomic elements of a nested structure. Arguments: is_atomic_fn: A function that determines if an element of `nested` is atomic. map_fn: The function to apply to atomic elements of `nested`. nested: A nested structure. Returns: The nested structure, with atomic elements mapped according to `map_fn`. Raises: ValueError: If an element that is neither atomic nor a sequence is encountered. """ if is_atomic_fn(nested): return map_fn(nested) # Recursively convert. if not nest.is_sequence(nested): raise ValueError( 'Received non-atomic and non-sequence element: {}'.format(nested)) if nest._is_mapping(nested): values = [nested[k] for k in nest._sorted(nested)] else: values = nested mapped_values = [ map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values ] return nest._sequence_like(nested, mapped_values)
def arg_retriving_path(arg, path=()): """ Get retriving path of an argument. Args: arg: The input signature of an argument. Yield: See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/nest.py::_yield_sorted_items """ if not nest.is_sequence(arg): yield path elif isinstance(arg, nest._collections_abc.Mapping): for key in nest._sorted(arg): for res in arg_retriving_path(arg[key], path + (('[]', key), )): yield res elif nest._is_attrs(arg): for item in nest._get_attrs_items(arg): for res in arg_retriving_path(item[1], path + (('.', item[0]), )): yield res elif nest._is_namedtuple(arg): for field in arg._fields: for res in arg_retriving_path(getattr(arg, field), path + (('.', field), )): yield res # Doesn't support composite_tensor comprared with _yield_sorted_items. elif nest._is_type_spec(arg): # Note: to allow CompositeTensors and their TypeSpecs to have matching # structures, we need to use the same key string here. for res in arg_retriving_path( arg._component_specs, path + (('.', arg.value_type.__name__), )): yield res else: for item in enumerate(arg): for res in arg_retriving_path(item[1], path + (('[]', item[0]), )): yield res