コード例 #1
0
def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
    """Maps the atomic elements of a nested structure.

  Arguments:
    is_atomic_fn: A function that determines if an element of `nested` is
      atomic.
    map_fn: The function to apply to atomic elements of `nested`.
    nested: A nested structure.

  Returns:
    The nested structure, with atomic elements mapped according to `map_fn`.

  Raises:
    ValueError: If an element that is neither atomic nor a sequence is
      encountered.
  """
    if is_atomic_fn(nested):
        return map_fn(nested)

    # Recursively convert.
    if not nest.is_sequence(nested):
        raise ValueError(
            'Received non-atomic and non-sequence element: {}'.format(nested))
    if nest._is_mapping(nested):
        values = [nested[k] for k in nest._sorted(nested)]
    else:
        values = nested
    mapped_values = [
        map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
    ]
    return nest._sequence_like(nested, mapped_values)
コード例 #2
0
ファイル: nest.py プロジェクト: xutianming/tensorflow
def _packed_nest_with_indices(structure, flat, index):
  """Helper function for pack_nest_as.

  Args:
    structure: Substructure (tuple of elements and/or tuples) to mimic
    flat: Flattened values to output substructure for.
    index: Index at which to start reading from flat.

  Returns:
    The tuple (new_index, child), where:
      * new_index - the updated index into `flat` having processed `structure`.
      * packed - the subset of `flat` corresponding to `structure`,
                 having started at `index`, and packed into the same nested
                 format.

  Raises:
    ValueError: if `structure` contains more elements than `flat`
      (assuming indexing starts from `index`).
  """
  packed = []
  for s in _yield_value(structure):
    if is_sequence(s):
      new_index, child = _packed_nest_with_indices(s, flat, index)
      packed.append(nest._sequence_like(s, child))  # pylint: disable=protected-access
      index = new_index
    else:
      packed.append(flat[index])
      index += 1
  return index, packed
コード例 #3
0
ファイル: tf_utils.py プロジェクト: kylin9872/tensorflow
def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
  """Maps the atomic elements of a nested structure.

  Arguments:
    is_atomic_fn: A function that determines if an element of `nested` is
      atomic.
    map_fn: The function to apply to atomic elements of `nested`.
    nested: A nested structure.

  Returns:
    The nested structure, with atomic elements mapped according to `map_fn`.

  Raises:
    ValueError: If an element that is neither atomic nor a sequence is
      encountered.
  """
  if is_atomic_fn(nested):
    return map_fn(nested)

  # Recursively convert.
  if not nest.is_sequence(nested):
    raise ValueError(
        'Received non-atomic and non-sequence element: {}'.format(nested))
  if nest._is_mapping(nested):
    values = [nested[k] for k in nest._sorted(nested)]
  else:
    values = nested
  mapped_values = [
      map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
  ]
  return nest._sequence_like(nested, mapped_values)
コード例 #4
0
ファイル: nest_util.py プロジェクト: gisilvs/probability
def _coerce_structure(shallow_tree, input_tree):
    """Implementation of coerce_structure."""
    if not nest.is_nested(shallow_tree):
        return input_tree

    if not nest.is_nested(input_tree):
        raise TypeError(
            nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree)))

    if len(input_tree) != len(shallow_tree):
        raise ValueError(
            nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
                input_length=len(input_tree),
                shallow_length=len(shallow_tree)))

    # Determine whether shallow_tree should be treated as a Mapping or a Sequence.
    # Namedtuples can be interpreted either way (but keys take precedence).
    _shallow_is_namedtuple = nest._is_namedtuple(shallow_tree)  # pylint: disable=invalid-name
    _shallow_is_mapping = isinstance(shallow_tree, collections.abc.Mapping)  # pylint: disable=invalid-name
    shallow_supports_keys = _shallow_is_namedtuple or _shallow_is_mapping
    shallow_supports_iter = _shallow_is_namedtuple or not _shallow_is_mapping

    # Branch-selection depends on both shallow and input container-classes.
    input_is_mapping = isinstance(input_tree, collections.abc.Mapping)
    if nest._is_namedtuple(input_tree):
        if shallow_supports_keys:
            lookup_branch = lambda k: getattr(input_tree, k)
        else:
            input_iter = nest._yield_value(input_tree)
            lookup_branch = lambda _: next(input_iter)
    elif shallow_supports_keys and input_is_mapping:
        lookup_branch = lambda k: input_tree[k]
    elif shallow_supports_iter and not input_is_mapping:
        input_iter = nest._yield_value(input_tree)
        lookup_branch = lambda _: next(input_iter)
    else:
        raise TypeError(
            nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
                input_type=type(input_tree),
                shallow_type=(type(shallow_tree.__wrapped__) if hasattr(
                    shallow_tree, '__wrapped__') else type(shallow_tree))))

    flat_coerced = []
    needs_wrapping = type(shallow_tree) is not type(input_tree)
    for shallow_key, shallow_branch in nest._yield_sorted_items(shallow_tree):
        try:
            input_branch = lookup_branch(shallow_key)
        except (KeyError, AttributeError):
            raise ValueError(
                nest._SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key]))
        flat_coerced.append(_coerce_structure(shallow_branch, input_branch))
        # Keep track of whether nested elements have changed.
        needs_wrapping |= input_branch is not flat_coerced[-1]

    # Only create a new instance if containers differ or contents changed.
    return (nest._sequence_like(shallow_tree, flat_coerced)
            if needs_wrapping else input_tree)
コード例 #5
0
ファイル: nest.py プロジェクト: xutianming/tensorflow
def pack_sequence_as(structure, flat_sequence):
  """Returns a given flattened sequence packed into a nest.

  If `structure` is a scalar, `flat_sequence` must be a single-element list;
  in this case the return value is `flat_sequence[0]`.

  Args:
    structure: tuple or list constructed of scalars and/or other tuples/lists,
      or a scalar.  Note: numpy arrays are considered scalars.
    flat_sequence: flat sequence to pack.

  Returns:
    packed: `flat_sequence` converted to have the same recursive structure as
      `structure`.

  Raises:
    ValueError: If nest and structure have different element counts.
  """
  if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)):
    raise TypeError("Argument `flat_sequence` must be a sequence. Got "
                    f"'{type(flat_sequence).__name__}'.")

  if not is_sequence(structure):
    if len(flat_sequence) != 1:
      raise ValueError("Argument `structure` is a scalar but "
                       f"`len(flat_sequence)`={len(flat_sequence)} > 1")
    return flat_sequence[0]

  flat_structure = flatten(structure)
  if len(flat_structure) != len(flat_sequence):
    raise ValueError(
        "Could not pack sequence. Argument `structure` had "
        f"{len(flat_structure)} elements, but argument `flat_sequence` had "
        f"{len(flat_sequence)} elements. Received structure: "
        f"{structure}, flat_sequence: {flat_sequence}.")

  _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
  return nest._sequence_like(structure, packed)  # pylint: disable=protected-access
コード例 #6
0
def pack_sequence_as(structure, flat_sequence):
    """Returns a given flattened sequence packed into a nest.

  If `structure` is a scalar, `flat_sequence` must be a single-element list;
  in this case the return value is `flat_sequence[0]`.

  Args:
    structure: tuple or list constructed of scalars and/or other tuples/lists,
      or a scalar.  Note: numpy arrays are considered scalars.
    flat_sequence: flat sequence to pack.

  Returns:
    packed: `flat_sequence` converted to have the same recursive structure as
      `structure`.

  Raises:
    ValueError: If nest and structure have different element counts.
  """
    if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)):
        raise TypeError("flat_sequence must be a sequence")

    if not is_sequence(structure):
        if len(flat_sequence) != 1:
            raise ValueError(
                "Structure is a scalar but len(flat_sequence) == %d > 1" %
                len(flat_sequence))
        return flat_sequence[0]

    flat_structure = flatten(structure)
    if len(flat_structure) != len(flat_sequence):
        raise ValueError(
            "Could not pack sequence. Structure had %d elements, but flat_sequence "
            "had %d elements.  Structure: %s, flat_sequence: %s." %
            (len(flat_structure), len(flat_sequence), structure,
             flat_sequence))

    _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
    return nest._sequence_like(structure, packed)  # pylint: disable=protected-access