def _sample_n(self, n, seed=None, **kwargs): validity_mask = tf.convert_to_tensor(self.validity_mask) # To avoid the shape gymnastics of drawing extra samples, we delegate # sampling to the BatchBroadcast distribution. bb = batch_broadcast.BatchBroadcast(self.distribution, ps.shape(validity_mask)) samples = bb.sample(n, seed=seed, **kwargs) safe_val = tf.stop_gradient(self.safe_sample_fn(self.distribution)) return tf.where(_add_event_dims_to_mask(validity_mask, dist=self), samples, safe_val)
def _maybe_broadcast_distribution_batch_shape(self): """Returns the base distribution broadcast to the TD's full batch shape.""" distribution_batch_shape = self.distribution.batch_shape if tf.nest.is_nested(distribution_batch_shape) or self._is_joint: # TODO(b/191674464): Support joint distributions in BatchBroadcast. return self.distribution overall_batch_shape = self.batch_shape if (tensorshape_util.is_fully_defined(overall_batch_shape) and distribution_batch_shape == overall_batch_shape): # No need to broadcast if the distribution already has full batch shape. return self.distribution if not tensorshape_util.is_fully_defined(overall_batch_shape): overall_batch_shape = self.batch_shape_tensor() return batch_broadcast.BatchBroadcast(self.distribution, with_shape=overall_batch_shape)
def maybe_make_list_and_batch_broadcast(momentum_distribution, batch_shape): """Makes the distribution list-like and batched, if possible.""" if not mcmc_util.is_list_like(momentum_distribution.dtype): momentum_distribution = _CompositeJointDistributionSequential( [momentum_distribution], name='joint_momentum') if (isinstance(momentum_distribution, jds.JointDistributionSequential) and not isinstance(momentum_distribution, jdn.JointDistributionNamed) and # Skip this step if we already batch broadcast. # TODO(b/182603117): Check public BatchBroadcast when JDS/JDN is # CompositeTensor. not all( isinstance(md, batch_broadcast._BatchBroadcast) # pylint: disable=protected-access for md in momentum_distribution.model) and not any( callable(dist_fn) for dist_fn in momentum_distribution.model)): momentum_distribution = momentum_distribution.copy(model=[ batch_broadcast.BatchBroadcast(md, with_shape=batch_shape) for md in momentum_distribution.model ]) return momentum_distribution
def _get_distributions_with_broadcast_batch_shape(self): """Broadcasts the mixture and component dists to have full batch shape.""" overall_batch_shape = self.batch_shape if (tensorshape_util.is_fully_defined(overall_batch_shape) and self.components_distribution.batch_shape[:-1] == overall_batch_shape and self.mixture_distribution.batch_shape == overall_batch_shape): # No need to broadcast. return self.mixture_distribution, self.components_distribution if not tensorshape_util.is_fully_defined(overall_batch_shape): overall_batch_shape = self.batch_shape_tensor() # The mixture distribution is primarily accessed through its parameters # (e.g., logits), so broadcast those directly. mixture_distribution = ( self.mixture_distribution._broadcast_parameters_with_batch_shape( overall_batch_shape)) components_distribution = batch_broadcast.BatchBroadcast( self.components_distribution, with_shape=ps.concat([overall_batch_shape, [1]], axis=0)) return mixture_distribution, components_distribution
def _maybe_broadcast_distribution_batch_shape(self): """Returns the base distribution broadcast to the TD's full batch shape.""" bijector_batch_shape = self._bijector_batch_shape() if tensorshape_util.rank(bijector_batch_shape) == 0: # Bijector batch shape is static and nonexistent: no broadcasting needed. return self.distribution if self._is_joint: # TODO(b/191674464): Support joint distributions in BatchBroadcast. return self.distribution distribution_batch_shape = self.distribution.batch_shape if (tensorshape_util.is_fully_defined(distribution_batch_shape) and distribution_batch_shape == tf.broadcast_static_shape( distribution_batch_shape, bijector_batch_shape)): # No need to broadcast if the distribution already has full batch shape. return self.distribution if not tensorshape_util.is_fully_defined(bijector_batch_shape): bijector_batch_shape = self._bijector_batch_shape_tensor() return batch_broadcast.BatchBroadcast(self.distribution, with_shape=bijector_batch_shape)
def _affine_surrogate_posterior(event_shape, operators='diag', bijector=None, base_distribution=normal.Normal, dtype=tf.float32, batch_shape=(), validate_args=False, name=None): """Builds a joint variational posterior with a given `event_shape`. This function builds a surrogate posterior by applying a trainable transformation to a standard base distribution and constraining the samples with `bijector`. The surrogate posterior has event shape equal to the input `event_shape`. This function is a convenience wrapper around `build_affine_surrogate_posterior_from_base_distribution` that allows the user to pass in the desired posterior `event_shape` instead of pre-constructed base distributions (at the expense of full control over the base distribution types and parameterizations). Args: event_shape: (Nested) event shape of the posterior. operators: Either a string or a list/tuple containing `LinearOperator` subclasses, `LinearOperator` instances, or callables returning `LinearOperator` instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is created. For complete documentation and examples, see `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which receives the `operators` arg if it is list-like. Default value: `"diag"`. bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and `scale`. The base distribution of the transformed surrogate has `loc=0.` and `scale=1.`. Default value: `tfd.Normal`. dtype: The `dtype` of the surrogate posterior. Default value: `tf.float32`. batch_shape: Batch shape (Python tuple, list, or int) of the surrogate posterior, to enable parallel optimization from multiple initializations. Default value: `()`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_affine_surrogate_posterior'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. #### Examples ```python tfd = tfp.distributions tfb = tfp.bijectors # Define a joint probabilistic model. Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample( tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) # Assume the `y` are observed, such that the posterior is a joint distribution # over `concentration` and `rate`. The posterior event shape is then equal to # the first two components of the model's event shape. posterior_event_shape = model.event_shape_tensor()[:-1] # Constrain the posterior values to be positive using the `Exp` bijector. bijector = [tfb.Exp(), tfb.Exp()] # Build a full-covariance surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_affine_surrogate_posterior( event_shape=posterior_event_shape, operators='tril', bijector=bijector)) # For an example defining `'operators'` as a list to express an alternative # covariance structure, see # `build_affine_surrogate_posterior_from_base_distribution`. # Fit the model. y = [0.2, 0.5, 0.3, 0.7] target_model = model.experimental_pin(y=y) losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` """ with tf.name_scope(name or 'build_affine_surrogate_posterior'): event_shape = nest.map_structure_up_to( _get_event_shape_shallow_structure(event_shape), lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): bijector = joint_map.JointMap(nest.map_structure( lambda b: identity.Identity() if b is None else b, bijector), validate_args=validate_args) if bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( bijector.inverse_event_shape_tensor(event_shape)) standard_base_distribution = nest.map_structure( lambda s: base_distribution(loc=tf.zeros([], dtype=dtype), scale=1.), unconstrained_event_shape) standard_base_distribution = nest.map_structure( lambda d, s: ( # pylint: disable=g-long-lambda sample.Sample(d, sample_shape=s, validate_args=validate_args) if distribution_util.shape_may_be_nontrivial(s) else d), standard_base_distribution, unconstrained_event_shape) if distribution_util.shape_may_be_nontrivial(batch_shape): standard_base_distribution = nest.map_structure( lambda d: batch_broadcast.BatchBroadcast( # pylint: disable=g-long-lambda d, to_shape=batch_shape, validate_args=validate_args), standard_base_distribution) surrogate_posterior = yield from _affine_surrogate_posterior_from_base_distribution( standard_base_distribution, operators=operators, bijector=bijector, validate_args=validate_args) return surrogate_posterior
def __init__(self, distribution, shift, scale, tailweight=None, validate_args=False, allow_nan_stats=True, name="LambertWDistribution"): """Initializes the class. Args: distribution: `tf.Distribution`-like instance. Distribution F that is transformed to produce this Lambert W x F distribution. shift: shift that should be applied before & after tail transformation. For a location-scale family `distribution` (e.g., `Normal` or `StudentT`) this usually is set as the mean / location parameter. For a scale family `distribution` (e.g., `Gamma` or `Fisher`) this must be set to 0 to guarantee a proper transformation on the positive real-line. scale: scaling factor that should be applied before & after the tail trarnsformation. Usually the standard deviation or scaling parameter of the `distribution`. tailweight: Tail parameter `delta` of the resulting Lambert W x F distribution(s). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: A name for the operation (optional). """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([tailweight, shift, scale], tf.float32) tailweight = 0. if tailweight is None else tailweight self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name="tailweight", dtype=dtype) self._shift = tensor_util.convert_nonref_to_tensor(shift, name="shift", dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name="scale", dtype=dtype) dtype_util.assert_same_float_dtype( (self.tailweight, self.shift, self.scale)) self._allow_nan_stats = allow_nan_stats super(LambertWDistribution, self).__init__( # TODO(b/160730249): Remove broadcasting when # TransformedDistribution's bijector can modify its `batch_shape`. distribution=batch_broadcast.BatchBroadcast( distribution, with_shape=ps.broadcast_shape( ps.shape(tailweight), ps.broadcast_shape(ps.shape(shift), ps.shape(scale)))), bijector=tfb.LambertWTail(shift=shift, scale=scale, tailweight=tailweight, validate_args=validate_args), parameters=parameters, validate_args=validate_args, name=name)
def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts( step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None, but if `store_parameters_in_results` # is true, then `momentum_distribution` defaults to DefaultStandardNormal(). if (momentum_distribution is None or isinstance(momentum_distribution, DefaultStandardNormal)): batch_rank = ps.rank(target_log_prob) def _batched_isotropic_normal_like(state_part): return sample.Sample( normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.), ps.shape(state_part)[batch_rank:]) momentum_distribution = jds.JointDistributionSequential( [_batched_isotropic_normal_like(state_part) for state_part in state_parts]) # The momentum will get "maybe listified" to zip with the state parts, # and this step makes sure that the momentum distribution will have the # same "maybe listified" underlying shape. if not mcmc_util.is_list_like(momentum_distribution.dtype): momentum_distribution = jds.JointDistributionSequential( [momentum_distribution]) # If all underlying distributions are independent, we can offer some help. # This code will also trigger for the output of the two blocks above. if (isinstance(momentum_distribution, jds.JointDistributionSequential) and not any(callable(dist_fn) for dist_fn in momentum_distribution.model)): batch_shape = ps.shape(target_log_prob) momentum_distribution = momentum_distribution.copy(model=[ batch_broadcast.BatchBroadcast(md, to_shape=batch_shape) for md in momentum_distribution.model ]) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]