Beispiel #1
0
 def testMapWithTuplePathsCompatibleStructures(
     self, s1, s2, check_types, expected):
   def path_and_sum(path, *values):
     return path, sum(values)
   result = nest.map_structure_with_tuple_paths(
       path_and_sum, s1, s2, check_types=check_types)
   self.assertEqual(expected, result)
Beispiel #2
0
    def testMapWithTuplePathsCompatibleStructures(self, s1, s2, check_types,
                                                  expected):
        def path_and_sum(path, *values):
            return path, sum(values)

        result = nest.map_structure_with_tuple_paths(path_and_sum,
                                                     s1,
                                                     s2,
                                                     check_types=check_types)
        self.assertEqual(expected, result)
Beispiel #3
0
    def testWithTuplePaths(self):
        def g(path, _):
            return path
            yield  # pylint: disable=unreachable

        a = {'y': 2., 'x': 3., 'z': 2.}
        _, rvals = _get_yielded_and_returned_values(
            nest_util.map_structure_coroutine(g, a, _with_tuple_paths=True))
        self.assertAllEqualNested(
            rvals, nest.map_structure_with_tuple_paths(lambda path, a: path,
                                                       a))
Beispiel #4
0
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            independent_joint_distribution_from_structure,
            structure_of_distributions)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(
            structure_of_distributions, validate_args=validate_args)
    return joint_distribution_sequential.JointDistributionSequential(
        structure_of_distributions, validate_args=validate_args)
Beispiel #5
0
    def testNestMapStructureWithPaths(self,
                                      structure,
                                      expected,
                                      expand_composites=True):
        def func1(path, x):
            return '%s:%s' % (path, x)

        result = nest.map_structure_with_paths(
            func1, structure, expand_composites=expand_composites)
        self.assertEqual(result, expected)

        # Use the same test cases for map_structure_with_tuple_paths.
        def func2(tuple_path, x):
            return '%s:%s' % ('/'.join(str(v) for v in tuple_path), x)

        result = nest.map_structure_with_tuple_paths(
            func2, structure, expand_composites=expand_composites)
        self.assertEqual(result, expected)
  def testNestMapStructureWithPaths(self,
                                    structure,
                                    expected,
                                    expand_composites=True):

    def func1(path, x):
      return '%s:%s' % (path, x)

    result = nest.map_structure_with_paths(
        func1, structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)

    # Use the same test cases for map_structure_with_tuple_paths.
    def func2(tuple_path, x):
      return '%s:%s' % ('/'.join(str(v) for v in tuple_path), x)

    result = nest.map_structure_with_tuple_paths(
        func2, structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)
  def testNestMapStructureWithTuplePaths(self):
    structure = [[TestCompositeTensor(1, 2, 3)], 100, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    }]

    def func(path, x):
      return (path, x)

    result = nest.map_structure_with_tuple_paths(
        func, structure, expand_composites=True)
    expected = [[
        TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
    ], ((1,), 100), {
        'y':
            TestCompositeTensor(
                TestCompositeTensor(((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5)),
                ((2, 'y', 1), 6))
    }]
    self.assertEqual(result, expected)
    def testNestMapStructureWithTuplePaths(self):
        structure = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        }]

        def func(path, x):
            return (path, x)

        result = nest.map_structure_with_tuple_paths(func,
                                                     structure,
                                                     expand_composites=True)
        expected = [[
            TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
        ], ((1, ), 100), {
            'y':
            TestCompositeTensor(
                TestCompositeTensor(((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5)),
                ((2, 'y', 1), 6))
        }]
        self.assertEqual(result, expected)
def build_trainable_linear_operator_block(
    operators,
    block_dims=None,
    batch_shape=(),
    dtype=None,
    name=None):
  """Builds a trainable blockwise `tf.linalg.LinearOperator`.

  This function returns a trainable blockwise `LinearOperator`. If `operators`
  is a flat list, it is interpreted as blocks along the diagonal of the
  structure and an instance of `tf.linalg.LinearOperatorBlockDiag` is returned.
  If `operators` is a doubly nested list, then a
  `tf.linalg.LinearOperatorBlockLowerTriangular` instance is returned, with
  the block in row `i` column `j` (`i >= j`) given by `operators[i][j]`.
  The `operators` list may contain `LinearOperator` instances, `LinearOperator`
  subclasses, or callables that return `LinearOperator` instances. The
  dimensions of the blocks are given by `block_dims`; this argument may be
  omitted if `operators` contains only `LinearOperator` instances.

  ### Examples

  ```python
  # Build a 5x5 trainable `LinearOperatorBlockDiag` given `LinearOperator`
  # subclasses and `block_dims`.
  op = build_trainable_linear_operator_block(
    operators=(tf.linalg.LinearOperatorDiag,
               tf.linalg.LinearOperatorLowerTriangular),
    block_dims=[3, 2],
    dtype=tf.float32)

  # Build an 8x8 `LinearOperatorBlockLowerTriangular`, with a callable that
  # returns a `LinearOperator` in the upper left block, and `LinearOperator`
  # subclasses in the lower two blocks.
  op = build_trainable_linear_operator_block(
    operators=(
      (lambda shape, dtype: tf.linalg.LinearOperatorScaledIdentity(
         num_rows=shape[-1], multiplier=tf.Variable(1., dtype=dtype))),
      (tf.linalg.LinearOperatorFullMatrix,
      tf.linalg.LinearOperatorLowerTriangular))
    block_dims=[4, 4],
    dtype=tf.float64)

  # Build a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,)`. Since
  # `operators` contains only `LinearOperator` instances, the `block_dims`
  # argument is not necessary.
  op = build_trainable_linear_operator_block(
    operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))),
               tf.linalg.LinearOperatorFullMatrix([4.]),
               tf.linalg.LinearOperatorIdentity(2)))
  ```

  Args:
    operators: A list or tuple containing `LinearOperator` subclasses,
      `LinearOperator` instances, or callables returning `LinearOperator`
      instances. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag`
      instance is returned. Otherwise, the list must be singly nested, with the
      first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is returned. Callables contained in the lists must take two
      arguments -- `shape`, the shape of the `tf.Variable` instantiating the
      `LinearOperator`, and `dtype`, the `tf.dtype` of the `LinearOperator`.
    block_dims: List or tuple of integers, representing the sizes of the blocks
      along one dimension of the (square) blockwise `LinearOperator`. If
      `operators` contains only `LinearOperator` instances, `block_dims` may be
      `None` and the dimensions are inferred.
    batch_shape: Batch shape of the `LinearOperator`.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.

  Returns:
    Trainable instance of `tf.linalg.LinearOperatorBlockDiag` or
      `tf.linalg.LinearOperatorBlockLowerTriangular`.
  """
  with tf.name_scope(name or 'build_trainable_blockwise_tril_operator'):
    operator_instances = [op for op in nest.flatten(operators)
                          if isinstance(op, tf.linalg.LinearOperator)]
    if (block_dims is None
        and len(operator_instances) < len(nest.flatten(operators))):
      # If `operator_instances` contains fewer elements than `operators`,
      # then some elements of `operators` are not instances of `LinearOperator`.
      raise ValueError('Argument `block_dims` must be defined unless '
                       '`operators` contains only `tf.linalg.LinearOperator` '
                       'instances.')

    batch_shape = ps.cast(batch_shape, tf.int32)
    if dtype is None:
      dtype = dtype_util.common_dtype(operator_instances)

    def convert_operator(path, op):
      if isinstance(op, tf.linalg.LinearOperator):
        return op
      builder = _OPERATOR_BUILDERS.get(op, op)
      if len(set(path)) == 1:  # for operators on the diagonal
        return builder(
            ps.concat([batch_shape, [block_dims[path[0]]]], axis=0),
            dtype=dtype)
      return builder(
          ps.concat([batch_shape, [block_dims[path[0]], block_dims[path[1]]]],
                    axis=0),
          dtype=dtype)

    operator_blocks = nest.map_structure_with_tuple_paths(
        convert_operator, operators)
    paths = nest.yield_flat_paths(operators)
    if all(len(p) == 1 for p in paths):
      return tf.linalg.LinearOperatorBlockDiag(
          operator_blocks, is_non_singular=True)
    elif all(len(p) == 2 for p in paths):
      return tf.linalg.LinearOperatorBlockLowerTriangular(
          operator_blocks, is_non_singular=True)
    else:
      raise ValueError(
          'Argument `operators` must be a flat or singly-nested sequence.')
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  batch_ndims=None,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions
      shared across all members of the input structure. If this is specified,
      the returned joint distribution will be an autobatched distribution with
      the given batch rank, and all other dimensions absorbed into the event.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        dist = structure_of_distributions
        if batch_ndims is not None:
            excess_ndims = ps.rank_from_shape(
                dist.batch_shape_tensor()) - batch_ndims
            if tf.get_static_value(
                    excess_ndims) != 0:  # Static value may be None.
                dist = independent.Independent(
                    dist, reinterpreted_batch_ndims=excess_ndims)
        return dist

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            functools.partial(independent_joint_distribution_from_structure,
                              batch_ndims=batch_ndims,
                              validate_args=validate_args),
            structure_of_distributions)

    jdnamed = joint_distribution_named.JointDistributionNamed
    jdsequential = joint_distribution_sequential.JointDistributionSequential
    # Use an autobatched JD if a specific batch rank was requested.
    if batch_ndims is not None:
        jdnamed = functools.partial(
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)
        jdsequential = functools.partial(
            joint_distribution_auto_batched.
            JointDistributionSequentialAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict') or isinstance(
            structure_of_distributions, collections.abc.Mapping)):
        return jdnamed(structure_of_distributions, validate_args=validate_args)
    return jdsequential(structure_of_distributions,
                        validate_args=validate_args)
Beispiel #11
0
 def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type):
     with self.assertRaises(error_type):
         nest.map_structure_with_tuple_paths(lambda path, *s: 0, s1, s2)
Beispiel #12
0
 def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type):
   with self.assertRaises(error_type):
     nest.map_structure_with_tuple_paths(lambda path, *s: 0, s1, s2)