def testNestYieldFlatPaths(self): structure = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] result1 = list(nest.yield_flat_paths(structure, expand_composites=True)) expected1 = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (1,), (2, 'y', 0, 0), (2, 'y', 0, 1), (2, 'y', 1)] self.assertEqual(result1, expected1) result2 = list(nest.yield_flat_paths(structure, expand_composites=False)) expected2 = [(0, 0), (1,), (2, 'y')] self.assertEqual(result2, expected2)
def testNestYieldFlatPaths(self): structure = [[TestCompositeTensor(1, 2, 3)], 100, { 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) }] result1 = list(nest.yield_flat_paths(structure, expand_composites=True)) expected1 = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (1, ), (2, 'y', 0, 0), (2, 'y', 0, 1), (2, 'y', 1)] self.assertEqual(result1, expected1) result2 = list( nest.yield_flat_paths(structure, expand_composites=False)) expected2 = [(0, 0), (1, ), (2, 'y')] self.assertEqual(result2, expected2)
def testNestFlatten(self, structure, expected, paths, expand_composites=True): result = nest.flatten(structure, expand_composites=expand_composites) self.assertEqual(result, expected) result_with_paths = nest.flatten_with_tuple_paths( structure, expand_composites=expand_composites) self.assertEqual(result_with_paths, list(zip(paths, expected))) string_paths = ['/'.join(str(p) for p in path) for path in paths] # pylint: disable=g-complex-comprehension result_with_string_paths = nest.flatten_with_joined_string_paths( structure, expand_composites=expand_composites) self.assertEqual(result_with_string_paths, list(zip(string_paths, expected))) flat_paths_result = list( nest.yield_flat_paths(structure, expand_composites=expand_composites)) self.assertEqual(flat_paths_result, paths)
def _create_pseudo_names(tensors, prefix): """Creates pseudo {input | output} names for subclassed Models. Warning: this function should only be used to define default names for `Metics` and `SavedModel`. No other use cases should rely on a `Model`'s input or output names. Example with dict: `{'a': [x1, x2], 'b': x3}` becomes: `['a_1', 'a_2', 'b']` Example with list: `[x, y]` becomes: `['output_1', 'output_2']` Arguments: tensors: `Model`'s outputs or inputs. prefix: 'output_' for outputs, 'input_' for inputs. Returns: Flattened list of pseudo names. """ def one_index(ele): # Start with "output_1" instead of "output_0". if isinstance(ele, int): return ele + 1 return ele flat_paths = list(nest.yield_flat_paths(tensors)) flat_paths = nest.map_structure(one_index, flat_paths) names = [] for path in flat_paths: if not path: name = prefix + '1' # Single output. else: name = '_'.join(str(p) for p in path) if isinstance(path[0], int): name = prefix + name names.append(name) return names
def testYieldFlatStringPaths(self): for inputs_expected in ({"inputs": [], "expected": []}, {"inputs": 3, "expected": [()]}, {"inputs": [3], "expected": [(0,)]}, {"inputs": {"a": 3}, "expected": [("a",)]}, {"inputs": {"a": {"b": 4}}, "expected": [("a", "b")]}, {"inputs": [{"a": 2}], "expected": [(0, "a")]}, {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, {"inputs": [{"a": [(23, 42)]}], "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, {"inputs": [{"a": ([23], 42)}], "expected": [(0, "a", 0, 0), (0, "a", 1)]}, {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, "expected": [("a", "a"), ("c", 0, 0, 0)]}, {"inputs": {"0": [{"1": 23}]}, "expected": [("0", 0, "1")]}): inputs = inputs_expected["inputs"] expected = inputs_expected["expected"] self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
def create_output_names(y_pred): """Creates output names for subclassed Model outputs. These names are used for naming `Metric`s. Example with dict: `{'a': [x1, x2], 'b': x3}` becomes: `['a_1', 'a_2', 'b']` Example with list: `[x, y]` becomes: `['output_1', 'output_2']` Arguments: y_pred: `Model`'s outputs. Returns: Flattened list of output names. """ def one_index(ele): # Start with "output_1" instead of "output_0". if isinstance(ele, int): return ele + 1 return ele flat_paths = list(nest.yield_flat_paths(y_pred)) flat_paths = nest.map_structure(one_index, flat_paths) output_names = [] for path in flat_paths: if not path: output_name = 'output_1' else: output_name = '_'.join(str(p) for p in path) if isinstance(path[0], int): output_name = 'output_' + output_name output_names.append(output_name) return output_names
def build_trainable_linear_operator_block( operators, block_dims=None, batch_shape=(), dtype=None, name=None): """Builds a trainable blockwise `tf.linalg.LinearOperator`. This function returns a trainable blockwise `LinearOperator`. If `operators` is a flat list, it is interpreted as blocks along the diagonal of the structure and an instance of `tf.linalg.LinearOperatorBlockDiag` is returned. If `operators` is a doubly nested list, then a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is returned, with the block in row `i` column `j` (`i >= j`) given by `operators[i][j]`. The `operators` list may contain `LinearOperator` instances, `LinearOperator` subclasses, or callables that return `LinearOperator` instances. The dimensions of the blocks are given by `block_dims`; this argument may be omitted if `operators` contains only `LinearOperator` instances. ### Examples ```python # Build a 5x5 trainable `LinearOperatorBlockDiag` given `LinearOperator` # subclasses and `block_dims`. op = build_trainable_linear_operator_block( operators=(tf.linalg.LinearOperatorDiag, tf.linalg.LinearOperatorLowerTriangular), block_dims=[3, 2], dtype=tf.float32) # Build an 8x8 `LinearOperatorBlockLowerTriangular`, with a callable that # returns a `LinearOperator` in the upper left block, and `LinearOperator` # subclasses in the lower two blocks. op = build_trainable_linear_operator_block( operators=( (lambda shape, dtype: tf.linalg.LinearOperatorScaledIdentity( num_rows=shape[-1], multiplier=tf.Variable(1., dtype=dtype))), (tf.linalg.LinearOperatorFullMatrix, tf.linalg.LinearOperatorLowerTriangular)) block_dims=[4, 4], dtype=tf.float64) # Build a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,)`. Since # `operators` contains only `LinearOperator` instances, the `block_dims` # argument is not necessary. op = build_trainable_linear_operator_block( operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))), tf.linalg.LinearOperatorFullMatrix([4.]), tf.linalg.LinearOperatorIdentity(2))) ``` Args: operators: A list or tuple containing `LinearOperator` subclasses, `LinearOperator` instances, or callables returning `LinearOperator` instances. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag` instance is returned. Otherwise, the list must be singly nested, with the first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is returned. Callables contained in the lists must take two arguments -- `shape`, the shape of the `tf.Variable` instantiating the `LinearOperator`, and `dtype`, the `tf.dtype` of the `LinearOperator`. block_dims: List or tuple of integers, representing the sizes of the blocks along one dimension of the (square) blockwise `LinearOperator`. If `operators` contains only `LinearOperator` instances, `block_dims` may be `None` and the dimensions are inferred. batch_shape: Batch shape of the `LinearOperator`. dtype: `tf.dtype` of the `LinearOperator`. name: str, name for `tf.name_scope`. Returns: Trainable instance of `tf.linalg.LinearOperatorBlockDiag` or `tf.linalg.LinearOperatorBlockLowerTriangular`. """ with tf.name_scope(name or 'build_trainable_blockwise_tril_operator'): operator_instances = [op for op in nest.flatten(operators) if isinstance(op, tf.linalg.LinearOperator)] if (block_dims is None and len(operator_instances) < len(nest.flatten(operators))): # If `operator_instances` contains fewer elements than `operators`, # then some elements of `operators` are not instances of `LinearOperator`. raise ValueError('Argument `block_dims` must be defined unless ' '`operators` contains only `tf.linalg.LinearOperator` ' 'instances.') batch_shape = ps.cast(batch_shape, tf.int32) if dtype is None: dtype = dtype_util.common_dtype(operator_instances) def convert_operator(path, op): if isinstance(op, tf.linalg.LinearOperator): return op builder = _OPERATOR_BUILDERS.get(op, op) if len(set(path)) == 1: # for operators on the diagonal return builder( ps.concat([batch_shape, [block_dims[path[0]]]], axis=0), dtype=dtype) return builder( ps.concat([batch_shape, [block_dims[path[0]], block_dims[path[1]]]], axis=0), dtype=dtype) operator_blocks = nest.map_structure_with_tuple_paths( convert_operator, operators) paths = nest.yield_flat_paths(operators) if all(len(p) == 1 for p in paths): return tf.linalg.LinearOperatorBlockDiag( operator_blocks, is_non_singular=True) elif all(len(p) == 2 for p in paths): return tf.linalg.LinearOperatorBlockLowerTriangular( operator_blocks, is_non_singular=True) else: raise ValueError( 'Argument `operators` must be a flat or singly-nested sequence.')
def flatten_with_tuple_paths(structure): return list(zip(nest.yield_flat_paths(structure), nest.flatten(structure)))
def map_structure_coroutine( coroutine, *structures, _expand_composites=False, # pylint: disable=invalid-name _up_to=UNSPECIFIED, # pylint: disable=invalid-name _with_tuple_paths=False, # pylint: disable=invalid-name **named_structures): # pylint: disable=g-doc-return-or-yield """Invokes a coroutine multiple times with args from provided structures. This is semantically identical to `map_structure_with_named_args`, except that the first argument is a generator or coroutine (a callable whose body contains `yield` statements) rather than a function. This is invoked with arguments from the provided structure(s), thus defining an outer generator/ coroutine that `yield`s values in sequence from each call to the inner `coroutine`. The argument structures are traversed, and the coroutine is invoked, in the order defined by `tf.nest.flatten`. A stripped-down implementation of the core logic is as follows: ```python def map_structure_coroutine(coroutine, *structures): flat_results = [] for args in zip(*[tf.nest.flatten(s) for s in structures]): retval = yield from coroutine(*args) flat_results.append(retval) return tf.nest.pack_sequence_as(structures[0], flat_results) ``` Args: coroutine: a generator/coroutine callable that accepts one or more named arguments. *structures: Structures of arguments passed positionally to `coroutine`. _expand_composites: Forwarded as `tf.nest.flatten(..., expand_composites=_expand_composites)`. _up_to: Optional shallow structure to map up to. If provided, `nest.map_structure_up_to` is called rather than `nest.map_structure`. Default value: `UNSPECIFIED`. _with_tuple_paths: Python bool. If `True`, the first argument to `coroutine` is a tuple path to the current leaf of the argument structure(s). Default value: `False`. **named_structures: Structures of arguments passed by name to `coroutine`. Yields: Values `yield`ed by each invocation of `coroutine`, with invocations in order corresponding to `tf.nest.flatten`. Returns: A new structure matching that of the input structures (or the shallow structure `_up_to`, if specified), in which each element is the return value from applying `coroutine` to the corresponding elements of the input structures. ## Examples A JointDistributionCoroutine may define a reusable submodel as its own coroutine, for example: ```python def horseshoe_prior(path, scale): # Auxiliary-variable representation of a horseshoe prior on sparse weights. name = ','.join(path) z = yield tfd.HalfCauchy(loc=0., scale=scale, name=name + '_z') w_noncentered = yield tfd.Normal( loc=0., scale=z, name=name + '_w_noncentered') return z * w_noncentered ``` Note that this submodel yields two auxiliary random variables, and returns the sampled weight as a third value. Using `map_structure_coroutine` we can define a structure of such submodels, and collect their return values: ``` @tfd.JointDistributionCoroutineAutoBatched def model(): weights = yield from nest_util.map_structure_coroutine( horseshoe_prior, scale={'a': tf.ones([5]) * 100., 'b': tf.ones([2]) * 1e-2}, _with_tuple_paths=True) # ==> `weights` is a dict of weight values. yield tfd.Deterministic( tf.sqrt(tf.norm(weights['a'])**2 + tf.norm(weights['b'])**2), name='weights_norm') print(model.event_shape) # ==> StructTuple( # a_z=TensorShape([5]), # a_w_noncentered=TensorShape([5]), # b_z=TensorShape([2]), # b_w_noncentered=TensorShape([2]), # weights_norm=TensorShape([])) ``` """ # pylint: enable=g-doc-return-or-yield names, named_structure_values = (zip( *named_structures.items()) if named_structures else ((), ())) all_structures = structures + named_structure_values result_structure = all_structures[0] if _up_to is UNSPECIFIED else _up_to flat_arg_structures = [ nest.flatten_up_to(result_structure, s) for s in all_structures ] if _with_tuple_paths: # Pass tuple paths as a first positional arg (before any provided args). flat_paths = nest.yield_flat_paths( result_structure, expand_composites=_expand_composites) flat_arg_structures = [list(flat_paths)] + flat_arg_structures num_positional_args = 1 + len(structures) else: num_positional_args = len(structures) flat_results = [] for leaf_values in zip(*flat_arg_structures): result = yield from coroutine( *leaf_values[:num_positional_args], **dict(zip(names, leaf_values[num_positional_args:]))) flat_results.append(result) return nest.pack_sequence_as(result_structure, flat_results)
def _trainable_linear_operator_block(operators, block_dims=None, batch_shape=(), dtype=None, name=None): """Builds a trainable blockwise `tf.linalg.LinearOperator`. This function returns a trainable blockwise `LinearOperator`. If `operators` is a flat list, it is interpreted as blocks along the diagonal of the structure and an instance of `tf.linalg.LinearOperatorBlockDiag` is returned. If `operators` is a doubly nested list, then a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is returned, with the block in row `i` column `j` (`i >= j`) given by `operators[i][j]`. The `operators` list may contain `LinearOperator` instances, `LinearOperator` subclasses, or callables defining custom constructors (see example below). The dimensions of the blocks are given by `block_dims`; this argument may be omitted if `operators` contains only `LinearOperator` instances. Args: operators: A list or tuple containing `LinearOperator` subclasses, `LinearOperator` instances, and/or callables returning `(init_fn, apply_fn)` pairs. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag` instance is returned. Otherwise, the list must be singly nested, with the first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is returned. Callables contained in the lists must take two arguments -- `shape`, the shape of the parameter instantiating the `LinearOperator`, and `dtype`, the `tf.dtype` of the `LinearOperator` -- and return a further pair of callables representing a stateless trainable operator (see example below). block_dims: List or tuple of integers, representing the sizes of the blocks along one dimension of the (square) blockwise `LinearOperator`. If `operators` contains only `LinearOperator` instances, `block_dims` may be `None` and the dimensions are inferred. batch_shape: Batch shape of the `LinearOperator`. dtype: `tf.dtype` of the `LinearOperator`. name: str, name for `tf.name_scope`. Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. ### Examples To build a 5x5 trainable `LinearOperatorBlockDiag` given `LinearOperator` subclasses and `block_dims`: ```python op = build_trainable_linear_operator_block( operators=(tf.linalg.LinearOperatorDiag, tf.linalg.LinearOperatorLowerTriangular), block_dims=[3, 2], dtype=tf.float32) ``` If `operators` contains only `LinearOperator` instances, the `block_dims` argument is not necessary: ```python # Builds a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,). op = build_trainable_linear_operator_block( operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))), tf.linalg.LinearOperatorFullMatrix([4.]), tf.linalg.LinearOperatorIdentity(2))) ``` A custom operator constructor may be specified as a callable taking arguments `shape` and `dtype`, and returning a pair of callables `(init_fn, apply_fn)` describing a parameterized operator, with the following signatures: ```python raw_parameters = init_fn(seed) linear_operator = apply_fn(raw_parameters) ``` For example, to define a custom initialization for a diagonal operator: ```python import functools def diag_operator_with_uniform_initialization(shape, dtype): init_fn = functools.partial( samplers.uniform, shape, maxval=2., dtype=dtype) apply_fn = lambda scale_diag: tf.linalg.LinearOperatorDiag( scale_diag, is_non_singular=True) return init_fn, apply_fn # Build an 8x8 `LinearOperatorBlockLowerTriangular`, with our custom diagonal # operator in the upper left block, and `LinearOperator` subclasses in the # lower two blocks. op = build_trainable_linear_operator_block( operators=(diag_operator_with_uniform_initialization, (tf.linalg.LinearOperatorFullMatrix, tf.linalg.LinearOperatorLowerTriangular)), block_dims=[4, 4], dtype=tf.float64) ``` """ with tf.name_scope(name or 'trainable_linear_operator_block'): operator_instances = [ op for op in nest.flatten(operators) if isinstance(op, tf.linalg.LinearOperator) ] if (block_dims is None and len(operator_instances) < len(nest.flatten(operators))): # If `operator_instances` contains fewer elements than `operators`, # then some elements of `operators` are not instances of `LinearOperator`. raise ValueError( 'Argument `block_dims` must be defined unless ' '`operators` contains only `tf.linalg.LinearOperator` ' 'instances.') batch_shape = ps.cast(batch_shape, tf.int32) if dtype is None: dtype = dtype_util.common_dtype(operator_instances) def convert_operator(path, op): if isinstance(op, tf.linalg.LinearOperator): return op if len(set(path)) == 1: # for operators on the diagonal shape = ps.concat([batch_shape, [block_dims[path[0]]]], axis=0) else: shape = ps.concat( [batch_shape, [block_dims[path[0]], block_dims[path[1]]]], axis=0) if op in _OPERATOR_COROUTINES: operator = yield from _OPERATOR_COROUTINES[op](shape=shape, dtype=dtype) else: # Custom stateless constructor. init_fn, apply_fn = op(shape=shape, dtype=dtype) raw_params = yield trainable_state_util.Parameter(init_fn) operator = apply_fn(raw_params) return operator # Build a structure of component trainable LinearOperators. operator_blocks = yield from nest_util.map_structure_coroutine( convert_operator, operators, _with_tuple_paths=True) paths = nest.yield_flat_paths(operators) if all(len(p) == 1 for p in paths): return tf.linalg.LinearOperatorBlockDiag(operator_blocks, is_non_singular=True) elif all(len(p) == 2 for p in paths): return tf.linalg.LinearOperatorBlockLowerTriangular( operator_blocks, is_non_singular=True) else: raise ValueError( 'Argument `operators` must be a flat or singly-nested sequence.' )