def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs): """Compute inverse_log_det_jacobian over the composition.""" with self._name_and_control_scope(name): dtype = self.forward_dtype(**kwargs) y = nest_util.convert_to_nested_tensor( y, name='y', dtype_hint=dtype, dtype=None if bijector.SKIP_DTYPE_CHECKS else dtype, allow_packing=True) event_ndims = nest_util.coerce_structure( self.inverse_min_event_ndims, event_ndims) return self._inverse_log_det_jacobian(y, event_ndims, **kwargs)
def update_i_event_ndims(bij, event_ndims): event_ndims = nest_util.coerce_structure(bij.forward_min_event_ndims, event_ndims) if bij.has_static_min_event_ndims: return update_event_ndims( input_event_ndims=event_ndims, input_min_event_ndims=bij.forward_min_event_ndims, output_min_event_ndims=bij.inverse_min_event_ndims) elif isinstance(bij, composition.Composition): return bij._call_walk_forward(update_i_event_ndims, event_ndims) # pylint: disable=protected-access else: return sanitize_event_ndims(bij.forward_event_ndims(event_ndims))
def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs): """Compute inverse_log_det_jacobian over the composition.""" with self._name_and_control_scope(name): dtype = self.forward_dtype(**kwargs) y = nest_util.convert_to_nested_tensor( y, name='y', dtype_hint=dtype, dtype=None if bijector.SKIP_DTYPE_CHECKS else dtype, allow_packing=True) if event_ndims is None: if self._has_static_min_event_ndims: event_ndims = self.inverse_min_event_ndims else: raise ValueError('Composition bijector with non-static ' '`min_event_ndims` does not support ' '`event_ndims=None`. Please pass a value ' 'for `event_ndims`.') event_ndims = nest_util.coerce_structure( self.inverse_min_event_ndims, event_ndims) return self._inverse_log_det_jacobian(y, event_ndims, **kwargs)
def _update_forward_min_event_ndims( bij, downstream_quantities, get_forward_min_event_ndims=lambda b: b.forward_min_event_ndims, get_inverse_min_event_ndims=lambda b: b.inverse_min_event_ndims, inverse_event_ndims_fn=lambda b, nd: b.inverse_event_ndims(nd)): """Step backwards through the graph to infer `forward_min_event_ndims`. Args: bij: local tfb.Bijector instance at the current graph node. downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities` namedtuple, containing event_ndims that satisfy the bijector(s) downstream from `bij` in the graph. May be `None` if there are no such bijectors. get_forward_min_event_ndims: callable; may be overridden to swap forward/inverse direction. get_inverse_min_event_ndims: callable; may be overridden to swap forward/inverse direction. inverse_event_ndims_fn: callable; may be overridden to swap forward/inverse direction. Returns: downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities` namedtuple containing event_ndims that satisfy `bij` and all downstream bijectors. """ if downstream_quantities is None: # This is a leaf bijector. return MinEventNdimsInferenceDownstreamQuantities( forward_min_event_ndims=get_forward_min_event_ndims(bij), parts_interact=bij._parts_interact) # pylint: disable=protected-access inverse_min_event_ndims = get_inverse_min_event_ndims(bij) downstream_min_event_ndims = nest_util.coerce_structure( inverse_min_event_ndims, downstream_quantities.forward_min_event_ndims) # Update the min_event_ndims that is a valid input to downstream bijectors # to also be a valid *output* of this bijector, or equivalently, a valid # input to `bij.inverse`. rank_mismatches = tf.nest.flatten( tf.nest.map_structure( lambda dim, min_dim: dim - min_dim, downstream_min_event_ndims, inverse_min_event_ndims)) if downstream_quantities.parts_interact: # If downstream bijectors involve interaction between parts, # then a valid input to the downstream bijectors must augment the # `downstream_min_event_ndims` by the # same rank for every part (otherwise we would induce event shape # broadcasting). Hopefully, this will also avoid event-shape broadcasting # at the current bijector---if not, the composition is invalid, and the call # to `bij.inverse_event_ndims(valid_inverse_min_event_ndims)` below will # raise an exception. maximum_rank_deficiency = -ps.reduce_min([0] + rank_mismatches) valid_inverse_min_event_ndims = tf.nest.map_structure( lambda ndims: maximum_rank_deficiency + ndims, downstream_min_event_ndims) else: if bij._parts_interact: # pylint: disable=protected-access # If this bijector does *not* operate independently on its parts, then a # valid input to `inverse` cannot require event shape broadcasting. That # is, each part must have the same 'excess rank' above the local # inverse_min_event_ndims; we ensure this by construction. maximum_excess_rank = ps.reduce_max([0] + rank_mismatches) valid_inverse_min_event_ndims = tf.nest.map_structure( lambda ndims: maximum_excess_rank + ndims, inverse_min_event_ndims) else: # If all parts are independent, can take the pointwise max event_ndims. valid_inverse_min_event_ndims = tf.nest.map_structure( ps.maximum, downstream_min_event_ndims, inverse_min_event_ndims) return MinEventNdimsInferenceDownstreamQuantities( # Pull the desired output ndims back through the bijector, to get # the ndims of a valid *input*. forward_min_event_ndims=inverse_event_ndims_fn( bij, valid_inverse_min_event_ndims), parts_interact=( downstream_quantities.parts_interact or bij._parts_interact)) # pylint: disable=protected-access
def _call_walk_inverse(self, step_fn, *args, **kwargs): """Prepares args and calls `_walk_inverse`. Converts a tuple of structured positional arguments to a structure of argument tuples, and wraps `step_fn` to unpack inputs and re-pack returned values. This way, users may invoke walks using `map_structure` semantics, and the concrete `_walk` implementations can operate on single-structure of inputs (without worrying about tuple unpacking). For example, the `inverse` method looks roughly like: ```python MyComposition()._call_walk_inverse( lambda bij, y, **kwargs: bij.inverse(y, **kwargs), composite_inputs, **composite_kwargs) ``` More complex methods may need to mutate external state from `step_fn`: ```python shape_trace = {} def trace_step(bijector, y_shape): shape_trace[bijector.name] = y_shape return bijector.inverse_event_shape(y_shape) # Calling this populates the `shape_trace` dictionary composition.walk_forward(trace_step, composite_y_shape) ``` Args: step_fn: Callable applied to each wrapped bijector. Must accept a bijector instance followed by `len(args)` positional arguments whose structures match `bijector.inverse_min_event_ndims`, and return `len(args)` structures matching `bijector.forward_min_event_ndims`. *args: Input arguments propagated to nested bijectors. **kwargs: Keyword arguments forwarded to `_walk_inverse`. Returns: The transformed output. If multiple positional arguments are provided, a tuple of matching length will be returned. """ args = tuple(nest_util.coerce_structure(self.inverse_min_event_ndims, y) for y in args) if len(args) == 1: return self._walk_inverse(step_fn, *args, **kwargs) # Convert a tuple of structures to a structure of tuples. This # allows `_walk` methods to route aligned structures of inputs/outputs # independently, obviates the need for conditional tuple unpacking. packed_args = pack_structs_like(self.inverse_min_event_ndims, *args) def transform_wrapper(bij, packed_ys, **nested): ys = unpack_structs_like(bij.inverse_min_event_ndims, packed_ys) xs = step_fn(bij, *ys, **nested) return pack_structs_like(bij.forward_min_event_ndims, *xs) packed_result = self._walk_inverse( transform_wrapper, packed_args, **kwargs) return unpack_structs_like(self.forward_min_event_ndims, packed_result)
def testCoerceStructureRaises(self, target, source, message): with self.assertRaisesRegex((ValueError, TypeError), message): nest_util.coerce_structure(target, source)
def testCoerceStructure(self, target, source, expect): coerced = nest_util.coerce_structure(target, source) self.assertAllEqualNested(expect, coerced)
def _factored_surrogate_posterior( # pylint: disable=dangerous-default-value event_shape=None, bijector=None, batch_shape=(), base_distribution_cls=normal.Normal, initial_parameters={'scale': 1e-2}, dtype=tf.float32, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. bijector: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. batch_shape: The `batch_shape` of the output distribution. Default value: `()`. base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated and optionally transformed by the bijector to define the component distributions. May optionally be a structure of such subclasses matching `event_shape`. Default value: `tfd.Normal`. initial_parameters: Optional `str : Tensor` dictionary specifying initial values for some or all of the base distribution's trainable parameters, or a Python `callable` with signature `value = parameter_init_fn(parameter_name, shape, dtype, seed, constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`. May optionally be a structure matching `event_shape` of such dictionaries and/or callables. Dictionary entries that do not correspond to parameter names are ignored. Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does not have a `scale` parameter). dtype: Optional float `dtype` for trainable parameters. May optionally be a structure of such `dtype`s matching `event_shape`. Default value: `tf.float32`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. bijector=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify initial parameters when we build the surrogate posterior. Note that these parameterize the distribution(s) over unconstrained values, so we need to transform our desired constrained locations using the inverse of the constraining bijector(s). ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), bijector={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_parameters={ 'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2}, 'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}}) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): event_space_bijector = joint_map.JointMap( nest.map_structure( lambda b: identity.Identity() if b is None else b, nest_util.coerce_structure(event_shape, bijector)), validate_args=validate_args) else: event_space_bijector = bijector if event_space_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( event_space_bijector.inverse_event_shape_tensor(event_shape)) unconstrained_batch_and_event_shape = tf.nest.map_structure( lambda s: ps.concat([batch_shape, s], axis=0), unconstrained_event_shape) base_distribution_cls = nest_util.broadcast_structure( event_shape, base_distribution_cls) try: # Check that we have initial parameters for each event part. nest.assert_shallow_structure(event_shape, initial_parameters) except (ValueError, TypeError): # If not, broadcast the parameters to match the event structure. # We do this manually rather than using `nest_util.broadcast_structure` # because the initial parameters can themselves be structures (dicts). initial_parameters = nest.map_structure( lambda x: initial_parameters, event_shape) unconstrained_trainable_distributions = yield from ( nest_util.map_structure_coroutine( trainable._make_trainable, # pylint: disable=protected-access cls=base_distribution_cls, initial_parameters=initial_parameters, batch_and_event_shape=unconstrained_batch_and_event_shape, parameter_dtype=nest_util.broadcast_structure( event_shape, dtype), _up_to=event_shape)) unconstrained_trainable_distribution = ( joint_distribution_util. independent_joint_distribution_from_structure( unconstrained_trainable_distributions, batch_ndims=ps.rank_from_shape(batch_shape), validate_args=validate_args)) if event_space_bijector is None: return unconstrained_trainable_distribution return transformed_distribution.TransformedDistribution( unconstrained_trainable_distribution, event_space_bijector)