def _maybe_build_joint_distribution(structure_of_distributions):
    """Turns a (potentially nested) structure of dists into a single dist."""
    # Base case: if we already have a Distribution, return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # Otherwise, recursively convert all interior nested structures into JDs.
    outer_structure = tf.nest.map_structure(_maybe_build_joint_distribution,
                                            structure_of_distributions)
    if (hasattr(outer_structure, '_asdict')
            or isinstance(outer_structure, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(outer_structure)
    else:
        return joint_distribution_sequential.JointDistributionSequential(
            outer_structure)
Example #2
0
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            independent_joint_distribution_from_structure,
            structure_of_distributions)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(
            structure_of_distributions, validate_args=validate_args)
    return joint_distribution_sequential.JointDistributionSequential(
        structure_of_distributions, validate_args=validate_args)
def joint_prior_on_parameters_and_state(parameter_prior,
                                        parameterized_initial_state_prior_fn,
                                        parameter_constraining_bijector,
                                        prior_is_constrained=True):
    """Constructs a joint dist. from p(parameters) and p(state | parameters)."""
    if prior_is_constrained:
        parameter_prior = transformed_distribution.TransformedDistribution(
            parameter_prior,
            invert.Invert(parameter_constraining_bijector),
            name='unconstrained_parameter_prior')

    return joint_distribution_named.JointDistributionNamed(
        ParametersAndState(
            unconstrained_parameters=parameter_prior,
            state=lambda unconstrained_parameters: (  # pylint: disable=g-long-lambda
                parameterized_initial_state_prior_fn(
                    parameter_constraining_bijector.forward(
                        unconstrained_parameters)))))
  def augmented_fn(step, state_with_history, **kwargs):
    """Builds history-tracking dist. over `StateWithHistory` instances."""
    with tf.name_scope('augment_with_state_history'):
      new_state_dist = fn(step, state_with_history, **kwargs)

      def new_state_history_dist(state):
        with tf.name_scope('new_state_history_dist'):
          new_state_histories = tf.nest.map_structure(
              lambda h, s: tf.concat([h[:, 1:],  # pylint: disable=g-long-lambda
                                      s[:, tf.newaxis]], axis=1),
              state_with_history.state_history,
              state)
          return (
              joint_distribution_util
              .independent_joint_distribution_from_structure(
                  _wrap_as_distributions(new_state_histories)))

    return joint_distribution_named.JointDistributionNamed(
        StateWithHistory(
            state=new_state_dist,
            state_history=new_state_history_dist))
  def params_and_state_transition_fn(step,
                                     params_and_state,
                                     perturbation_scale,
                                     **kwargs):
    """Transition function operating on a `ParamsAndState` namedtuple."""
    # Extract the state, to pass through to the observation fn.
    unconstrained_params, state = params_and_state
    if 'state_history' in kwargs:
      kwargs['state_history'] = kwargs['state_history'].state

    # Perturb each (unconstrained) parameter with normally-distributed noise.
    if not tf.nest.is_nested(perturbation_scale):
      perturbation_scale = tf.nest.map_structure(
          lambda x: tf.convert_to_tensor(perturbation_scale,  # pylint: disable=g-long-lambda
                                         name='perturbation_scale',
                                         dtype=x.dtype),
          unconstrained_params)
    perturbed_unconstrained_parameter_dists = tf.nest.map_structure(
        lambda x, p, s: independent.Independent(  # pylint: disable=g-long-lambda
            normal.Normal(loc=x, scale=p),
            reinterpreted_batch_ndims=prefer_static.rank_from_shape(s)),
        unconstrained_params,
        perturbation_scale,
        parameter_prior.event_shape_tensor())

    # For the joint transition, pass the perturbed parameters
    # into the original transition fn (after pushing them into constrained
    # space).
    return joint_distribution_named.JointDistributionNamed(
        ParametersAndState(
            unconstrained_parameters=_maybe_build_joint_distribution(
                perturbed_unconstrained_parameter_dists),
            state=lambda unconstrained_parameters: (  # pylint: disable=g-long-lambda
                parameterized_transition_fn(
                    step,
                    state,
                    parameters=parameter_constraining_bijector.forward(
                        unconstrained_parameters),
                    **kwargs))))
def augment_prior_with_state_history(prior, history_size):
    """Augments a prior or proposal distribution's state space with history.

  The augmented state space is over `tfp.experimental.mcmc.StateWithHistory`
  namedtuples, which contain the original `state` as well as a `state_history`.
  The `state_history` is a structure of `Tensor`s matching `state`, of shape
  `concat([[num_particles, history_size], state.shape[1:]])`. In other words,
  previous states for each particle are indexed along `axis=1`, to the right
  of the particle indices.

  Args:
    prior: a (joint) distribution over the initial latent state,
      with optional batch shape `[b1, ..., bN]`.
    history_size: integer `Tensor` number of steps of history to pass.
  Returns:
    augmented_prior: a `tfd.JointDistributionNamed` instance whose samples
      are `tfp.experimental.mcmc.StateWithHistory` namedtuples.

  #### Example

  As a toy example, let's see how we'd use state history to experiment with
  stochastic 'Fibonacci sequences'. We'll assume that the sequence starts at a
  value sampled from a Poisson distribution.

  ```python
  initial_state_prior = tfd.Poisson(5.)
  initial_state_with_history_prior = (
    tfp.experimental.mcmc.augment_prior_with_state_history(
      initial_state_prior, history_size=2))
  ```

  Note that we've augmented the state space to include a state history of
  size two. The augmented state space is over instances of
  `tfp.experimental.mcmc.StateWithHistory`. Initially, the state history
  will simply tile the initial state: if
  `s = initial_state_with_history_prior.sample()`, then
  `s.state_history==[s.state, s.state]`.

  Next, we'll define a `transition_fn` that uses the history to
  sample the next integer in the sequence, also from a Poisson distribution.

  ```python
  @tfp.experimental.mcmc.augment_with_state_history
  def fibonacci_transition_fn(_, state_with_history):
    expected_next_element = tf.reduce_sum(
      state_with_history.state_history[:, -2:], axis=1)
    return tfd.Poisson(rate=expected_next_element)
  ```

  Our transition function must accept `state_with_history`,
  so that it can access the history, but it returns a distribution
  only over the next state. Decorating it with `augment_with_state_history`
  ensures that the state history is automatically propagated.

  Note: if we were using an `initial_state_proposal` and/or `proposal_fn`, we
  would need to wrap them similarly to the prior and transition function
  shown here.

  Combined with an observation function (which must also now be defined on the
  augmented `StateWithHistory` space), we can track stochastic Fibonacci
  sequences and, for example, infer the initial value of a sequence:

  ```python

  def observation_fn(_, state_with_history):
    return tfd.Poisson(rate=state_with_history.state)

  trajectories, _ = tfp.experimental.mcmc.infer_trajectories(
    observations=tf.convert_to_tensor([4., 11., 16., 23., 40., 69., 100.]),
    initial_state_prior=initial_state_with_history_prior,
    transition_fn=fibonacci_transition_fn,
    observation_fn=observation_fn,
    num_particles=1024)
  inferred_initial_states = trajectories.state[0]
  print(tf.unique_with_counts(inferred_initial_states))
  ```

  """
    def initialize_state_history(state):
        """Build an initial state history by replicating the initial state."""
        with tf.name_scope('initialize_state_history'):
            initial_state_histories = tf.nest.map_structure(
                lambda x: tf.broadcast_to(  # pylint: disable=g-long-lambda
                    tf.expand_dims(x, ps.minimum(ps.rank(x), 1)),
                    ps.concat(
                        [ps.shape(x)[:1], [history_size],
                         ps.shape(x)[1:]],
                        axis=0)),
                state)
            return (joint_distribution_util.
                    independent_joint_distribution_from_structure(
                        _wrap_as_distributions(initial_state_histories)))

    return joint_distribution_named.JointDistributionNamed(
        StateWithHistory(state=prior, state_history=initialize_state_history))
Example #7
0
def build_factored_surrogate_posterior(model,
                                       batch_shape=(),
                                       seed=None,
                                       name=None):
    """Build a variational posterior that factors over model parameters.

  The surrogate posterior consists of independent Normal distributions for
  each parameter with trainable `loc` and `scale`, transformed using the
  parameter's `bijector` to the appropriate support space for that parameter.

  Args:
    model: An instance of `StructuralTimeSeries` representing a
        time-series model. This represents a joint distribution over
        time-series and their parameters with batch shape `[b1, ..., bN]`.
    batch_shape: Batch shape (Python `tuple`, `list`, or `int`) of initial
      states to optimize in parallel.
      Default value: `()`. (i.e., just run a single optimization).
    seed: Python integer to seed the random number generator.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').
  Returns:
    variational_posterior: `tfd.JointDistributionNamed` defining a trainable
        surrogate posterior over model parameters. Samples from this
        distribution are Python `dict`s with Python `str` parameter names as
        keys.

  ### Examples

  Assume we've built a structural time-series model:

  ```python
    day_of_week = tfp.sts.Seasonal(
        num_seasons=7,
        observed_time_series=observed_time_series,
        name='day_of_week')
    local_linear_trend = tfp.sts.LocalLinearTrend(
        observed_time_series=observed_time_series,
        name='local_linear_trend')
    model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                        observed_time_series=observed_time_series)
  ```

  To fit the model to data, we define a surrogate posterior and fit it
  by optimizing a variational bound:

  ```python
    surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(
      model=model)
    loss_curve = tfp.vi.fit_surrogate_posterior(
      target_log_prob_fn=model.joint_log_prob(observed_time_series),
      surrogate_posterior=surrogate_posterior,
      optimizer=tf.optimizers.Adam(learning_rate=0.1),
      num_steps=200)
    posterior_samples = surrogate_posterior.sample(50)

    # In graph mode, we would need to write:
    # with tf.control_dependencies([loss_curve]):
    #   posterior_samples = surrogate_posterior.sample(50)
  ```

  For more control, we can also build and optimize a variational loss
  manually:

  ```python
    @tf.function(autograph=False)  # Ensure the loss is computed efficiently
    def loss_fn():
      return tfp.vi.monte_carlo_variational_loss(
        model.joint_log_prob(observed_time_series),
        surrogate_posterior,
        sample_size=10)

    optimizer = tf.optimizers.Adam(learning_rate=0.1)
    for step in range(200):
      with tf.GradientTape() as tape:
        loss = loss_fn()
      grads = tape.gradient(loss, surrogate_posterior.trainable_variables)
      optimizer.apply_gradients(
        zip(grads, surrogate_posterior.trainable_variables))
      if step % 20 == 0:
        print('step {} loss {}'.format(step, loss))

    posterior_samples = surrogate_posterior.sample(50)
  ```

  """
    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        seed = tfp_util.SeedStream(
            seed,
            salt='StructuralTimeSeries_build_factored_surrogate_posterior')
        variational_posterior = collections.OrderedDict()
        for param in model.parameters:
            variational_posterior[
                param.name] = _build_posterior_for_one_parameter(
                    param, batch_shape=batch_shape, seed=seed())
        return joint_distribution_named_lib.JointDistributionNamed(
            variational_posterior)