Пример #1
0
 def _sample_n(self, n, seed=None, **kwargs):
     validity_mask = tf.convert_to_tensor(self.validity_mask)
     # To avoid the shape gymnastics of drawing extra samples, we delegate
     # sampling to the BatchBroadcast distribution.
     bb = batch_broadcast.BatchBroadcast(self.distribution,
                                         ps.shape(validity_mask))
     samples = bb.sample(n, seed=seed, **kwargs)
     safe_val = tf.stop_gradient(self.safe_sample_fn(self.distribution))
     return tf.where(_add_event_dims_to_mask(validity_mask, dist=self),
                     samples, safe_val)
Пример #2
0
    def _maybe_broadcast_distribution_batch_shape(self):
        """Returns the base distribution broadcast to the TD's full batch shape."""
        distribution_batch_shape = self.distribution.batch_shape
        if tf.nest.is_nested(distribution_batch_shape) or self._is_joint:
            # TODO(b/191674464): Support joint distributions in BatchBroadcast.
            return self.distribution

        overall_batch_shape = self.batch_shape
        if (tensorshape_util.is_fully_defined(overall_batch_shape)
                and distribution_batch_shape == overall_batch_shape):
            # No need to broadcast if the distribution already has full batch shape.
            return self.distribution

        if not tensorshape_util.is_fully_defined(overall_batch_shape):
            overall_batch_shape = self.batch_shape_tensor()
        return batch_broadcast.BatchBroadcast(self.distribution,
                                              with_shape=overall_batch_shape)
Пример #3
0
def maybe_make_list_and_batch_broadcast(momentum_distribution, batch_shape):
    """Makes the distribution list-like and batched, if possible."""
    if not mcmc_util.is_list_like(momentum_distribution.dtype):
        momentum_distribution = _CompositeJointDistributionSequential(
            [momentum_distribution], name='joint_momentum')
    if (isinstance(momentum_distribution, jds.JointDistributionSequential) and
            not isinstance(momentum_distribution, jdn.JointDistributionNamed)
            and
            # Skip this step if we already batch broadcast.
            # TODO(b/182603117): Check public BatchBroadcast when JDS/JDN is
            # CompositeTensor.
            not all(
                isinstance(md, batch_broadcast._BatchBroadcast)  # pylint: disable=protected-access
                for md in momentum_distribution.model) and not any(
                    callable(dist_fn)
                    for dist_fn in momentum_distribution.model)):
        momentum_distribution = momentum_distribution.copy(model=[
            batch_broadcast.BatchBroadcast(md, with_shape=batch_shape)
            for md in momentum_distribution.model
        ])
    return momentum_distribution
Пример #4
0
    def _get_distributions_with_broadcast_batch_shape(self):
        """Broadcasts the mixture and component dists to have full batch shape."""
        overall_batch_shape = self.batch_shape
        if (tensorshape_util.is_fully_defined(overall_batch_shape)
                and self.components_distribution.batch_shape[:-1]
                == overall_batch_shape and
                self.mixture_distribution.batch_shape == overall_batch_shape):
            # No need to broadcast.
            return self.mixture_distribution, self.components_distribution

        if not tensorshape_util.is_fully_defined(overall_batch_shape):
            overall_batch_shape = self.batch_shape_tensor()
        # The mixture distribution is primarily accessed through its parameters
        # (e.g., logits), so broadcast those directly.
        mixture_distribution = (
            self.mixture_distribution._broadcast_parameters_with_batch_shape(
                overall_batch_shape))
        components_distribution = batch_broadcast.BatchBroadcast(
            self.components_distribution,
            with_shape=ps.concat([overall_batch_shape, [1]], axis=0))
        return mixture_distribution, components_distribution
Пример #5
0
  def _maybe_broadcast_distribution_batch_shape(self):
    """Returns the base distribution broadcast to the TD's full batch shape."""
    bijector_batch_shape = self._bijector_batch_shape()
    if tensorshape_util.rank(bijector_batch_shape) == 0:
      # Bijector batch shape is static and nonexistent: no broadcasting needed.
      return self.distribution
    if self._is_joint:
      # TODO(b/191674464): Support joint distributions in BatchBroadcast.
      return self.distribution

    distribution_batch_shape = self.distribution.batch_shape
    if (tensorshape_util.is_fully_defined(distribution_batch_shape) and
        distribution_batch_shape == tf.broadcast_static_shape(
            distribution_batch_shape,
            bijector_batch_shape)):
      # No need to broadcast if the distribution already has full batch shape.
      return self.distribution

    if not tensorshape_util.is_fully_defined(bijector_batch_shape):
      bijector_batch_shape = self._bijector_batch_shape_tensor()

    return batch_broadcast.BatchBroadcast(self.distribution,
                                          with_shape=bijector_batch_shape)
Пример #6
0
def _affine_surrogate_posterior(event_shape,
                                operators='diag',
                                bijector=None,
                                base_distribution=normal.Normal,
                                dtype=tf.float32,
                                batch_shape=(),
                                validate_args=False,
                                name=None):
    """Builds a joint variational posterior with a given `event_shape`.

  This function builds a surrogate posterior by applying a trainable
  transformation to a standard base distribution and constraining the samples
  with `bijector`. The surrogate posterior has event shape equal to
  the input `event_shape`.

  This function is a convenience wrapper around
  `build_affine_surrogate_posterior_from_base_distribution` that allows the
  user to pass in the desired posterior `event_shape` instead of
  pre-constructed base distributions (at the expense of full control over the
  base distribution types and parameterizations).

  Args:
    event_shape: (Nested) event shape of the posterior.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution of the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    batch_shape: Batch shape (Python tuple, list, or int) of the surrogate
      posterior, to enable parallel optimization from multiple initializations.
      Default value: `()`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_affine_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  #### Examples

  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Define a joint probabilistic model.
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(
        tfd.Gamma(concentration=concentration, rate=rate),
        sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)

  # Assume the `y` are observed, such that the posterior is a joint distribution
  # over `concentration` and `rate`. The posterior event shape is then equal to
  # the first two components of the model's event shape.
  posterior_event_shape = model.event_shape_tensor()[:-1]

  # Constrain the posterior values to be positive using the `Exp` bijector.
  bijector = [tfb.Exp(), tfb.Exp()]

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior(
        event_shape=posterior_event_shape,
        operators='tril',
        bijector=bijector))

  # For an example defining `'operators'` as a list to express an alternative
  # covariance structure, see
  # `build_affine_surrogate_posterior_from_base_distribution`.

  # Fit the model.
  y = [0.2, 0.5, 0.3, 0.7]
  target_model = model.experimental_pin(y=y)
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```
  """
    with tf.name_scope(name or 'build_affine_surrogate_posterior'):

        event_shape = nest.map_structure_up_to(
            _get_event_shape_shallow_structure(event_shape),
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        if bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                bijector.inverse_event_shape_tensor(event_shape))

        standard_base_distribution = nest.map_structure(
            lambda s: base_distribution(loc=tf.zeros([], dtype=dtype),
                                        scale=1.), unconstrained_event_shape)
        standard_base_distribution = nest.map_structure(
            lambda d, s: (  # pylint: disable=g-long-lambda
                sample.Sample(d, sample_shape=s, validate_args=validate_args)
                if distribution_util.shape_may_be_nontrivial(s) else d),
            standard_base_distribution,
            unconstrained_event_shape)
        if distribution_util.shape_may_be_nontrivial(batch_shape):
            standard_base_distribution = nest.map_structure(
                lambda d: batch_broadcast.BatchBroadcast(  # pylint: disable=g-long-lambda
                    d,
                    to_shape=batch_shape,
                    validate_args=validate_args),
                standard_base_distribution)

        surrogate_posterior = yield from _affine_surrogate_posterior_from_base_distribution(
            standard_base_distribution,
            operators=operators,
            bijector=bijector,
            validate_args=validate_args)
        return surrogate_posterior
Пример #7
0
    def __init__(self,
                 distribution,
                 shift,
                 scale,
                 tailweight=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="LambertWDistribution"):
        """Initializes the class.

    Args:
      distribution: `tf.Distribution`-like instance. Distribution F that is
        transformed to produce this Lambert W x F distribution.
      shift: shift that should be applied before & after tail transformation.
        For a location-scale family `distribution` (e.g., `Normal` or
        `StudentT`) this usually is set as the mean / location parameter. For a
        scale family `distribution` (e.g., `Gamma` or `Fisher`) this must be
        set to 0 to guarantee a proper transformation on the positive
        real-line.
      scale: scaling factor that should be applied before & after the tail
        trarnsformation.  Usually the standard deviation or scaling parameter
        of the `distribution`.
      tailweight: Tail parameter `delta` of the resulting Lambert W x F
        distribution(s).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: A name for the operation (optional).
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([tailweight, shift, scale],
                                            tf.float32)
            tailweight = 0. if tailweight is None else tailweight
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name="tailweight", dtype=dtype)
            self._shift = tensor_util.convert_nonref_to_tensor(shift,
                                                               name="shift",
                                                               dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self.tailweight, self.shift, self.scale))
            self._allow_nan_stats = allow_nan_stats
            super(LambertWDistribution, self).__init__(
                # TODO(b/160730249): Remove broadcasting when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution=batch_broadcast.BatchBroadcast(
                    distribution,
                    with_shape=ps.broadcast_shape(
                        ps.shape(tailweight),
                        ps.broadcast_shape(ps.shape(shift), ps.shape(scale)))),
                bijector=tfb.LambertWTail(shift=shift,
                                          scale=scale,
                                          tailweight=tailweight,
                                          validate_args=validate_args),
                parameters=parameters,
                validate_args=validate_args,
                name=name)
Пример #8
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
  """Helper which processes input args to meet list-like assumptions."""
  state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
  if state_gradients_are_stopped:
    state_parts = [tf.stop_gradient(x) for x in state_parts]
  target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
      target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob)
  step_sizes, _ = mcmc_util.prepare_state_parts(
      step_size, dtype=target_log_prob.dtype, name='step_size')

  # Default momentum distribution is None, but if `store_parameters_in_results`
  # is true, then `momentum_distribution` defaults to DefaultStandardNormal().
  if (momentum_distribution is None or
      isinstance(momentum_distribution, DefaultStandardNormal)):
    batch_rank = ps.rank(target_log_prob)
    def _batched_isotropic_normal_like(state_part):
      return sample.Sample(
          normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.),
          ps.shape(state_part)[batch_rank:])

    momentum_distribution = jds.JointDistributionSequential(
        [_batched_isotropic_normal_like(state_part)
         for state_part in state_parts])

  # The momentum will get "maybe listified" to zip with the state parts,
  # and this step makes sure that the momentum distribution will have the
  # same "maybe listified" underlying shape.
  if not mcmc_util.is_list_like(momentum_distribution.dtype):
    momentum_distribution = jds.JointDistributionSequential(
        [momentum_distribution])

  # If all underlying distributions are independent, we can offer some help.
  # This code will also trigger for the output of the two blocks above.
  if (isinstance(momentum_distribution, jds.JointDistributionSequential) and
      not any(callable(dist_fn) for dist_fn in momentum_distribution.model)):
    batch_shape = ps.shape(target_log_prob)
    momentum_distribution = momentum_distribution.copy(model=[
        batch_broadcast.BatchBroadcast(md, to_shape=batch_shape)
        for md in momentum_distribution.model
    ])

  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      momentum_distribution,
      target_log_prob,
      grads_target_log_prob,
  ]