def testNestYieldFlatPaths(self):
    structure = [[TestCompositeTensor(1, 2, 3)], 100, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    }]
    result1 = list(nest.yield_flat_paths(structure, expand_composites=True))
    expected1 = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (1,), (2, 'y', 0, 0),
                 (2, 'y', 0, 1), (2, 'y', 1)]
    self.assertEqual(result1, expected1)

    result2 = list(nest.yield_flat_paths(structure, expand_composites=False))
    expected2 = [(0, 0), (1,), (2, 'y')]
    self.assertEqual(result2, expected2)
    def testNestYieldFlatPaths(self):
        structure = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        }]
        result1 = list(nest.yield_flat_paths(structure,
                                             expand_composites=True))
        expected1 = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (1, ), (2, 'y', 0, 0),
                     (2, 'y', 0, 1), (2, 'y', 1)]
        self.assertEqual(result1, expected1)

        result2 = list(
            nest.yield_flat_paths(structure, expand_composites=False))
        expected2 = [(0, 0), (1, ), (2, 'y')]
        self.assertEqual(result2, expected2)
  def testNestFlatten(self, structure, expected, paths, expand_composites=True):
    result = nest.flatten(structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)

    result_with_paths = nest.flatten_with_tuple_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))

    string_paths = ['/'.join(str(p) for p in path) for path in paths]  # pylint: disable=g-complex-comprehension
    result_with_string_paths = nest.flatten_with_joined_string_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_string_paths,
                     list(zip(string_paths, expected)))

    flat_paths_result = list(
        nest.yield_flat_paths(structure, expand_composites=expand_composites))
    self.assertEqual(flat_paths_result, paths)
示例#4
0
  def testNestFlatten(self, structure, expected, paths, expand_composites=True):
    result = nest.flatten(structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)

    result_with_paths = nest.flatten_with_tuple_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))

    string_paths = ['/'.join(str(p) for p in path) for path in paths]  # pylint: disable=g-complex-comprehension
    result_with_string_paths = nest.flatten_with_joined_string_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_string_paths,
                     list(zip(string_paths, expected)))

    flat_paths_result = list(
        nest.yield_flat_paths(structure, expand_composites=expand_composites))
    self.assertEqual(flat_paths_result, paths)
示例#5
0
def _create_pseudo_names(tensors, prefix):
  """Creates pseudo {input | output} names for subclassed Models.

  Warning: this function should only be used to define default
  names for `Metics` and `SavedModel`. No other use cases should
  rely on a `Model`'s input or output names.

  Example with dict:

  `{'a': [x1, x2], 'b': x3}` becomes:
  `['a_1', 'a_2', 'b']`

  Example with list:

  `[x, y]` becomes:
  `['output_1', 'output_2']`

  Arguments:
    tensors: `Model`'s outputs or inputs.
    prefix: 'output_' for outputs, 'input_' for inputs.

  Returns:
    Flattened list of pseudo names.
  """

  def one_index(ele):
    # Start with "output_1" instead of "output_0".
    if isinstance(ele, int):
      return ele + 1
    return ele

  flat_paths = list(nest.yield_flat_paths(tensors))
  flat_paths = nest.map_structure(one_index, flat_paths)
  names = []
  for path in flat_paths:
    if not path:
      name = prefix + '1'  # Single output.
    else:
      name = '_'.join(str(p) for p in path)
      if isinstance(path[0], int):
        name = prefix + name
    names.append(name)
  return names
示例#6
0
 def testYieldFlatStringPaths(self):
   for inputs_expected in ({"inputs": [], "expected": []},
                           {"inputs": 3, "expected": [()]},
                           {"inputs": [3], "expected": [(0,)]},
                           {"inputs": {"a": 3}, "expected": [("a",)]},
                           {"inputs": {"a": {"b": 4}},
                            "expected": [("a", "b")]},
                           {"inputs": [{"a": 2}], "expected": [(0, "a")]},
                           {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
                           {"inputs": [{"a": [(23, 42)]}],
                            "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
                           {"inputs": [{"a": ([23], 42)}],
                            "expected": [(0, "a", 0, 0), (0, "a", 1)]},
                           {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
                            "expected": [("a", "a"), ("c", 0, 0, 0)]},
                           {"inputs": {"0": [{"1": 23}]},
                            "expected": [("0", 0, "1")]}):
     inputs = inputs_expected["inputs"]
     expected = inputs_expected["expected"]
     self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
示例#7
0
 def testYieldFlatStringPaths(self):
   for inputs_expected in ({"inputs": [], "expected": []},
                           {"inputs": 3, "expected": [()]},
                           {"inputs": [3], "expected": [(0,)]},
                           {"inputs": {"a": 3}, "expected": [("a",)]},
                           {"inputs": {"a": {"b": 4}},
                            "expected": [("a", "b")]},
                           {"inputs": [{"a": 2}], "expected": [(0, "a")]},
                           {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
                           {"inputs": [{"a": [(23, 42)]}],
                            "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
                           {"inputs": [{"a": ([23], 42)}],
                            "expected": [(0, "a", 0, 0), (0, "a", 1)]},
                           {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
                            "expected": [("a", "a"), ("c", 0, 0, 0)]},
                           {"inputs": {"0": [{"1": 23}]},
                            "expected": [("0", 0, "1")]}):
     inputs = inputs_expected["inputs"]
     expected = inputs_expected["expected"]
     self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
示例#8
0
def create_output_names(y_pred):
    """Creates output names for subclassed Model outputs.

  These names are used for naming `Metric`s.

  Example with dict:

  `{'a': [x1, x2], 'b': x3}` becomes:
  `['a_1', 'a_2', 'b']`

  Example with list:

  `[x, y]` becomes:
  `['output_1', 'output_2']`

  Arguments:
    y_pred: `Model`'s outputs.

  Returns:
    Flattened list of output names.
  """
    def one_index(ele):
        # Start with "output_1" instead of "output_0".
        if isinstance(ele, int):
            return ele + 1
        return ele

    flat_paths = list(nest.yield_flat_paths(y_pred))
    flat_paths = nest.map_structure(one_index, flat_paths)
    output_names = []
    for path in flat_paths:
        if not path:
            output_name = 'output_1'
        else:
            output_name = '_'.join(str(p) for p in path)
            if isinstance(path[0], int):
                output_name = 'output_' + output_name
        output_names.append(output_name)
    return output_names
def build_trainable_linear_operator_block(
    operators,
    block_dims=None,
    batch_shape=(),
    dtype=None,
    name=None):
  """Builds a trainable blockwise `tf.linalg.LinearOperator`.

  This function returns a trainable blockwise `LinearOperator`. If `operators`
  is a flat list, it is interpreted as blocks along the diagonal of the
  structure and an instance of `tf.linalg.LinearOperatorBlockDiag` is returned.
  If `operators` is a doubly nested list, then a
  `tf.linalg.LinearOperatorBlockLowerTriangular` instance is returned, with
  the block in row `i` column `j` (`i >= j`) given by `operators[i][j]`.
  The `operators` list may contain `LinearOperator` instances, `LinearOperator`
  subclasses, or callables that return `LinearOperator` instances. The
  dimensions of the blocks are given by `block_dims`; this argument may be
  omitted if `operators` contains only `LinearOperator` instances.

  ### Examples

  ```python
  # Build a 5x5 trainable `LinearOperatorBlockDiag` given `LinearOperator`
  # subclasses and `block_dims`.
  op = build_trainable_linear_operator_block(
    operators=(tf.linalg.LinearOperatorDiag,
               tf.linalg.LinearOperatorLowerTriangular),
    block_dims=[3, 2],
    dtype=tf.float32)

  # Build an 8x8 `LinearOperatorBlockLowerTriangular`, with a callable that
  # returns a `LinearOperator` in the upper left block, and `LinearOperator`
  # subclasses in the lower two blocks.
  op = build_trainable_linear_operator_block(
    operators=(
      (lambda shape, dtype: tf.linalg.LinearOperatorScaledIdentity(
         num_rows=shape[-1], multiplier=tf.Variable(1., dtype=dtype))),
      (tf.linalg.LinearOperatorFullMatrix,
      tf.linalg.LinearOperatorLowerTriangular))
    block_dims=[4, 4],
    dtype=tf.float64)

  # Build a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,)`. Since
  # `operators` contains only `LinearOperator` instances, the `block_dims`
  # argument is not necessary.
  op = build_trainable_linear_operator_block(
    operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))),
               tf.linalg.LinearOperatorFullMatrix([4.]),
               tf.linalg.LinearOperatorIdentity(2)))
  ```

  Args:
    operators: A list or tuple containing `LinearOperator` subclasses,
      `LinearOperator` instances, or callables returning `LinearOperator`
      instances. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag`
      instance is returned. Otherwise, the list must be singly nested, with the
      first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is returned. Callables contained in the lists must take two
      arguments -- `shape`, the shape of the `tf.Variable` instantiating the
      `LinearOperator`, and `dtype`, the `tf.dtype` of the `LinearOperator`.
    block_dims: List or tuple of integers, representing the sizes of the blocks
      along one dimension of the (square) blockwise `LinearOperator`. If
      `operators` contains only `LinearOperator` instances, `block_dims` may be
      `None` and the dimensions are inferred.
    batch_shape: Batch shape of the `LinearOperator`.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.

  Returns:
    Trainable instance of `tf.linalg.LinearOperatorBlockDiag` or
      `tf.linalg.LinearOperatorBlockLowerTriangular`.
  """
  with tf.name_scope(name or 'build_trainable_blockwise_tril_operator'):
    operator_instances = [op for op in nest.flatten(operators)
                          if isinstance(op, tf.linalg.LinearOperator)]
    if (block_dims is None
        and len(operator_instances) < len(nest.flatten(operators))):
      # If `operator_instances` contains fewer elements than `operators`,
      # then some elements of `operators` are not instances of `LinearOperator`.
      raise ValueError('Argument `block_dims` must be defined unless '
                       '`operators` contains only `tf.linalg.LinearOperator` '
                       'instances.')

    batch_shape = ps.cast(batch_shape, tf.int32)
    if dtype is None:
      dtype = dtype_util.common_dtype(operator_instances)

    def convert_operator(path, op):
      if isinstance(op, tf.linalg.LinearOperator):
        return op
      builder = _OPERATOR_BUILDERS.get(op, op)
      if len(set(path)) == 1:  # for operators on the diagonal
        return builder(
            ps.concat([batch_shape, [block_dims[path[0]]]], axis=0),
            dtype=dtype)
      return builder(
          ps.concat([batch_shape, [block_dims[path[0]], block_dims[path[1]]]],
                    axis=0),
          dtype=dtype)

    operator_blocks = nest.map_structure_with_tuple_paths(
        convert_operator, operators)
    paths = nest.yield_flat_paths(operators)
    if all(len(p) == 1 for p in paths):
      return tf.linalg.LinearOperatorBlockDiag(
          operator_blocks, is_non_singular=True)
    elif all(len(p) == 2 for p in paths):
      return tf.linalg.LinearOperatorBlockLowerTriangular(
          operator_blocks, is_non_singular=True)
    else:
      raise ValueError(
          'Argument `operators` must be a flat or singly-nested sequence.')
示例#10
0
def flatten_with_tuple_paths(structure):
  return list(zip(nest.yield_flat_paths(structure), nest.flatten(structure)))
示例#11
0
def flatten_with_tuple_paths(structure):
  return list(zip(nest.yield_flat_paths(structure), nest.flatten(structure)))
示例#12
0
def map_structure_coroutine(
        coroutine,
        *structures,
        _expand_composites=False,  # pylint: disable=invalid-name
        _up_to=UNSPECIFIED,  # pylint: disable=invalid-name
        _with_tuple_paths=False,  # pylint: disable=invalid-name
        **named_structures):
    # pylint: disable=g-doc-return-or-yield
    """Invokes a coroutine multiple times with args from provided structures.

  This is semantically identical to `map_structure_with_named_args`, except
  that the first argument is a generator or coroutine (a callable whose body
  contains `yield` statements) rather than a function. This is invoked with
  arguments from the provided structure(s), thus defining an outer generator/
  coroutine that `yield`s values in sequence from each call to the inner
  `coroutine`.

  The argument structures are traversed, and the coroutine is invoked, in
  the order defined by `tf.nest.flatten`. A stripped-down implementation of
  the core logic is as follows:

  ```python
  def map_structure_coroutine(coroutine, *structures):
    flat_results = []
    for args in zip(*[tf.nest.flatten(s) for s in structures]):
      retval = yield from coroutine(*args)
      flat_results.append(retval)
    return tf.nest.pack_sequence_as(structures[0], flat_results)
  ```

  Args:
    coroutine: a generator/coroutine callable that accepts one or more named
      arguments.
    *structures: Structures of arguments passed positionally to `coroutine`.
    _expand_composites: Forwarded as
      `tf.nest.flatten(..., expand_composites=_expand_composites)`.
    _up_to: Optional shallow structure to map up to. If provided,
      `nest.map_structure_up_to` is called rather than `nest.map_structure`.
      Default value: `UNSPECIFIED`.
    _with_tuple_paths: Python bool. If `True`, the first argument to `coroutine`
      is a tuple path to the current leaf of the argument structure(s).
      Default value: `False`.
    **named_structures: Structures of arguments passed by name to `coroutine`.
  Yields:
    Values `yield`ed by each invocation of `coroutine`, with invocations in
      order corresponding to `tf.nest.flatten`.
  Returns:
    A new structure matching that of the input structures (or the shallow
      structure `_up_to`, if specified), in which each element is the return
      value from applying `coroutine` to the corresponding elements of the input
      structures.

  ## Examples

  A JointDistributionCoroutine may define a reusable submodel as its own
  coroutine, for example:

  ```python
  def horseshoe_prior(path, scale):
    # Auxiliary-variable representation of a horseshoe prior on sparse weights.
    name = ','.join(path)
    z = yield tfd.HalfCauchy(loc=0., scale=scale, name=name + '_z')
    w_noncentered = yield tfd.Normal(
        loc=0., scale=z, name=name + '_w_noncentered')
    return z * w_noncentered
  ```

  Note that this submodel yields two auxiliary random variables, and returns the
  sampled weight as a third value.

  Using `map_structure_coroutine` we can define a structure of such submodels,
  and collect their return values:

  ```
  @tfd.JointDistributionCoroutineAutoBatched
  def model():
    weights = yield from nest_util.map_structure_coroutine(
        horseshoe_prior,
        scale={'a': tf.ones([5]) * 100., 'b': tf.ones([2]) * 1e-2},
        _with_tuple_paths=True)
    # ==> `weights` is a dict of weight values.
    yield tfd.Deterministic(
        tf.sqrt(tf.norm(weights['a'])**2 + tf.norm(weights['b'])**2),
        name='weights_norm')

  print(model.event_shape)
  # ==> StructTuple(
  #       a_z=TensorShape([5]),
  #       a_w_noncentered=TensorShape([5]),
  #       b_z=TensorShape([2]),
  #       b_w_noncentered=TensorShape([2]),
  #       weights_norm=TensorShape([]))
  ```
  """
    # pylint: enable=g-doc-return-or-yield

    names, named_structure_values = (zip(
        *named_structures.items()) if named_structures else ((), ()))
    all_structures = structures + named_structure_values
    result_structure = all_structures[0] if _up_to is UNSPECIFIED else _up_to
    flat_arg_structures = [
        nest.flatten_up_to(result_structure, s) for s in all_structures
    ]

    if _with_tuple_paths:
        # Pass tuple paths as a first positional arg (before any provided args).
        flat_paths = nest.yield_flat_paths(
            result_structure, expand_composites=_expand_composites)
        flat_arg_structures = [list(flat_paths)] + flat_arg_structures
        num_positional_args = 1 + len(structures)
    else:
        num_positional_args = len(structures)

    flat_results = []
    for leaf_values in zip(*flat_arg_structures):
        result = yield from coroutine(
            *leaf_values[:num_positional_args],
            **dict(zip(names, leaf_values[num_positional_args:])))
        flat_results.append(result)

    return nest.pack_sequence_as(result_structure, flat_results)
def _trainable_linear_operator_block(operators,
                                     block_dims=None,
                                     batch_shape=(),
                                     dtype=None,
                                     name=None):
    """Builds a trainable blockwise `tf.linalg.LinearOperator`.

  This function returns a trainable blockwise `LinearOperator`. If `operators`
  is a flat list, it is interpreted as blocks along the diagonal of the
  structure and an instance of `tf.linalg.LinearOperatorBlockDiag` is returned.
  If `operators` is a doubly nested list, then a
  `tf.linalg.LinearOperatorBlockLowerTriangular` instance is returned, with
  the block in row `i` column `j` (`i >= j`) given by `operators[i][j]`.
  The `operators` list may contain `LinearOperator` instances, `LinearOperator`
  subclasses, or callables defining custom constructors (see example below).
  The dimensions of the blocks are given by `block_dims`; this argument may be
  omitted if `operators` contains only `LinearOperator` instances.

  Args:
    operators: A list or tuple containing `LinearOperator` subclasses,
      `LinearOperator` instances, and/or callables returning
      `(init_fn, apply_fn)` pairs. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance is returned. Otherwise, the
      list must be singly nested, with the
      first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is returned. Callables contained in the lists must take two
      arguments -- `shape`, the shape of the parameter instantiating the
      `LinearOperator`, and `dtype`, the `tf.dtype` of the `LinearOperator` --
      and return a further pair of callables representing a stateless trainable
      operator (see example below).
    block_dims: List or tuple of integers, representing the sizes of the blocks
      along one dimension of the (square) blockwise `LinearOperator`. If
      `operators` contains only `LinearOperator` instances, `block_dims` may be
      `None` and the dimensions are inferred.
    batch_shape: Batch shape of the `LinearOperator`.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  ### Examples

  To build a 5x5 trainable `LinearOperatorBlockDiag` given `LinearOperator`
  subclasses and `block_dims`:

  ```python
  op = build_trainable_linear_operator_block(
    operators=(tf.linalg.LinearOperatorDiag,
               tf.linalg.LinearOperatorLowerTriangular),
    block_dims=[3, 2],
    dtype=tf.float32)
  ```

  If `operators` contains only `LinearOperator` instances, the `block_dims`
  argument is not necessary:

  ```python
  # Builds a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,).
  op = build_trainable_linear_operator_block(
    operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))),
               tf.linalg.LinearOperatorFullMatrix([4.]),
               tf.linalg.LinearOperatorIdentity(2)))

  ```

  A custom operator constructor may be specified as a callable taking
  arguments `shape` and `dtype`, and returning a pair of callables
  `(init_fn, apply_fn)` describing a parameterized operator, with the following
  signatures:

  ```python
  raw_parameters = init_fn(seed)
  linear_operator = apply_fn(raw_parameters)
  ```

  For example, to define a custom initialization for a diagonal operator:

  ```python
  import functools

  def diag_operator_with_uniform_initialization(shape, dtype):
    init_fn = functools.partial(
        samplers.uniform, shape, maxval=2., dtype=dtype)
    apply_fn = lambda scale_diag: tf.linalg.LinearOperatorDiag(
        scale_diag, is_non_singular=True)
    return init_fn, apply_fn

  # Build an 8x8 `LinearOperatorBlockLowerTriangular`, with our custom diagonal
  # operator in the upper left block, and `LinearOperator` subclasses in the
  # lower two blocks.
  op = build_trainable_linear_operator_block(
    operators=(diag_operator_with_uniform_initialization,
               (tf.linalg.LinearOperatorFullMatrix,
                tf.linalg.LinearOperatorLowerTriangular)),
    block_dims=[4, 4],
    dtype=tf.float64)
  ```

  """
    with tf.name_scope(name or 'trainable_linear_operator_block'):
        operator_instances = [
            op for op in nest.flatten(operators)
            if isinstance(op, tf.linalg.LinearOperator)
        ]
        if (block_dims is None
                and len(operator_instances) < len(nest.flatten(operators))):
            # If `operator_instances` contains fewer elements than `operators`,
            # then some elements of `operators` are not instances of `LinearOperator`.
            raise ValueError(
                'Argument `block_dims` must be defined unless '
                '`operators` contains only `tf.linalg.LinearOperator` '
                'instances.')

        batch_shape = ps.cast(batch_shape, tf.int32)
        if dtype is None:
            dtype = dtype_util.common_dtype(operator_instances)

        def convert_operator(path, op):
            if isinstance(op, tf.linalg.LinearOperator):
                return op
            if len(set(path)) == 1:  # for operators on the diagonal
                shape = ps.concat([batch_shape, [block_dims[path[0]]]], axis=0)
            else:
                shape = ps.concat(
                    [batch_shape, [block_dims[path[0]], block_dims[path[1]]]],
                    axis=0)
            if op in _OPERATOR_COROUTINES:
                operator = yield from _OPERATOR_COROUTINES[op](shape=shape,
                                                               dtype=dtype)
            else:  # Custom stateless constructor.
                init_fn, apply_fn = op(shape=shape, dtype=dtype)
                raw_params = yield trainable_state_util.Parameter(init_fn)
                operator = apply_fn(raw_params)
            return operator

        # Build a structure of component trainable LinearOperators.
        operator_blocks = yield from nest_util.map_structure_coroutine(
            convert_operator, operators, _with_tuple_paths=True)
        paths = nest.yield_flat_paths(operators)
        if all(len(p) == 1 for p in paths):
            return tf.linalg.LinearOperatorBlockDiag(operator_blocks,
                                                     is_non_singular=True)
        elif all(len(p) == 2 for p in paths):
            return tf.linalg.LinearOperatorBlockLowerTriangular(
                operator_blocks, is_non_singular=True)
        else:
            raise ValueError(
                'Argument `operators` must be a flat or singly-nested sequence.'
            )