def _should_broadcast(self, obj): # e.g. 'mse'. if not nest.is_nested(obj): return True # e.g. ['mse'] or ['mse', 'mae']. return (isinstance(obj, (list, tuple)) and not any(nest.is_nested(o) for o in obj))
def call(self, inputs, states, training=None): prev_output = states[0] if nest.is_nested(states) else states output = K.dot(inputs, self.kernel) print (f"output:{output}") new_state = [output] if nest.is_nested(states) else output return output, new_state
def _coerce_structure(shallow_tree, input_tree): """Implementation of coerce_structure.""" if not nest.is_nested(shallow_tree): return input_tree if not nest.is_nested(input_tree): raise TypeError( nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))) if len(input_tree) != len(shallow_tree): raise ValueError( nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(input_tree), shallow_length=len(shallow_tree))) # Determine whether shallow_tree should be treated as a Mapping or a Sequence. # Namedtuples can be interpreted either way (but keys take precedence). _shallow_is_namedtuple = nest._is_namedtuple(shallow_tree) # pylint: disable=invalid-name _shallow_is_mapping = isinstance(shallow_tree, collections.abc.Mapping) # pylint: disable=invalid-name shallow_supports_keys = _shallow_is_namedtuple or _shallow_is_mapping shallow_supports_iter = _shallow_is_namedtuple or not _shallow_is_mapping # Branch-selection depends on both shallow and input container-classes. input_is_mapping = isinstance(input_tree, collections.abc.Mapping) if nest._is_namedtuple(input_tree): if shallow_supports_keys: lookup_branch = lambda k: getattr(input_tree, k) else: input_iter = nest._yield_value(input_tree) lookup_branch = lambda _: next(input_iter) elif shallow_supports_keys and input_is_mapping: lookup_branch = lambda k: input_tree[k] elif shallow_supports_iter and not input_is_mapping: input_iter = nest._yield_value(input_tree) lookup_branch = lambda _: next(input_iter) else: raise TypeError( nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format( input_type=type(input_tree), shallow_type=(type(shallow_tree.__wrapped__) if hasattr( shallow_tree, '__wrapped__') else type(shallow_tree)))) flat_coerced = [] needs_wrapping = type(shallow_tree) is not type(input_tree) for shallow_key, shallow_branch in nest._yield_sorted_items(shallow_tree): try: input_branch = lookup_branch(shallow_key) except (KeyError, AttributeError): raise ValueError( nest._SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key])) flat_coerced.append(_coerce_structure(shallow_branch, input_branch)) # Keep track of whether nested elements have changed. needs_wrapping |= input_branch is not flat_coerced[-1] # Only create a new instance if containers differ or contents changed. return (nest._sequence_like(shallow_tree, flat_coerced) if needs_wrapping else input_tree)
def _prepare_args(target, event_ndims): """Creates a structure of `RunningCovariance`s based on inferred metadata. Metadata required to create a `RunningCovariance` object (`shape`, `dtype`, and `event_ndims` of incoming chain states) will be inferred from the `target`. Using that information, an identical structure of `RunningCovariance`s to `target` will be returned. Args: target: A (possibly nested) structure of `Tensor`s or Python `list`s of `Tensor`s representing the current state(s) of the Markov chain(s). It is used to infer the shape and dtype of future samples. event_ndims: A (possibly nested) structure of integers. Defines the number of inner-most dimensions that represent the event shape. Must be either a singleton or of the same shape as `target`. Returns: cov_streams: Structure of `sample_stats.RunningCovariance` matching the shape of `target`. """ shape = tf.nest.map_structure(lambda target: target.shape, target) dtype = tf.nest.map_structure(lambda target: target.dtype, target) if event_ndims is None: event_ndims = tf.nest.map_structure(ps.rank, target) elif not nest.is_nested(event_ndims): event_ndims = nest_util.broadcast_structure(target, event_ndims) return nest.map_structure_up_to( target, sample_stats.RunningCovariance, shape, event_ndims, dtype, check_types=False, )
def one_step(self, new_chain_state, current_reducer_state, previous_kernel_results, axis=None): """Update the `current_reducer_state` with a new chain state. Chunking semantics are similar to those of batching and are specified by the `axis` parameter. If chunking is enabled (axis is not `None`), all elements along the specified `axis` will be treated as separate samples. If a single scalar value is provided for a non-scalar sample structure, that value will be used for all elements in the structure. If not, an identical structure must be provided. Args: new_chain_state: A (possibly nested) structure of incoming chain state(s) with shape and dtype compatible with those used to initialize the `current_reducer_state`. current_reducer_state: `CovarianceReducerState`s representing the current state of the running covariance. previous_kernel_results: A (possibly nested) structure of `Tensor`s representing internal calculations made in a related `TransitionKernel`. axis: If chunking is desired, this is a (possibly nested) structure of integers that specifies the axis with chunked samples. For individual samples, set this to `None`. By default, samples are not chunked (`axis` is None). Returns: new_reducer_state: `CovarianceReducerState` with updated running statistics. Its `cov_state` field has an identical structure to the results of `self.transform_fn`. Each of the individual values in that structure subsequently mimics the structure of `current_reducer_state`. """ with tf.name_scope( mcmc_util.make_name(self.name, 'covariance_reducer', 'one_step')): cov_streams = _prepare_args(current_reducer_state.init_structure, self.event_ndims) new_chain_state = tf.nest.map_structure(tf.convert_to_tensor, new_chain_state) previous_kernel_results = tf.nest.map_structure( tf.convert_to_tensor, previous_kernel_results) fn_results = tf.nest.map_structure( lambda fn: fn(new_chain_state, previous_kernel_results), self.transform_fn, ) if not nest.is_nested(axis): axis = nest_util.broadcast_structure(fn_results, axis) running_cov_state = nest.map_structure_up_to( current_reducer_state.init_structure, lambda strm, *args: strm.update(*args), cov_streams, current_reducer_state.cov_state, fn_results, axis, check_types=False, ) return CovarianceReducerState(current_reducer_state.init_structure, running_cov_state)
def call(self, inputs, states, training=None): verbose = False old_is_init, old_forward, old_loglik = states batch_size = old_forward.shape[0] if verbose: print("batch_size=", batch_size) print("old_is_init=", old_is_init) print("old_forward=\n", old_forward, " shape", old_forward.shape) print("old_loglik=", old_loglik) I0 = tf.dtypes.cast(old_is_init, tf.float32) R0 = tf.tensordot(I0, self.I, axes=0) R1 = tf.linalg.matvec(self.A, old_forward, transpose_a=True) R = R0 + R1 R = tf.identity(R, name="R") if verbose: print(f"R0:{R0}\nR1:{R1}\nR:{R}") # [units, n, s] [batch_size, s] # E has shape [batch_size, units, n] E = tf.linalg.matvec(tf.expand_dims(self.B, 0), tf.expand_dims(inputs, 1), name="E") forward = tf.multiply(E, R, name="forward") S = tf.reduce_sum(forward, axis=-1, name="loglik") loglik = old_loglik + tf.math.log(S) forward = forward / tf.expand_dims(S, -1) batch_size = tf.shape(inputs)[ 0] # 'call' can be given None as batch size is_init = tf.zeros(batch_size, dtype='int8', name="is_init") new_state = [is_init, forward, loglik] new_state = [new_state] if nest.is_nested(states) else new_state if verbose: print("new_state", new_state) return loglik, new_state
def map_structure_with_atomic(is_atomic_fn, map_fn, nested): """Maps the atomic elements of a nested structure. Args: is_atomic_fn: A function that determines if an element of `nested` is atomic. map_fn: The function to apply to atomic elements of `nested`. nested: A nested structure. Returns: The nested structure, with atomic elements mapped according to `map_fn`. Raises: ValueError: If an element that is neither atomic nor a sequence is encountered. """ if is_atomic_fn(nested): return map_fn(nested) # Recursively convert. if not nest.is_nested(nested): raise ValueError( 'Received non-atomic and non-sequence element: {}'.format(nested)) if nest.is_mapping(nested): values = [nested[k] for k in sorted(nested.keys())] elif nest.is_attrs(nested): values = _astuple(nested) else: values = nested mapped_values = [ map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values ] return nest._sequence_like(nested, mapped_values)
def _is_atomic_nested(nested): """Returns `True` if `nested` is a list representing node data.""" if isinstance(nested, ListWrapper): return True if _is_serialized_node_data(nested): return True return not nest.is_nested(nested)
def check_default_value(shape, default_value, dtype, key): """Returns default value as tuple if it's valid, otherwise raises errors. This function verifies that `default_value` is compatible with both `shape` and `dtype`. If it is not compatible, it raises an error. If it is compatible, it casts default_value to a tuple and returns it. `key` is used only for error message. Args: shape: An iterable of integers specifies the shape of the `Tensor`. default_value: If a single value is provided, the same value will be applied as the default value for every item. If an iterable of values is provided, the shape of the `default_value` should be equal to the given `shape`. dtype: defines the type of values. Default value is `tf.float32`. Must be a non-quantized, real integer or floating point type. key: Column name, used only for error messages. Returns: A tuple which will be used as default value. Raises: TypeError: if `default_value` is an iterable but not compatible with `shape` TypeError: if `default_value` is not compatible with `dtype`. ValueError: if `dtype` is not convertible to `tf.float32`. """ if default_value is None: return None if isinstance(default_value, int): return _create_tuple(shape, default_value) if isinstance(default_value, float) and dtype.is_floating: return _create_tuple(shape, default_value) if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays default_value = default_value.tolist() if nest.is_nested(default_value): if not _is_shape_and_default_value_compatible(default_value, shape): raise ValueError( 'The shape of default_value must be equal to given shape. ' 'default_value: {}, shape: {}, key: {}'.format( default_value, shape, key)) # Check if the values in the list are all integers or are convertible to # floats. is_list_all_int = all( isinstance(v, int) for v in nest.flatten(default_value)) is_list_has_float = any( isinstance(v, float) for v in nest.flatten(default_value)) if is_list_all_int: return _as_tuple(default_value) if is_list_has_float and dtype.is_floating: return _as_tuple(default_value) raise TypeError('default_value must be compatible with dtype. ' 'default_value: {}, dtype: {}, key: {}'.format( default_value, dtype, key))
def map_to_output_names(y_pred, output_names, struct): """Maps a dict to a list using `output_names` as keys. This is a convenience feature only. When a `Model`'s outputs are a list, you can specify per-output losses and metrics as a dict, where the keys are the output names. If you specify per-output losses and metrics via the same structure as the `Model`'s outputs (recommended), no mapping is performed. For the Functional API, the output names are the names of the last layer of each output. For the Subclass API, the output names are determined by `create_pseudo_output_names` (For example: `['output_1', 'output_2']` for a list of outputs). This mapping preserves backwards compatibility for `compile` and `fit`. Args: y_pred: Sample outputs of the Model, to determine if this convenience feature should be applied (`struct` is returned unmodified if `y_pred` isn't a flat list). output_names: List. The names of the outputs of the Model. struct: The structure to map. Returns: `struct` mapped to a list in same order as `output_names`. """ single_output = not nest.is_nested(y_pred) outputs_are_flat_list = (not single_output and isinstance( y_pred, (list, tuple)) and not any(nest.is_nested(y_p) for y_p in y_pred)) if (single_output or outputs_are_flat_list) and isinstance(struct, dict): output_names = output_names or create_pseudo_output_names(y_pred) struct = copy.copy(struct) new_struct = [struct.pop(name, None) for name in output_names] if struct: raise ValueError('Found unexpected keys that do not correspond ' 'to any Model output: {}. Expected: {}'.format( struct.keys(), output_names)) if len(new_struct) == 1: return new_struct[0] return new_struct else: return struct
def convert_fn(path, value, dtype, dtype_hint, name=None): if not allow_packing and nest.is_nested(value) and any( # Treat arrays like Tensors for full parity in JAX backend. tf.is_tensor(x) or isinstance(x, np.ndarray) for x in nest.flatten(value)): raise NotImplementedError( ('Cannot convert a structure of tensors to a ' 'single tensor. Saw {} at path {}.').format(value, path)) return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
def one_step(self, new_chain_state, current_reducer_state, previous_kernel_results=None, axis=None): """Update the `current_reducer_state` with a new chain state. Chunking semantics are specified by the `axis` parameter. If chunking is enabled (axis is not `None`), all elements along the specified `axis` will be treated as separate samples. If a single scalar value is provided for a non-scalar sample structure, that value will be used for all elements in the structure. If not, an identical structure must be provided. Args: new_chain_state: A (possibly nested) structure of incoming chain state(s) with shape and dtype compatible with those used to initialize the `current_reducer_state`. current_reducer_state: `ExpectationsReducerState` representing the current reducer state. previous_kernel_results: A (possibly nested) structure of `Tensor`s representing internal calculations made in a related `TransitionKernel`. axis: If chunking is desired, this is a (possibly nested) structure of integers that specifies the axis with chunked samples. For individual samples, set this to `None`. By default, samples are not chunked (`axis` is None). Returns: new_reducer_state: `ExpectationsReducerState` with updated running statistics. It tracks a running total and the number of processed samples. """ with tf.name_scope( mcmc_util.make_name(self.name, 'expectations_reducer', 'one_step')): new_chain_state = tf.nest.map_structure(tf.convert_to_tensor, new_chain_state) if previous_kernel_results is not None: previous_kernel_results = tf.nest.map_structure( tf.convert_to_tensor, previous_kernel_results, expand_composites=True) fn_results = tf.nest.map_structure( lambda fn: fn(new_chain_state, previous_kernel_results), self.transform_fn) if not nest.is_nested(axis): axis = nest_util.broadcast_structure(fn_results, axis) def update(fn_results, state, axis): return state.update(fn_results, axis=axis) return ExpectationsReducerState( nest.map_structure(update, fn_results, current_reducer_state.expectation_state, axis, check_types=False))
def _canonicalize_event_ndims(target, event_ndims): """Returns `event_ndims` shaped parallel to `target`, repeating as needed.""" # This is only here to support the possibility of different event_ndims across # different Tensors in the target structure. Otherwise, event_ndims could # just be an integer (or None) and wouldn't need to be canonicalized to a # structure. if not nest.is_nested(event_ndims): return nest_util.broadcast_structure(target, event_ndims) else: return event_ndims
def is_empty(x): """Check whether a possibly nested structure is empty.""" if not nest.is_nested(x): return False if isinstance(x, collections.Mapping): return is_empty(list(x.values())) for item in x: if not is_empty(item): return False return True
def is_empty(x): """Check whether a possibly nested structure is empty.""" if not nest.is_nested(x): return False if isinstance(x, collections_abc.Mapping): return is_empty(list(x.values())) for item in x: if not is_empty(item): return False return True
def arg_is_blockwise(block_dimensions, arg, arg_split_dim): """Detect if input should be interpreted as a list of blocks.""" # Tuples and lists of length equal to the number of operators may be # blockwise. if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)): # If the elements of the iterable are not nested, interpret the input as # blockwise. if not any(nest.is_nested(x) for x in arg): return True else: arg_dims = [ ops.convert_to_tensor_v2_with_dispatch(x).shape[arg_split_dim] for x in arg ] self_dims = [dim.value for dim in block_dimensions] # If none of the operator dimensions are known, interpret the input as # blockwise if its matching dimensions are unequal. if all(self_d is None for self_d in self_dims): # A nested tuple/list with a single outermost element is not blockwise if len(arg_dims) == 1: return False elif any(dim != arg_dims[0] for dim in arg_dims): return True else: raise ValueError( "Parsing of the input structure is ambiguous. Please input " "a blockwise iterable of `Tensor`s or a single `Tensor`." ) # If input dimensions equal the respective (known) blockwise operator # dimensions, then the input is blockwise. if all(self_d == arg_d or self_d is None for self_d, arg_d in zip(self_dims, arg_dims)): return True # If input dimensions equals are all equal, and are greater than or equal # to the sum of the known operator dimensions, interpret the input as # blockwise. # input is not blockwise. self_dim = sum(self_d for self_d in self_dims if self_d is not None) if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim: return False # If none of these conditions is met, the input shape is mismatched. raise ValueError( "Input dimension does not match operator dimension.") else: return False
def _scan( # pylint: disable=unused-argument fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, reverse=False, name=None): """Scan implementation.""" if reverse: elems = nest.map_structure(lambda x: x[::-1], elems) if initializer is None: if nest.is_nested(elems): raise NotImplementedError initializer = elems[0] elems = elems[1:] prepend = [[initializer]] else: prepend = None def func(arg, x): return nest.flatten( fn(nest.pack_sequence_as(initializer, arg), nest.pack_sequence_as(elems, x))) arg = nest.flatten(initializer) if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top def scan_body(arg, x): arg = func(arg, x) return arg, arg _, out = lax.scan(scan_body, arg, nest.flatten(elems)) else: out = [[] for _ in range(len(arg))] for x in zip(*nest.flatten(elems)): arg = func(arg, x) for i, z in enumerate(arg): out[i].append(z) if prepend is not None: out = [pre + list(o) for (pre, o) in zip(prepend, out)] ordering = (lambda x: x[::-1]) if reverse else (lambda x: x) return nest.pack_sequence_as(initializer, [ordering(np.array(o)) for o in out])
def _conform_to_outputs(self, outputs, struct): """Convenience method to conform `struct` to `outputs` structure. Mappings performed: (1) Map a dict to a list of outputs, using the output names. (2) Fill missing keys in a dict w/ `None`s. (3) Map a single item to all outputs. Args: outputs: Model predictions. struct: Arbitrary nested structure (e.g. of labels, sample_weights, losses, or metrics). Returns: Mapping of `struct` to `outputs` structure. """ struct = map_to_output_names(outputs, self._output_names, struct) struct = map_missing_dict_keys(outputs, struct) # Allow passing one object that applies to all outputs. if not nest.is_nested(struct) and nest.is_nested(outputs): struct = nest.map_structure(lambda _: struct, outputs) return struct
def serialize(self, make_node_key, node_conversion_map): """Serializes `Node` for Functional API's `get_config`.""" # Serialization still special-cases first argument. args, kwargs = self.call_args, self.call_kwargs inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs) # Treat everything other than first argument as a kwarg. arguments = dict(zip(self.layer._call_fn_args[1:], args)) arguments.update(kwargs) kwargs = arguments kwargs = nest.map_structure(_serialize_keras_tensor, kwargs) try: json.dumps(kwargs, default=json_utils.get_json_type) except TypeError: kwarg_types = nest.map_structure(type, kwargs) raise TypeError('Layer ' + self.layer.name + ' was passed non-JSON-serializable arguments. ' + 'Arguments had types: ' + str(kwarg_types) + '. They cannot be serialized out ' 'when saving the model.') # `kwargs` is added to each Tensor in the first arg. This should be # changed in a future version of the serialization format. def serialize_first_arg_tensor(t): if is_keras_tensor(t): kh = t._keras_history node_index = kh.node_index node_key = make_node_key(kh.layer.name, node_index) new_node_index = node_conversion_map.get(node_key, 0) data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs] else: # If an element in the first call argument did not originate as a # keras tensor and is a constant value, we save it using the format # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant] # (potentially including serialized kwargs in an optional 4th argument data = [ _CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs ] return tf_utils.ListWrapper(data) data = nest.map_structure(serialize_first_arg_tensor, inputs) if (not nest.is_nested(data) and not self.layer._preserve_input_structure_in_config): data = [data] data = tf_utils.convert_inner_node_data(data) return data
def _is_shape_and_default_value_compatible(default_value, shape): """Verifies compatibility of shape and default_value.""" # Invalid condition: # * if default_value is not a scalar and shape is empty # * or if default_value is an iterable and shape is not empty if nest.is_nested(default_value) != bool(shape): return False if not shape: return True if len(default_value) != shape[0]: return False for i in range(shape[0]): if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]): return False return True
def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): """Generate a zero filled tensor with shape [batch_size, state_size].""" if batch_size_tensor is None or dtype is None: raise ValueError( 'batch_size and dtype cannot be None while constructing initial state: ' 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) def create_zeros(unnested_state_size): flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list() init_state_size = [batch_size_tensor] + flat_dims return array_ops.zeros(init_state_size, dtype=dtype) if nest.is_nested(state_size): return nest.map_structure(create_zeros, state_size) else: return create_zeros(state_size)
def replace_composites_with_components(structure): """Recursively replaces CompositeTensors with their components. Args: structure: A `nest`-compatible structure, possibly containing composite tensors. Returns: A copy of `structure`, where each composite tensor has been replaced by its components. The result will contain no composite tensors. Note that `nest.flatten(replace_composites_with_components(structure))` returns the same value as `nest.flatten(structure)`. """ if isinstance(structure, CompositeTensor): return replace_composites_with_components( structure._type_spec._to_components(structure)) # pylint: disable=protected-access elif not nest.is_nested(structure): return structure else: return nest.map_structure( replace_composites_with_components, structure, expand_composites=False)
def convert_fn(path, value, dtype, dtype_hint, name=None): if not allow_packing and nest.is_nested(value) and any( # Treat arrays like Tensors for full parity in JAX backend. tf.is_tensor(x) or isinstance(x, np.ndarray) for x in nest.flatten(value)): raise NotImplementedError( ('Cannot convert a structure of tensors to a ' 'single tensor. Saw {} at path {}.').format(value, path)) if as_shape_tensor: return ps.convert_to_shape_tensor(value, dtype, dtype_hint, name=name) elif 'KerasTensor' in str(type(value)): # This is a hack to detect symbolic Keras tensors to work around # b/206660667. The issue was that symbolic Keras tensors would # break the Bijector cache on forward/inverse log det jacobian, # because tf.convert_to_tensor is not a no-op thereon. return value else: return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
def match(self, expected, actual): """Matches nested structures. Recursively matches shape and values of `expected` and `actual`. Handles scalars, numpy arrays and other python sequence containers e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue. Args: expected: Nested structure 1. actual: Nested structure 2. Raises: AssertionError if matching fails. """ if isinstance(expected, np.ndarray): expected = expected.tolist() if isinstance(actual, np.ndarray): actual = actual.tolist() self.assertEqual(type(expected), type(actual)) if nest.is_nested(expected): self.assertEqual(len(expected), len(actual)) if isinstance(expected, dict): for key1, key2 in zip(sorted(expected), sorted(actual)): self.assertEqual(key1, key2) self.match(expected[key1], actual[key2]) else: for item1, item2 in zip(expected, actual): self.match(item1, item2) elif isinstance(expected, sparse_tensor.SparseTensorValue): self.match( (expected.indices, expected.values, expected.dense_shape), (actual.indices, actual.values, actual.dense_shape)) elif isinstance(expected, ragged_tensor_value.RaggedTensorValue): self.match((expected.values, expected.row_splits), (actual.values, actual.row_splits)) else: self.assertEqual(expected, actual)
def _generate_initial_state(self, inputs, batch_size_tensor, state_size, dtype): """Generate a zero filled tensor with shape [batch_size, state_size].""" if batch_size_tensor is None or dtype is None: raise ValueError( 'batch_size and dtype cannot be None while constructing initial state: ' 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) def create_init_values(unnested_state_size): flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list() init_state_size = [batch_size_tensor] + flat_dims if self.learned_init == 'dynamic': return inputs[:, 0, :] @ self.w + self.b elif self.learned_init == 'static': # Broadcast learned init vector to batch size return tf.broadcast_to(self.b, init_state_size) else: return array_ops.zeros(init_state_size, dtype=dtype) if nest.is_nested(state_size): return nest.map_structure(create_init_values, state_size) else: return create_init_values(state_size)
def _should_broadcast(self, obj): return not nest.is_nested(obj)
def build_factored_surrogate_posterior( event_shape=None, constraining_bijectors=None, initial_unconstrained_loc=_sample_uniform_initial_loc, initial_unconstrained_scale=1e-2, trainable_distribution_fn=_build_trainable_normal_dist, seed=None, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. constraining_bijectors: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, constraining_bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. initial_unconstrained_loc: Optional Python `callable` with signature `tensor = initial_unconstrained_loc(shape, seed)` used to sample real-valued initializations for the unconstrained representation of each variable. May alternately be a nested structure of `Tensor`s, giving specific initial locations for each variable; these must have structure matching `event_shape` and shapes determined by the inverse image of `event_shape` under `constraining_bijectors`, which may optionally be prefixed with a common batch shape. Default value: `functools.partial(tf.random.uniform, minval=-2., maxval=2., dtype=tf.float32)`. initial_unconstrained_scale: Optional scalar float `Tensor` initial scale for the unconstrained distributions, or a nested structure of `Tensor` initial scales for each variable. Default value: `1e-2`. trainable_distribution_fn: Optional Python `callable` with signature `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale, event_ndims, validate_args)`. This is called for each model variable to build the corresponding factor in the surrogate posterior. It is expected that the distribution returned is supported on unconstrained real values. Default value: `functools.partial( tfp.experimental.vi.build_trainable_location_scale_distribution, distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution. seed: Python integer to seed the random number generator. This is used only when `initial_loc` is not specified. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Returns: surrogate_posterior: A `tfd.Distribution` instance whose samples have shape and structure matching that of `event_shape` or `initial_loc`. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. constraining_bijectors=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify one when we build the surrogate posterior. This function requires the initial location to be specified in *unconstrained* space; we do this by inverting the constraining bijectors (note this section also demonstrates the creation of a dict-structured model). ```python initial_loc = {'concentration': 0.4, 'rate': 0.2} constraining_bijectors={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_unconstrained_loc = tf.nest.map_fn( lambda b, x: b.inverse(x) if b is not None else x, constraining_bijectors, initial_loc) surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), constraining_bijectors=constraining_bijectors, initial_unconstrained_loc=initial_unconstrained_state, initial_unconstrained_scale=1e-4) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior') # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) flat_event_shapes = tf.nest.flatten(event_shape) # For simplicity, we'll work with flattened lists of state parts and # repack the structure at the end. if constraining_bijectors is not None: flat_bijectors = tf.nest.flatten(constraining_bijectors) else: flat_bijectors = [None for _ in flat_event_shapes] flat_unconstrained_event_shapes = [ b.inverse_event_shape_tensor(s) if b is not None else s for s, b in zip(flat_event_shapes, flat_bijectors) ] # Construct initial locations for the internal unconstrained dists. if callable( initial_unconstrained_loc): # Sample random initialization. flat_unconstrained_locs = [ initial_unconstrained_loc(shape=s, seed=seed()) for s in flat_unconstrained_event_shapes ] else: # Use provided initialization. flat_unconstrained_locs = nest.flatten_up_to( shallow_structure, initial_unconstrained_loc, check_types=False) if nest.is_nested(initial_unconstrained_scale): flat_unconstrained_scales = nest.flatten_up_to( shallow_structure, initial_unconstrained_scale, check_types=False) else: flat_unconstrained_scales = [ initial_unconstrained_scale for _ in flat_unconstrained_locs ] # Extract the rank of each event, so that we build distributions with the # correct event shapes. flat_unconstrained_event_ndims = [ prefer_static.rank_from_shape(s) for s in flat_unconstrained_event_shapes ] # Build the component surrogate posteriors. flat_component_dists = [] for initial_loc, initial_scale, event_ndims, bijector in zip( flat_unconstrained_locs, flat_unconstrained_scales, flat_unconstrained_event_ndims, flat_bijectors): unconstrained_dist = trainable_distribution_fn( initial_loc=initial_loc, initial_scale=initial_scale, event_ndims=event_ndims, validate_args=validate_args) flat_component_dists.append( bijector(unconstrained_dist ) if bijector is not None else unconstrained_dist) component_distributions = tf.nest.pack_sequence_as( event_shape, flat_component_dists) # Return a `Distribution` object whose events have the specified structure. return (joint_distribution_util. independent_joint_distribution_from_structure( component_distributions, validate_args=validate_args))
def _affine_surrogate_posterior(event_shape, operators='diag', bijector=None, base_distribution=normal.Normal, dtype=tf.float32, batch_shape=(), validate_args=False, name=None): """Builds a joint variational posterior with a given `event_shape`. This function builds a surrogate posterior by applying a trainable transformation to a standard base distribution and constraining the samples with `bijector`. The surrogate posterior has event shape equal to the input `event_shape`. This function is a convenience wrapper around `build_affine_surrogate_posterior_from_base_distribution` that allows the user to pass in the desired posterior `event_shape` instead of pre-constructed base distributions (at the expense of full control over the base distribution types and parameterizations). Args: event_shape: (Nested) event shape of the posterior. operators: Either a string or a list/tuple containing `LinearOperator` subclasses, `LinearOperator` instances, or callables returning `LinearOperator` instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is created. For complete documentation and examples, see `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which receives the `operators` arg if it is list-like. Default value: `"diag"`. bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and `scale`. The base distribution of the transformed surrogate has `loc=0.` and `scale=1.`. Default value: `tfd.Normal`. dtype: The `dtype` of the surrogate posterior. Default value: `tf.float32`. batch_shape: Batch shape (Python tuple, list, or int) of the surrogate posterior, to enable parallel optimization from multiple initializations. Default value: `()`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_affine_surrogate_posterior'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. #### Examples ```python tfd = tfp.distributions tfb = tfp.bijectors # Define a joint probabilistic model. Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample( tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) # Assume the `y` are observed, such that the posterior is a joint distribution # over `concentration` and `rate`. The posterior event shape is then equal to # the first two components of the model's event shape. posterior_event_shape = model.event_shape_tensor()[:-1] # Constrain the posterior values to be positive using the `Exp` bijector. bijector = [tfb.Exp(), tfb.Exp()] # Build a full-covariance surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_affine_surrogate_posterior( event_shape=posterior_event_shape, operators='tril', bijector=bijector)) # For an example defining `'operators'` as a list to express an alternative # covariance structure, see # `build_affine_surrogate_posterior_from_base_distribution`. # Fit the model. y = [0.2, 0.5, 0.3, 0.7] target_model = model.experimental_pin(y=y) losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` """ with tf.name_scope(name or 'build_affine_surrogate_posterior'): event_shape = nest.map_structure_up_to( _get_event_shape_shallow_structure(event_shape), lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): bijector = joint_map.JointMap(nest.map_structure( lambda b: identity.Identity() if b is None else b, bijector), validate_args=validate_args) if bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( bijector.inverse_event_shape_tensor(event_shape)) standard_base_distribution = nest.map_structure( lambda s: base_distribution(loc=tf.zeros([], dtype=dtype), scale=1.), unconstrained_event_shape) standard_base_distribution = nest.map_structure( lambda d, s: ( # pylint: disable=g-long-lambda sample.Sample(d, sample_shape=s, validate_args=validate_args) if distribution_util.shape_may_be_nontrivial(s) else d), standard_base_distribution, unconstrained_event_shape) if distribution_util.shape_may_be_nontrivial(batch_shape): standard_base_distribution = nest.map_structure( lambda d: batch_broadcast.BatchBroadcast( # pylint: disable=g-long-lambda d, to_shape=batch_shape, validate_args=validate_args), standard_base_distribution) surrogate_posterior = yield from _affine_surrogate_posterior_from_base_distribution( standard_base_distribution, operators=operators, bijector=bijector, validate_args=validate_args) return surrogate_posterior
def _affine_surrogate_posterior_from_base_distribution( base_distribution, operators='diag', bijector=None, initial_unconstrained_loc_fn=_sample_uniform_initial_loc, validate_args=False, name=None): """Builds a variational posterior by linearly transforming base distributions. This function builds a surrogate posterior by applying a trainable transformation to a base distribution (typically a `tfd.JointDistribution`) or nested structure of base distributions, and constraining the samples with `bijector`. Note that the distributions must have event shapes corresponding to the *pretransformed* surrogate posterior -- that is, if `bijector` contains a shape-changing bijector, then the corresponding base distribution event shape is the inverse event shape of the bijector applied to the desired surrogate posterior shape. The surrogate posterior is constucted as follows: 1. Flatten the base distribution event shapes to vectors, and pack the base distributions into a `tfd.JointDistribution`. 2. Apply a trainable blockwise LinearOperator bijector to the joint base distribution. 3. Apply the constraining bijectors and return the resulting trainable `tfd.TransformedDistribution` instance. Args: base_distribution: `tfd.Distribution` instance (typically a `tfd.JointDistribution`), or a nested structure of `tfd.Distribution` instances. operators: Either a string or a list/tuple containing `LinearOperator` subclasses, `LinearOperator` instances, or callables returning `LinearOperator` instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is created. For complete documentation and examples, see `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which receives the `operators` arg if it is list-like. Default value: `"diag"`. bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). initial_unconstrained_loc_fn: Optional Python `callable` with signature `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to sample real-valued initializations for the unconstrained location of each variable. Default value: `functools.partial(tf.random.stateless_uniform, minval=-2., maxval=2., dtype=tf.float32)`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_affine_surrogate_posterior_from_base_distribution'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. Raises: NotImplementedError: Base distributions with mixed dtypes are not supported. #### Examples ```python tfd = tfp.distributions tfb = tfp.bijectors # Fit a multivariate Normal surrogate posterior on the Eight Schools model # [1]. treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.] treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.] def model_fn(): avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect') log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev') school_effects = yield tfd.Sample( tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)), sample_shape=[8], name='school_effects') treatment_effects = yield tfd.Independent( tfd.Normal(loc=school_effects, scale=treatment_stddevs), reinterpreted_batch_ndims=1, name='treatment_effects') model = tfd.JointDistributionCoroutineAutoBatched(model_fn) # Pin the observed values in the model. target_model = model.experimental_pin(treatment_effects=treatment_effects) # Define a lower triangular structure of `LinearOperator` subclasses that # models full covariance among latent variables except for the 8 dimensions # of `school_effect`, which are modeled as independent (using # `LinearOperatorDiag`). operators = [ [tf.linalg.LinearOperatorLowerTriangular], [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular], [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix, tf.linalg.LinearOperatorDiag]] # Constrain the posterior values to the support of the prior. bijector = target_model.experimental_default_event_space_bijector() # Build a full-covariance surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution( base_distribution=base_distribution, operators=operators, bijector=bijector)) # Fit the model. losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` #### References [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013. """ with tf.name_scope(name or 'affine_surrogate_posterior_from_base_distribution'): if nest.is_nested(base_distribution): base_distribution = (joint_distribution_util. independent_joint_distribution_from_structure( base_distribution, validate_args=validate_args)) if nest.is_nested(bijector): bijector = joint_map.JointMap(nest.map_structure( lambda b: identity.Identity() if b is None else b, bijector), validate_args=validate_args) batch_shape = base_distribution.batch_shape_tensor() if tf.nest.is_nested( batch_shape): # Base is a classic JointDistribution. batch_shape = functools.reduce(ps.broadcast_shape, tf.nest.flatten(batch_shape)) event_shape = base_distribution.event_shape_tensor() flat_event_size = nest.flatten( nest.map_structure(ps.reduce_prod, event_shape)) base_dtypes = set([ dtype_util.base_dtype(d) for d in nest.flatten(base_distribution.dtype) ]) if len(base_dtypes) > 1: raise NotImplementedError( 'Base distributions with mixed dtype are not supported. Saw ' 'components of dtype {}'.format(base_dtypes)) base_dtype = list(base_dtypes)[0] num_components = len(flat_event_size) if operators == 'diag': operators = [tf.linalg.LinearOperatorDiag] * num_components elif operators == 'tril': operators = [[tf.linalg.LinearOperatorFullMatrix] * i + [tf.linalg.LinearOperatorLowerTriangular] for i in range(num_components)] elif isinstance(operators, str): raise ValueError( 'Unrecognized operator type {}. Valid operators are "diag", "tril", ' 'or a structure that can be passed to ' '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as ' 'the `operators` arg.'.format(operators)) if nest.is_nested(operators): operators = yield from trainable_linear_operators._trainable_linear_operator_block( # pylint: disable=protected-access operators, block_dims=flat_event_size, dtype=base_dtype, batch_shape=batch_shape) linop_bijector = ( scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock( scale=operators, validate_args=validate_args)) def generate_shift_bijector(s): x = yield trainable_state_util.Parameter( functools.partial(initial_unconstrained_loc_fn, ps.concat([batch_shape, [s]], axis=0), dtype=base_dtype)) return shift.Shift(x) loc_bijectors = yield from nest_util.map_structure_coroutine( generate_shift_bijector, flat_event_size) loc_bijector = joint_map.JointMap(loc_bijectors, validate_args=validate_args) unflatten_and_reshape = chain.Chain([ joint_map.JointMap(nest.map_structure(reshape.Reshape, event_shape), validate_args=validate_args), restructure.Restructure( nest.pack_sequence_as(event_shape, range(num_components))) ], validate_args=validate_args) bijectors = [] if bijector is None else [bijector] bijectors.extend([ unflatten_and_reshape, loc_bijector, # Allow the mean of the standard dist to shift from 0. linop_bijector ]) # Apply LinOp to scale the standard dist. bijector = chain.Chain(bijectors, validate_args=validate_args) flat_base_distribution = invert.Invert(unflatten_and_reshape)( base_distribution) return transformed_distribution.TransformedDistribution( flat_base_distribution, bijector=bijector, validate_args=validate_args)
def _factored_surrogate_posterior( # pylint: disable=dangerous-default-value event_shape=None, bijector=None, batch_shape=(), base_distribution_cls=normal.Normal, initial_parameters={'scale': 1e-2}, dtype=tf.float32, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. bijector: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. batch_shape: The `batch_shape` of the output distribution. Default value: `()`. base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated and optionally transformed by the bijector to define the component distributions. May optionally be a structure of such subclasses matching `event_shape`. Default value: `tfd.Normal`. initial_parameters: Optional `str : Tensor` dictionary specifying initial values for some or all of the base distribution's trainable parameters, or a Python `callable` with signature `value = parameter_init_fn(parameter_name, shape, dtype, seed, constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`. May optionally be a structure matching `event_shape` of such dictionaries and/or callables. Dictionary entries that do not correspond to parameter names are ignored. Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does not have a `scale` parameter). dtype: Optional float `dtype` for trainable parameters. May optionally be a structure of such `dtype`s matching `event_shape`. Default value: `tf.float32`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. bijector=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify initial parameters when we build the surrogate posterior. Note that these parameterize the distribution(s) over unconstrained values, so we need to transform our desired constrained locations using the inverse of the constraining bijector(s). ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), bijector={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_parameters={ 'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2}, 'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}}) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): event_space_bijector = joint_map.JointMap( nest.map_structure( lambda b: identity.Identity() if b is None else b, nest_util.coerce_structure(event_shape, bijector)), validate_args=validate_args) else: event_space_bijector = bijector if event_space_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( event_space_bijector.inverse_event_shape_tensor(event_shape)) unconstrained_batch_and_event_shape = tf.nest.map_structure( lambda s: ps.concat([batch_shape, s], axis=0), unconstrained_event_shape) base_distribution_cls = nest_util.broadcast_structure( event_shape, base_distribution_cls) try: # Check that we have initial parameters for each event part. nest.assert_shallow_structure(event_shape, initial_parameters) except (ValueError, TypeError): # If not, broadcast the parameters to match the event structure. # We do this manually rather than using `nest_util.broadcast_structure` # because the initial parameters can themselves be structures (dicts). initial_parameters = nest.map_structure( lambda x: initial_parameters, event_shape) unconstrained_trainable_distributions = yield from ( nest_util.map_structure_coroutine( trainable._make_trainable, # pylint: disable=protected-access cls=base_distribution_cls, initial_parameters=initial_parameters, batch_and_event_shape=unconstrained_batch_and_event_shape, parameter_dtype=nest_util.broadcast_structure( event_shape, dtype), _up_to=event_shape)) unconstrained_trainable_distribution = ( joint_distribution_util. independent_joint_distribution_from_structure( unconstrained_trainable_distributions, batch_ndims=ps.rank_from_shape(batch_shape), validate_args=validate_args)) if event_space_bijector is None: return unconstrained_trainable_distribution return transformed_distribution.TransformedDistribution( unconstrained_trainable_distribution, event_space_bijector)
def build_split_flow_surrogate_posterior(event_shape, trainable_bijector, constraining_bijector=None, base_distribution=normal.Normal, batch_shape=(), dtype=tf.float32, validate_args=False, name=None): """Builds a joint variational posterior by splitting a normalizing flow. Args: event_shape: (Nested) event shape of the surrogate posterior. trainable_bijector: A trainable `tfb.Bijector` instance that operates on `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or `tfb.RealNVP`. This bijector transforms the base distribution before it is split. constraining_bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and `scale`. The base distribution for the transformed surrogate has `loc=0.` and `scale=1.`. Default value: `tfd.Normal`. batch_shape: The `batch_shape` of the output distribution. Default value: `()`. dtype: The `dtype` of the surrogate posterior. Default value: `tf.float32`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_split_flow_surrogate_posterior'). Returns: surrogate_distribution: Trainable `tfd.TransformedDistribution` with event shape equal to `event_shape`. ### Examples ```python # Train a normalizing flow on the Eight Schools model [1]. treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.] treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.] model = tfd.JointDistributionNamed({ 'avg_effect': tfd.Normal(loc=0., scale=10., name='avg_effect'), 'log_stddev': tfd.Normal(loc=5., scale=1., name='log_stddev'), 'school_effects': lambda log_stddev, avg_effect: ( tfd.Independent( tfd.Normal( loc=avg_effect[..., None] * tf.ones(8), scale=tf.exp(log_stddev[..., None]) * tf.ones(8), name='school_effects'), reinterpreted_batch_ndims=1)), 'treatment_effects': lambda school_effects: tfd.Independent( tfd.Normal(loc=school_effects, scale=treatment_stddevs), reinterpreted_batch_ndims=1) }) # Pin the observed values in the model. target_model = model.experimental_pin(treatment_effects=treatment_effects) # Create a Masked Autoregressive Flow bijector. net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32) maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net) # Build and fit the surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_split_flow_surrogate_posterior( event_shape=target_model.event_shape_tensor(), trainable_bijector=maf, constraining_bijector=( target_model.experimental_default_event_space_bijector()))) losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` #### References [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013. """ with tf.name_scope(name or 'build_split_flow_surrogate_posterior'): shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to(shallow_structure, ps.convert_to_shape_tensor, event_shape) if nest.is_nested(constraining_bijector): constraining_bijector = joint_map.JointMap( nest.map_structure( lambda b: identity.Identity() if b is None else b, constraining_bijector), validate_args=validate_args) if constraining_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( constraining_bijector.inverse_event_shape_tensor(event_shape)) flat_base_event_shape = nest.flatten(unconstrained_event_shape) flat_base_event_size = nest.map_structure(tf.reduce_prod, flat_base_event_shape) event_size = tf.reduce_sum(flat_base_event_size) base_distribution = sample.Sample( base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.), [event_size]) # After transforming base distribution samples with `trainable_bijector`, # split them into vector-valued components. split_bijector = split.Split(flat_base_event_size, validate_args=validate_args) # Reshape the vectors to the correct posterior event shape. event_reshape = joint_map.JointMap(nest.map_structure( reshape.Reshape, unconstrained_event_shape), validate_args=validate_args) # Restructure the flat list of components to the correct posterior # structure. event_unflatten = restructure.Restructure( nest.pack_sequence_as(unconstrained_event_shape, range(len(flat_base_event_shape)))) bijectors = [] if constraining_bijector is None else [ constraining_bijector ] bijectors.extend([ event_reshape, event_unflatten, split_bijector, trainable_bijector ]) bijector = chain.Chain(bijectors, validate_args=validate_args) return transformed_distribution.TransformedDistribution( base_distribution, bijector=bijector, validate_args=validate_args)