Ejemplo n.º 1
0
def _yield_value(iterable):
  """Yield elements of `iterable` in a deterministic order.

  Args:
    iterable: an iterable.

  Yields:
    The iterable elements in a deterministic order.
  """
  # pylint: disable=protected-access
  if isinstance(iterable, _collections_abc.Mapping):
    # Iterate through dictionaries in a deterministic order by sorting the
    # keys. Notice this means that we ignore the original order of `OrderedDict`
    # instances. This is intentional, to avoid potential bugs caused by mixing
    # ordered and plain dicts (e.g., flattening a dict but using a
    # corresponding `OrderedDict` to pack it back).
    for key in _sorted(iterable):
      yield iterable[key]
  elif isinstance(iterable, _sparse_tensor.SparseTensorValue):
    yield iterable
  elif nest._is_attrs(iterable):
    for _, attr in nest._get_attrs_items(iterable):
      yield attr
  else:
    for value in iterable:
      yield value
Ejemplo n.º 2
0
def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
    """Maps the atomic elements of a nested structure.

  Arguments:
    is_atomic_fn: A function that determines if an element of `nested` is
      atomic.
    map_fn: The function to apply to atomic elements of `nested`.
    nested: A nested structure.

  Returns:
    The nested structure, with atomic elements mapped according to `map_fn`.

  Raises:
    ValueError: If an element that is neither atomic nor a sequence is
      encountered.
  """
    if is_atomic_fn(nested):
        return map_fn(nested)

    # Recursively convert.
    if not nest.is_sequence(nested):
        raise ValueError(
            'Received non-atomic and non-sequence element: {}'.format(nested))
    if nest._is_mapping(nested):
        values = [nested[k] for k in nest._sorted(nested)]
    elif nest._is_attrs(nested):
        values = _astuple(nested)
    else:
        values = nested
    mapped_values = [
        map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
    ]
    return nest._sequence_like(nested, mapped_values)
Ejemplo n.º 3
0
    def testAttrsFlattenAndPack(self):
        if attr is None:
            self.skipTest("attr module is unavailable.")

        field_values = [1, 2]
        sample_attr = NestTest.SampleAttr(*field_values)
        self.assertFalse(nest._is_attrs(field_values))
        self.assertTrue(nest._is_attrs(sample_attr))
        flat = nest.flatten(sample_attr)
        self.assertEqual(field_values, flat)
        restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
        self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
        self.assertEqual(restructured_from_flat, sample_attr)

        # Check that flatten fails if attributes are not iterable
        with self.assertRaisesRegexp(TypeError, "object is not iterable"):
            flat = nest.flatten(NestTest.BadAttr())
Ejemplo n.º 4
0
  def testAttrsFlattenAndPack(self):
    if attr is None:
      self.skipTest("attr module is unavailable.")

    field_values = [1, 2]
    sample_attr = NestTest.SampleAttr(*field_values)
    self.assertFalse(nest._is_attrs(field_values))
    self.assertTrue(nest._is_attrs(sample_attr))
    flat = nest.flatten(sample_attr)
    self.assertEqual(field_values, flat)
    restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
    self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
    self.assertEqual(restructured_from_flat, sample_attr)

    # Check that flatten fails if attributes are not iterable
    with self.assertRaisesRegexp(TypeError, "object is not iterable"):
      flat = nest.flatten(NestTest.BadAttr())
Ejemplo n.º 5
0
        def arg_retriving_path(arg, path=()):
            """
      Get retriving path of an argument.

      Args:
        arg: The input signature of an argument.

      Yield:

      See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/nest.py::_yield_sorted_items
      """
            if not nest.is_sequence(arg):
                yield path
            elif isinstance(arg, nest._collections_abc.Mapping):
                for key in nest._sorted(arg):
                    for res in arg_retriving_path(arg[key],
                                                  path + (('[]', key), )):
                        yield res
            elif nest._is_attrs(arg):
                for item in nest._get_attrs_items(arg):
                    for res in arg_retriving_path(item[1],
                                                  path + (('.', item[0]), )):
                        yield res
            elif nest._is_namedtuple(arg):
                for field in arg._fields:
                    for res in arg_retriving_path(getattr(arg, field),
                                                  path + (('.', field), )):
                        yield res
            # Doesn't support composite_tensor comprared with _yield_sorted_items.
            elif nest._is_type_spec(arg):
                # Note: to allow CompositeTensors and their TypeSpecs to have matching
                # structures, we need to use the same key string here.
                for res in arg_retriving_path(
                        arg._component_specs,
                        path + (('.', arg.value_type.__name__), )):
                    yield res
            else:
                for item in enumerate(arg):
                    for res in arg_retriving_path(item[1],
                                                  path + (('[]', item[0]), )):
                        yield res
Ejemplo n.º 6
0
 def __len__(self) -> int:
     warnings.warn("Temporary hotfix")
     assert nest._is_attrs(self)
     return len(nest._get_attrs_items(self))