def _call_execute_model(self,
                            sample_shape,
                            seed,
                            value=None,
                            sample_and_trace_fn=None):
        """Wraps the base `_call_execute_model` with vectorized_map."""
        value_might_have_sample_dims = (
            value is not None and _might_have_excess_ndims(
                # Double-flatten in case any components have structured events.
                flat_value=nest.flatten_up_to(self._single_sample_ndims,
                                              self._model_flatten(value),
                                              check_types=False),
                flat_core_ndims=tf.nest.flatten(self._single_sample_ndims)))
        sample_shape_may_be_nontrivial = (
            distribution_util.shape_may_be_nontrivial(sample_shape))

        if not self.use_vectorized_map or not (sample_shape_may_be_nontrivial
                                               or  # pylint: disable=protected-access
                                               value_might_have_sample_dims):
            # No need to auto-vectorize.
            return joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                self,
                sample_shape=sample_shape,
                seed=seed,
                value=value,
                sample_and_trace_fn=sample_and_trace_fn)

        # Set up for autovectorized sampling. To support the `value` arg, we need to
        # first understand which dims are from the model itself, then wrap
        # `_call_execute_model` to batch over all remaining dims.
        value_core_ndims = None
        if value is not None:
            value_core_ndims = tf.nest.map_structure(
                lambda v, nd: None if v is None else nd,
                value,
                self._model_unflatten(self._single_sample_ndims),
                check_types=False)

        vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic(
            lambda v, seed: (  # pylint: disable=g-long-lambda
                joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                    self,
                    sample_shape=(),
                    seed=seed,
                    value=v,
                    sample_and_trace_fn=sample_and_trace_fn)),
            core_ndims=[value_core_ndims, None],
            validate_args=self.validate_args)
        # Redefine the polymorphic fn to hack around `make_rank_polymorphic`
        # not currently supporting keyword args. This is needed because the
        # `iid_sample` wrapper below expects to pass through a `seed` kwarg.
        vectorized_execute_model = (
            lambda v, seed: vectorized_execute_model_helper(v, seed))  # pylint: disable=unnecessary-lambda

        if sample_shape_may_be_nontrivial:
            vectorized_execute_model = vectorization_util.iid_sample(
                vectorized_execute_model, sample_shape)

        return vectorized_execute_model(value, seed=seed)
    def sample_distributions(self,
                             sample_shape=(),
                             seed=None,
                             value=None,
                             name='sample_distributions',
                             **kwargs):
        """Generate samples and the (random) distributions.

    Note that a call to `sample()` without arguments will generate a single
    sample.

    Args:
      sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
      seed: Python integer seed for generating random numbers.
      value: `list` of `Tensor`s in `distribution_fn` order to use to
        parameterize other ("downstream") distribution makers.
        Default value: `None` (i.e., draw a sample from each distribution).
      name: name prepended to ops created by this function.
        Default value: `"sample_distributions"`.
      **kwargs: This is an alternative to passing a `value`, and achieves the
        same effect. Named arguments will be used to parameterize other
        dependent ("downstream") distribution-making functions. If a `value`
        argument is also provided, raises a ValueError.

    Returns:
      distributions: a `tuple` of `Distribution` instances for each of
        `distribution_fn`.
      samples: a `tuple` of `Tensor`s with prepended dimensions `sample_shape`
        for each of `distribution_fn`.
    """
        with self._name_and_control_scope(name):
            value = self._resolve_value(value=value,
                                        allow_partially_specified=True,
                                        **kwargs)
            might_have_batch_dims = (
                distribution_util.shape_may_be_nontrivial(sample_shape)
                or value is not None)
            if self.use_vectorized_map and might_have_batch_dims:
                raise NotImplementedError(
                    '`sample_distributions` with nontrivial '
                    'sample shape is not yet supported '
                    'for autovectorized JointDistributions.')

            ds, xs = self._call_flat_sample_distributions(sample_shape,
                                                          seed=seed,
                                                          value=value)
            if not might_have_batch_dims:
                # This is a single sample with no pinned values; this call will cache
                # the distributions if they are not already cached.
                self._get_single_sample_distributions(candidate_dists=ds)

            return self._model_unflatten(ds), self._model_unflatten(xs)
Example #3
0
 def state_space_model_likelihood(**param_vals):
     ssm = self.make_state_space_model(
         param_vals=param_vals,
         num_timesteps=num_timesteps,
         initial_step=initial_step,
         mask=mask,
         experimental_parallelize=experimental_parallelize)
     # Looping LGSSM methods are really expensive in eager mode; wrap them
     # to keep this from slowing things down in interactive use.
     ssm = tfe_util.JitPublicMethods(ssm, trace_only=True)
     if distribution_util.shape_may_be_nontrivial(trajectories_shape):
         return sample.Sample(ssm, sample_shape=trajectories_shape)
     return ssm
    def _sample_n(self, sample_shape, seed, value=None):
        might_have_batch_dims = (
            distribution_util.shape_may_be_nontrivial(sample_shape)
            or value is not None)
        if might_have_batch_dims:
            xs = self._call_execute_model(
                sample_shape,
                seed=seed,
                value=value,
                sample_and_trace_fn=trace_values_only)
        else:
            ds, xs = zip(*self._call_execute_model(
                sample_shape,
                seed=seed,
                value=value,
                sample_and_trace_fn=trace_distributions_and_values))
            # This is a single sample with no pinned values; this call will cache
            # the distributions if they are not already cached.
            self._get_single_sample_distributions(candidate_dists=ds)

        return self._model_unflatten(xs)
Example #5
0
def _affine_surrogate_posterior(event_shape,
                                operators='diag',
                                bijector=None,
                                base_distribution=normal.Normal,
                                dtype=tf.float32,
                                batch_shape=(),
                                validate_args=False,
                                name=None):
    """Builds a joint variational posterior with a given `event_shape`.

  This function builds a surrogate posterior by applying a trainable
  transformation to a standard base distribution and constraining the samples
  with `bijector`. The surrogate posterior has event shape equal to
  the input `event_shape`.

  This function is a convenience wrapper around
  `build_affine_surrogate_posterior_from_base_distribution` that allows the
  user to pass in the desired posterior `event_shape` instead of
  pre-constructed base distributions (at the expense of full control over the
  base distribution types and parameterizations).

  Args:
    event_shape: (Nested) event shape of the posterior.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution of the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    batch_shape: Batch shape (Python tuple, list, or int) of the surrogate
      posterior, to enable parallel optimization from multiple initializations.
      Default value: `()`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_affine_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  #### Examples

  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Define a joint probabilistic model.
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(
        tfd.Gamma(concentration=concentration, rate=rate),
        sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)

  # Assume the `y` are observed, such that the posterior is a joint distribution
  # over `concentration` and `rate`. The posterior event shape is then equal to
  # the first two components of the model's event shape.
  posterior_event_shape = model.event_shape_tensor()[:-1]

  # Constrain the posterior values to be positive using the `Exp` bijector.
  bijector = [tfb.Exp(), tfb.Exp()]

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior(
        event_shape=posterior_event_shape,
        operators='tril',
        bijector=bijector))

  # For an example defining `'operators'` as a list to express an alternative
  # covariance structure, see
  # `build_affine_surrogate_posterior_from_base_distribution`.

  # Fit the model.
  y = [0.2, 0.5, 0.3, 0.7]
  target_model = model.experimental_pin(y=y)
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```
  """
    with tf.name_scope(name or 'build_affine_surrogate_posterior'):

        event_shape = nest.map_structure_up_to(
            _get_event_shape_shallow_structure(event_shape),
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        if bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                bijector.inverse_event_shape_tensor(event_shape))

        standard_base_distribution = nest.map_structure(
            lambda s: base_distribution(loc=tf.zeros([], dtype=dtype),
                                        scale=1.), unconstrained_event_shape)
        standard_base_distribution = nest.map_structure(
            lambda d, s: (  # pylint: disable=g-long-lambda
                sample.Sample(d, sample_shape=s, validate_args=validate_args)
                if distribution_util.shape_may_be_nontrivial(s) else d),
            standard_base_distribution,
            unconstrained_event_shape)
        if distribution_util.shape_may_be_nontrivial(batch_shape):
            standard_base_distribution = nest.map_structure(
                lambda d: batch_broadcast.BatchBroadcast(  # pylint: disable=g-long-lambda
                    d,
                    to_shape=batch_shape,
                    validate_args=validate_args),
                standard_base_distribution)

        surrogate_posterior = yield from _affine_surrogate_posterior_from_base_distribution(
            standard_base_distribution,
            operators=operators,
            bijector=bijector,
            validate_args=validate_args)
        return surrogate_posterior