def testNestMapStructureWithTuplePathsUpTo(self, s1, s2, expected): def func(tuple_path, x): return '%s:%s' % ('/'.join(str(v) for v in tuple_path), x) result = nest.map_structure_with_tuple_paths_up_to( s1, func, s2, expand_composites=True) self.assertEqual(result, expected)
def testNestMapStructureWithTuplePathsUpTo(self, s1, s2, expected): def func(tuple_path, x): return '%s:%s' % ('/'.join(str(v) for v in tuple_path), x) result = nest.map_structure_with_tuple_paths_up_to( s1, func, s2, expand_composites=True) self.assertEqual(result, expected)
def testNestMapStructureWithTuplePathsUpTo(self): s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] s2 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] def func(path, x): return (path, x) result = nest.map_structure_with_tuple_paths_up_to( s1, func, s2, expand_composites=True) expected = [[ TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3)) ], ((1,), 100), { 'y': TestCompositeTensor(((2, 'y', 0), TestCompositeTensor(4, 5)), ((2, 'y', 1), 6)) }] self.assertEqual(result, expected)
def testNestMapStructureWithTuplePathsUpTo(self): s1 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(5, 6) }] s2 = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] def func(path, x): return (path, x) result = nest.map_structure_with_tuple_paths_up_to( s1, func, s2, expand_composites=True) expected = [[ TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3)) ], ((1, ), 100), { 'y': TestCompositeTensor(((2, 'y', 0), TestCompositeTensor(4, 5)), ((2, 'y', 1), 6)) }] self.assertEqual(result, expected)
def convert_to_nested_tensor(value, dtype=None, dtype_hint=None, allow_packing=False, as_shape_tensor=False, name=None): """Converts the given `value` to a (structure of) `Tensor`. This function converts Python objects of various types to a (structure of) `Tensor` objects. It accepts `Tensor` objects, numpy arrays, Python lists, and Python scalars. Args: value: An object whose structure matches that of `dtype` and for which each leaf has a registered `Tensor` conversion function. dtype: Optional structure of dtypes defining the structure of outputs and the `dtype` argument for nested calls to `convert_to_tensor`. If not nested, will be broadcasted to match the structure of `dtype_hint`. dtype_hint: Optional structure of dtypes defining the structure of outputs and the `dtype_hint` argument for nested calls to `convert_to_tensor`. If not nested, will be broadcasted to match the structure of `dtype`. allow_packing: Python `bool`, default `False`. If `True`, allow `convert_to_nested_tensor` to stack nested lists of Tensors along the leading dimension. Otherwise, raise. as_shape_tensor: Optional boolean when if `True` uses `prefer_static.convert_to_shape_tensor` instead of `tf.convert_to_tensor` for JAX compatibility. name: Optional name to use if a new `Tensor` is created. If inputs are structured, elements are named accoring to '{name}/{path}.{to}.{elem}'. Returns: tensor: A (structure of) `Tensor` based on `value`. """ dtype_is_nested = nest.is_nested(dtype) hint_is_nested = nest.is_nested(dtype_hint) # If only one of dtype/dtype_hint is nested, broadcast the atom to match. if dtype_is_nested and hint_is_nested: nest.assert_same_structure(dtype, dtype_hint) elif dtype_is_nested: dtype_hint = broadcast_structure(dtype, dtype_hint) elif hint_is_nested: dtype = broadcast_structure(dtype_hint, dtype) # Call coerce_structure to force the argument structure to match dtype. value = coerce_structure(dtype, value) def convert_fn(path, value, dtype, dtype_hint, name=None): if not allow_packing and nest.is_nested(value) and any( # Treat arrays like Tensors for full parity in JAX backend. tf.is_tensor(x) or isinstance(x, np.ndarray) for x in nest.flatten(value)): raise NotImplementedError( ('Cannot convert a structure of tensors to a ' 'single tensor. Saw {} at path {}.').format(value, path)) if as_shape_tensor: return ps.convert_to_shape_tensor(value, dtype, dtype_hint, name=name) elif 'KerasTensor' in str(type(value)): # This is a hack to detect symbolic Keras tensors to work around # b/206660667. The issue was that symbolic Keras tensors would # break the Bijector cache on forward/inverse log det jacobian, # because tf.convert_to_tensor is not a no-op thereon. return value else: return tf.convert_to_tensor(value, dtype, dtype_hint, name=name) ### The following branches only affect naming. # For unstructured calls, just use the provided name. if not nest.is_nested(dtype): return convert_fn((), value, dtype, dtype_hint, name=name) # For structured calls where name is provided, include a scope and name # members according to "{path}.{to}.{element}". elif name is not None: with tf.name_scope(name): convert_with_name = lambda path, *args: convert_fn( # pylint: disable=g-long-lambda path, *args, name='.'.join(map(str, path))) return nest.map_structure_with_tuple_paths_up_to(dtype, convert_with_name, value, dtype, dtype_hint, check_types=False) # For structured calls without name, skip the scope and don't pass a # struct-path to convert-to-tensor. else: return nest.map_structure_with_tuple_paths_up_to(dtype, convert_fn, value, dtype, dtype_hint, check_types=False)
def convert_to_nested_tensor(value, dtype=None, dtype_hint=None, name=None): """Converts the given `value` to a (structure of) `Tensor`. This function converts Python objects of various types to a (structure of) `Tensor` objects. It accepts `Tensor` objects, numpy arrays, Python lists, and Python scalars. Args: value: An object whose structure matches that of `dtype ` and for which each leaf has a registered `Tensor` conversion function. dtype: Optional structure of dtypes defining the structure of outputs and the `dtype` argument for nested calls to `convert_to_tensor`. If not nested, will be broadcasted to match the structure of `dtype_hint`. dtype_hint: Optional structure of dtypes defining the structure of outputs and the `dtype_hint` argument for nested calls to `convert_to_tensor`. If not nested, will be broadcasted to match the structure of `dtype`. name: Optional name to use if a new `Tensor` is created. If inputs are structured, elements are named accoring to '{name}/{path}.{to}.{elem}'. Returns: tensor: A (structure of) `Tensor` based on `value`. """ dtype_is_nested = nest.is_nested(dtype) hint_is_nested = nest.is_nested(dtype_hint) # If only one of dtype/dtype_hint is nested, broadcast the atom to match. if dtype_is_nested and hint_is_nested: nest.assert_same_structure(dtype, dtype_hint) elif dtype_is_nested: dtype_hint = broadcast_structure(dtype, dtype_hint) elif hint_is_nested: dtype = broadcast_structure(dtype_hint, dtype) def convert_fn(path, value, dtype, dtype_hint, name=None): if nest.is_nested(value) and any( # Treat arrays like Tensors for full parity in JAX backend. tf.is_tensor(x) or isinstance(x, np.ndarray) for x in nest.flatten(value)): raise NotImplementedError( ('Cannot convert a structure of tensors to a ' 'single tensor. Saw {} at path {}.').format(value, path)) return tf.convert_to_tensor(value, dtype, dtype_hint, name=name) ### The following branches only affect naming. # For unstructured calls, just use the provided name. if not nest.is_nested(dtype): return convert_fn((), value, dtype, dtype_hint, name=name) # For structured calls where name is provided, include a scope and name # members according to "{path}.{to}.{element}". elif name is not None: with tf.name_scope(name): convert_with_name = lambda path, *args: convert_fn( # pylint: disable=g-long-lambda path, *args, name='.'.join(map(str, path))) return nest.map_structure_with_tuple_paths_up_to(dtype, convert_with_name, value, dtype, dtype_hint, check_types=False) # For structured calls without name, skip the scope and don't pass a # struct-path to convert-to-tensor. else: return nest.map_structure_with_tuple_paths_up_to(dtype, convert_fn, value, dtype, dtype_hint, check_types=False)