def _yield_value(iterable): """Yield elements of `iterable` in a deterministic order. Args: iterable: an iterable. Yields: The iterable elements in a deterministic order. """ # pylint: disable=protected-access if isinstance(iterable, _collections_abc.Mapping): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing # ordered and plain dicts (e.g., flattening a dict but using a # corresponding `OrderedDict` to pack it back). for key in _sorted(iterable): yield iterable[key] elif isinstance(iterable, _sparse_tensor.SparseTensorValue): yield iterable elif nest._is_attrs(iterable): for _, attr in nest._get_attrs_items(iterable): yield attr else: for value in iterable: yield value
def arg_retriving_path(arg, path=()): """ Get retriving path of an argument. Args: arg: The input signature of an argument. Yield: See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/nest.py::_yield_sorted_items """ if not nest.is_sequence(arg): yield path elif isinstance(arg, nest._collections_abc.Mapping): for key in nest._sorted(arg): for res in arg_retriving_path(arg[key], path + (('[]', key), )): yield res elif nest._is_attrs(arg): for item in nest._get_attrs_items(arg): for res in arg_retriving_path(item[1], path + (('.', item[0]), )): yield res elif nest._is_namedtuple(arg): for field in arg._fields: for res in arg_retriving_path(getattr(arg, field), path + (('.', field), )): yield res # Doesn't support composite_tensor comprared with _yield_sorted_items. elif nest._is_type_spec(arg): # Note: to allow CompositeTensors and their TypeSpecs to have matching # structures, we need to use the same key string here. for res in arg_retriving_path( arg._component_specs, path + (('.', arg.value_type.__name__), )): yield res else: for item in enumerate(arg): for res in arg_retriving_path(item[1], path + (('[]', item[0]), )): yield res
def __len__(self) -> int: warnings.warn("Temporary hotfix") assert nest._is_attrs(self) return len(nest._get_attrs_items(self))