Пример #1
0
    def testNestMapStructureWithTuplePathsUpTo(self, s1, s2, expected):
        def func(tuple_path, x):
            return '%s:%s' % ('/'.join(str(v) for v in tuple_path), x)

        result = nest.map_structure_with_tuple_paths_up_to(
            s1, func, s2, expand_composites=True)
        self.assertEqual(result, expected)
  def testNestMapStructureWithTuplePathsUpTo(self, s1, s2, expected):

    def func(tuple_path, x):
      return '%s:%s' % ('/'.join(str(v) for v in tuple_path), x)

    result = nest.map_structure_with_tuple_paths_up_to(
        s1, func, s2, expand_composites=True)
    self.assertEqual(result, expected)
  def testNestMapStructureWithTuplePathsUpTo(self):
    s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
    s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    }]

    def func(path, x):
      return (path, x)

    result = nest.map_structure_with_tuple_paths_up_to(
        s1, func, s2, expand_composites=True)
    expected = [[
        TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
    ], ((1,), 100), {
        'y':
            TestCompositeTensor(((2, 'y', 0), TestCompositeTensor(4, 5)),
                                ((2, 'y', 1), 6))
    }]
    self.assertEqual(result, expected)
    def testNestMapStructureWithTuplePathsUpTo(self):
        s1 = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(5, 6)
        }]
        s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        }]

        def func(path, x):
            return (path, x)

        result = nest.map_structure_with_tuple_paths_up_to(
            s1, func, s2, expand_composites=True)
        expected = [[
            TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
        ], ((1, ), 100), {
            'y':
            TestCompositeTensor(((2, 'y', 0), TestCompositeTensor(4, 5)),
                                ((2, 'y', 1), 6))
        }]
        self.assertEqual(result, expected)
Пример #5
0
def convert_to_nested_tensor(value,
                             dtype=None,
                             dtype_hint=None,
                             allow_packing=False,
                             as_shape_tensor=False,
                             name=None):
    """Converts the given `value` to a (structure of) `Tensor`.

  This function converts Python objects of various types to a (structure of)
  `Tensor` objects. It accepts `Tensor` objects, numpy arrays, Python lists, and
  Python scalars.

  Args:
    value: An object whose structure matches that of `dtype` and for which each
      leaf has a registered `Tensor` conversion function.
    dtype: Optional structure of dtypes defining the structure of outputs and
      the `dtype` argument for nested calls to `convert_to_tensor`. If not
      nested, will be broadcasted to match the structure of `dtype_hint`.
    dtype_hint: Optional structure of dtypes defining the structure of outputs
      and the `dtype_hint` argument for nested calls to `convert_to_tensor`. If
      not nested, will be broadcasted to match the structure of `dtype`.
    allow_packing: Python `bool`, default `False`. If `True`, allow
      `convert_to_nested_tensor` to stack nested lists of Tensors along the
      leading dimension. Otherwise, raise.
    as_shape_tensor: Optional boolean when if `True` uses
      `prefer_static.convert_to_shape_tensor` instead of `tf.convert_to_tensor`
      for JAX compatibility.
    name: Optional name to use if a new `Tensor` is created. If inputs are
      structured, elements are named accoring to '{name}/{path}.{to}.{elem}'.

  Returns:
    tensor: A (structure of) `Tensor` based on `value`.
  """
    dtype_is_nested = nest.is_nested(dtype)
    hint_is_nested = nest.is_nested(dtype_hint)
    # If only one of dtype/dtype_hint is nested, broadcast the atom to match.
    if dtype_is_nested and hint_is_nested:
        nest.assert_same_structure(dtype, dtype_hint)
    elif dtype_is_nested:
        dtype_hint = broadcast_structure(dtype, dtype_hint)
    elif hint_is_nested:
        dtype = broadcast_structure(dtype_hint, dtype)

    # Call coerce_structure to force the argument structure to match dtype.
    value = coerce_structure(dtype, value)

    def convert_fn(path, value, dtype, dtype_hint, name=None):
        if not allow_packing and nest.is_nested(value) and any(
                # Treat arrays like Tensors for full parity in JAX backend.
                tf.is_tensor(x) or isinstance(x, np.ndarray)
                for x in nest.flatten(value)):
            raise NotImplementedError(
                ('Cannot convert a structure of tensors to a '
                 'single tensor. Saw {} at path {}.').format(value, path))
        if as_shape_tensor:
            return ps.convert_to_shape_tensor(value,
                                              dtype,
                                              dtype_hint,
                                              name=name)
        elif 'KerasTensor' in str(type(value)):
            # This is a hack to detect symbolic Keras tensors to work around
            # b/206660667.  The issue was that symbolic Keras tensors would
            # break the Bijector cache on forward/inverse log det jacobian,
            # because tf.convert_to_tensor is not a no-op thereon.
            return value
        else:
            return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)

    ### The following branches only affect naming.
    # For unstructured calls, just use the provided name.
    if not nest.is_nested(dtype):
        return convert_fn((), value, dtype, dtype_hint, name=name)
    # For structured calls where name is provided, include a scope and name
    # members according to "{path}.{to}.{element}".
    elif name is not None:
        with tf.name_scope(name):
            convert_with_name = lambda path, *args: convert_fn(  # pylint: disable=g-long-lambda
                path,
                *args,
                name='.'.join(map(str, path)))
            return nest.map_structure_with_tuple_paths_up_to(dtype,
                                                             convert_with_name,
                                                             value,
                                                             dtype,
                                                             dtype_hint,
                                                             check_types=False)
    # For structured calls without name, skip the scope and don't pass a
    # struct-path to convert-to-tensor.
    else:
        return nest.map_structure_with_tuple_paths_up_to(dtype,
                                                         convert_fn,
                                                         value,
                                                         dtype,
                                                         dtype_hint,
                                                         check_types=False)
Пример #6
0
def convert_to_nested_tensor(value, dtype=None, dtype_hint=None, name=None):
    """Converts the given `value` to a (structure of) `Tensor`.

  This function converts Python objects of various types to a (structure of)
  `Tensor` objects. It accepts `Tensor` objects, numpy arrays, Python lists, and
  Python scalars.

  Args:
    value: An object whose structure matches that of `dtype ` and for which each
      leaf has a registered `Tensor` conversion function.
    dtype: Optional structure of dtypes defining the structure of outputs and
      the `dtype` argument for nested calls to `convert_to_tensor`. If not
      nested, will be broadcasted to match the structure of `dtype_hint`.
    dtype_hint: Optional structure of dtypes defining the structure of outputs
      and the `dtype_hint` argument for nested calls to `convert_to_tensor`. If
      not nested, will be broadcasted to match the structure of `dtype`.
    name: Optional name to use if a new `Tensor` is created. If inputs are
      structured, elements are named accoring to '{name}/{path}.{to}.{elem}'.

  Returns:
    tensor: A (structure of) `Tensor` based on `value`.
  """
    dtype_is_nested = nest.is_nested(dtype)
    hint_is_nested = nest.is_nested(dtype_hint)
    # If only one of dtype/dtype_hint is nested, broadcast the atom to match.
    if dtype_is_nested and hint_is_nested:
        nest.assert_same_structure(dtype, dtype_hint)
    elif dtype_is_nested:
        dtype_hint = broadcast_structure(dtype, dtype_hint)
    elif hint_is_nested:
        dtype = broadcast_structure(dtype_hint, dtype)

    def convert_fn(path, value, dtype, dtype_hint, name=None):
        if nest.is_nested(value) and any(
                # Treat arrays like Tensors for full parity in JAX backend.
                tf.is_tensor(x) or isinstance(x, np.ndarray)
                for x in nest.flatten(value)):
            raise NotImplementedError(
                ('Cannot convert a structure of tensors to a '
                 'single tensor. Saw {} at path {}.').format(value, path))
        return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)

    ### The following branches only affect naming.
    # For unstructured calls, just use the provided name.
    if not nest.is_nested(dtype):
        return convert_fn((), value, dtype, dtype_hint, name=name)
    # For structured calls where name is provided, include a scope and name
    # members according to "{path}.{to}.{element}".
    elif name is not None:
        with tf.name_scope(name):
            convert_with_name = lambda path, *args: convert_fn(  # pylint: disable=g-long-lambda
                path,
                *args,
                name='.'.join(map(str, path)))
            return nest.map_structure_with_tuple_paths_up_to(dtype,
                                                             convert_with_name,
                                                             value,
                                                             dtype,
                                                             dtype_hint,
                                                             check_types=False)
    # For structured calls without name, skip the scope and don't pass a
    # struct-path to convert-to-tensor.
    else:
        return nest.map_structure_with_tuple_paths_up_to(dtype,
                                                         convert_fn,
                                                         value,
                                                         dtype,
                                                         dtype_hint,
                                                         check_types=False)