Пример #1
0
 def test_slice_nested_mixture(self):
     dist = tfd.MixtureSameFamily(
         tfd.Categorical(logits=tf.zeros([2])),
         tfd.MixtureSameFamily(tfd.Categorical(logits=tf.zeros([2])),
                               tfd.Bernoulli(logits=tf.zeros([1, 2, 2]))))
     self.assertAllEqual(dist[0, ...].batch_shape_tensor(), [])
     self.assertAllEqual(dist[0, ..., tf.newaxis].batch_shape_tensor(), [1])
     self.assertAllEqual(dist[..., tf.newaxis].batch_shape_tensor(), [1, 1])
  def test_bug170030378(self):
    n_item = 50
    n_rater = 7

    stream = test_util.test_seed_stream()
    weight = self.evaluate(
        tfd.Sample(tfd.Dirichlet([0.25, 0.25]), n_item).sample(seed=stream()))
    mixture_dist = tfd.Categorical(probs=weight)  # batch_shape=[50]

    rater_sensitivity = self.evaluate(
        tfd.Sample(tfd.Beta(5., 1.), n_rater).sample(seed=stream()))
    rater_specificity = self.evaluate(
        tfd.Sample(tfd.Beta(2., 5.), n_rater).sample(seed=stream()))

    probs = tf.stack([rater_sensitivity, rater_specificity])[None, ...]

    components_dist = tfd.BatchBroadcast(  # batch_shape=[50, 2]
        tfd.Independent(tfd.Bernoulli(probs=probs),
                        reinterpreted_batch_ndims=1),
        [50, 2])

    obs_dist = tfd.MixtureSameFamily(mixture_dist, components_dist)

    observed = self.evaluate(obs_dist.sample(seed=stream()))
    mixture_logp = obs_dist.log_prob(observed)

    expected_logp = tf.math.reduce_logsumexp(
        tf.math.log(weight) + components_dist.distribution.log_prob(
            observed[:, None, ...]),
        axis=-1)
    self.assertAllClose(expected_logp, mixture_logp)
Пример #3
0
 def new(params,
         event_size,
         num_components,
         dtype=None,
         validate_args=False,
         name=None):
     """Create the distribution instance from a `params` vector."""
     with tf.name_scope(name, 'CategoricalMixtureOfOneHotCategorical',
                        [params, event_size, num_components]):
         components_shape = tf.concat(
             [tf.shape(params)[:-1], [num_components, event_size]], axis=0)
         dist = tfd.MixtureSameFamily(
             mixture_distribution=tfd.Categorical(
                 logits=params[..., :num_components],
                 validate_args=validate_args),
             components_distribution=tfd.OneHotCategorical(
                 logits=tf.reshape(params[..., num_components:],
                                   components_shape),
                 dtype=dtype or params.dtype.base_dtype,
                 validate_args=False
             ),  # So we can eval on simplex interior.
             # TODO(b/120154797): Change following to `validate_args=True` after
             # fixing: "ValueError: `mixture_distribution` must have scalar
             # `event_dim`s." assertion in MixtureSameFamily.
             validate_args=False)
         # pylint: disable=protected-access
         dist._mean = functools.partial(_eval_all_one_hot,
                                        tfd.Distribution.prob, dist)
         dist.log_mean = functools.partial(_eval_all_one_hot,
                                           tfd.Distribution.log_prob, dist)
         # pylint: enable=protected-access
         return dist
Пример #4
0
    def new(params,
            num_components,
            component_layer,
            validate_args=False,
            name=None,
            **kwargs):
        """Create the distribution instance from a `params` vector."""
        with tf.name_scope(name, 'MixtureSameFamily',
                           [params, num_components, component_layer]):
            params = tf.convert_to_tensor(params, name='params')
            num_components = tf.convert_to_tensor(num_components,
                                                  name='num_components',
                                                  preferred_dtype=tf.int32)

            components_dist = component_layer(
                tf.reshape(
                    params[..., num_components:],
                    tf.concat([tf.shape(params)[:-1], [num_components, -1]],
                              axis=0)))
            mixture_dist = tfd.Categorical(logits=params[..., :num_components])
            return tfd.MixtureSameFamily(
                mixture_dist,
                components_dist,
                # TODO(b/120154797): Change following to `validate_args=True` after
                # fixing: "ValueError: `mixture_distribution` must have scalar
                # `event_dim`s." assertion in MixtureSameFamily.
                validate_args=False,
                **kwargs)
class BatchShapeInferenceTests(test_util.TestCase):
    @parameterized.named_parameters(
        {
            'testcase_name': '_trivial',
            'value_fn': lambda: tfd.Normal(loc=0., scale=1.),
            'expected_batch_shape': []
        },
        {
            'testcase_name':
            '_simple_tensor_broadcasting',
            'value_fn':
            lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
                loc=[0., 0.],
                scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])),
            'expected_batch_shape': [2]
        },
        {
            'testcase_name':
            '_rank_deficient_tensor_broadcasting',
            'value_fn':
            lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
                loc=0.,
                scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])),
            'expected_batch_shape': [2]
        },
        {
            'testcase_name':
            '_mixture_same_family',
            'value_fn':
            lambda: tfd.MixtureSameFamily(  # pylint: disable=g-long-lambda
                mixture_distribution=tfd.Categorical(logits=[[[1., 2., 3.],
                                                              [4., 5., 6.]]]),
                components_distribution=tfd.Normal(
                    loc=0., scale=[[[1., 2., 3.], [4., 5., 6.]]])),
            'expected_batch_shape': [1, 2]
        },
        {
            'testcase_name':
            '_deeply_nested',
            'value_fn':
            lambda: tfd.Independent(  # pylint: disable=g-long-lambda
                tfd.Independent(tfd.Independent(tfd.Independent(
                    tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]),
                    reinterpreted_batch_ndims=2),
                                                reinterpreted_batch_ndims=0),
                                reinterpreted_batch_ndims=1),
                reinterpreted_batch_ndims=1),
            'expected_batch_shape': [1, 1, 1, 1]
        })
    def test_batch_shape_inference_is_correct(self, value_fn,
                                              expected_batch_shape):
        value = value_fn(
        )  # Defer construction until we're in the right graph.
        self.assertAllEqual(expected_batch_shape,
                            value._inferred_batch_shape_tensor())

        batch_shape = value._inferred_batch_shape()
        self.assertIsInstance(batch_shape, tf.TensorShape)
        self.assertTrue(batch_shape.is_compatible_with(expected_batch_shape))
 def test_docstring_example(self):
   stream = test_util.test_seed_stream()
   loc = tfp.random.spherical_uniform([10], 3, seed=stream())
   components_dist = tfd.VonMisesFisher(mean_direction=loc, concentration=50.)
   mixture_dist = tfd.Categorical(
       logits=tf.random.uniform([500, 10], seed=stream()))
   obs_dist = tfd.MixtureSameFamily(
       mixture_dist, tfd.BatchBroadcast(components_dist, [500, 10]))
   test_sites = tfp.random.spherical_uniform([20], 3, seed=stream())
   lp = tfd.Sample(obs_dist, 20).log_prob(test_sites)
   self.assertEqual([500], lp.shape)
   self.evaluate(lp)
Пример #7
0
 def testLogProbBroadcastOverDfInsideMixture(self):
   dims = 2
   scale = np.float32([[0.5, 0.25],  #
                       [0.25, 0.75]])
   df = np.arange(3., 8., dtype=np.float32)
   dist = tfd.MixtureSameFamily(
       components_distribution=tfd.WishartTriL(df=df, scale_tril=chol(scale)),
       mixture_distribution=tfd.Categorical(logits=tf.zeros(df.shape)))
   x = np.random.randn(dims, dims)
   x = np.matmul(x, x.T)
   lp = dist.log_prob(x)
   lp_ = self.evaluate(lp)
   self.assertAllEqual([], dist.batch_shape)
   self.assertAllEqual([], lp.shape)
   self.assertAllEqual([], lp_.shape)
Пример #8
0
def make_mixture_prior():
  if mixture_components == 1:
    # See the module docstring for why we don't learn the parameters here.
    return tfd.MultivariateNormalDiag(loc=tf.zeros([latent_size]),
                                      scale_identity_multiplier=1.0)
  raise NotImplementedError()
  loc = tf.compat.v1.get_variable(name="loc",
                                  shape=[mixture_components, latent_size])
  raw_scale_diag = tf.compat.v1.get_variable(
      name="raw_scale_diag", shape=[mixture_components, latent_size])
  mixture_logits = tf.compat.v1.get_variable(name="mixture_logits",
                                             shape=[mixture_components])

  return tfd.MixtureSameFamily(
      components_distribution=tfd.MultivariateNormalDiag(
          loc=loc, scale_diag=tf.nn.softplus(raw_scale_diag)),
      mixture_distribution=tfd.Categorical(logits=mixture_logits),
      name="prior")
Пример #9
0
def mix_over_posterior_draws(means, variances):
  """Construct a predictive normal distribution that mixes over posterior draws.

  Args:
    means: float `Tensor` of shape
      `[num_posterior_draws, ..., num_timesteps]`.
    variances: float `Tensor` of shape
      `[num_posterior_draws, ..., num_timesteps]`.

  Returns:
    mixture_dist: `tfd.MixtureSameFamily(tfd.Independent(tfd.Normal))` instance
      representing a uniform mixture over the posterior samples, with
      `batch_shape = ...` and `event_shape = [num_timesteps]`.

  """
  # The inputs `means`, `variances` have shape
  #   `concat([
  #      [num_posterior_draws],
  #      sample_shape,
  #      batch_shape,
  #      [num_timesteps]])`
  # Because MixtureSameFamily mixes over the rightmost batch dimension,
  # we need to move the `num_posterior_draws` dimension to be rightmost
  # in the batch shape. This requires use of `Independent` (to preserve
  # `num_timesteps` as part of the event shape) and `move_dimension`.
  # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
  # arbitrary axis, and eliminate `move_dimension` calls here.

  with tf.compat.v1.name_scope(
      'mix_over_posterior_draws', values=[means, variances]):
    num_posterior_draws = dist_util.prefer_static_value(
        tf.shape(input=means))[0]

    component_observations = tfd.Independent(
        distribution=tfd.Normal(
            loc=dist_util.move_dimension(means, 0, -2),
            scale=tf.sqrt(dist_util.move_dimension(variances, 0, -2))),
        reinterpreted_batch_ndims=1)

    return tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            logits=tf.zeros([num_posterior_draws],
                            dtype=component_observations.dtype)),
        components_distribution=component_observations)
Пример #10
0
def forecast(model,
             observed_time_series,
             parameter_samples,
             num_steps_forecast,
             include_observation_noise=True):
    """Construct predictive distribution over future observations.

  Given samples from the posterior over parameters, return the predictive
  distribution over future observations for num_steps_forecast timesteps.

  Args:
    model: An instance of `StructuralTimeSeries` representing a
      time-series model. This represents a joint distribution over
      time-series and their parameters with batch shape `[b1, ..., bN]`.
    observed_time_series: `float` `Tensor` of shape
      `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where
      `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]`
      dimension may (optionally) be omitted if `num_timesteps > 1`. May
      optionally be an instance of `tfp.sts.MaskedTimeSeries` including a
      mask `Tensor` to encode the locations of missing observations.
    parameter_samples: Python `list` of `Tensors` representing posterior samples
      of model parameters, with shapes `[concat([[num_posterior_draws],
      param.prior.batch_shape, param.prior.event_shape]) for param in
      model.parameters]`. This may optionally also be a map (Python `dict`) of
      parameter names to `Tensor` values.
    num_steps_forecast: scalar `int` `Tensor` number of steps to forecast.
    include_observation_noise: Python `bool` indicating whether the forecast
      distribution should include uncertainty from observation noise. If `True`,
      the forecast is over future observations, if `False`, the forecast is over
      future values of the latent noise-free time series.
      Default value: `True`.

  Returns:
    forecast_dist: a `tfd.MixtureSameFamily` instance with event shape
      [num_steps_forecast, 1] and batch shape
      `concat([sample_shape, model.batch_shape])`, with `num_posterior_draws`
      mixture components.

  #### Examples

  Suppose we've built a model and fit it to data using HMC:

  ```python
    day_of_week = tfp.sts.Seasonal(
        num_seasons=7,
        observed_time_series=observed_time_series,
        name='day_of_week')
    local_linear_trend = tfp.sts.LocalLinearTrend(
        observed_time_series=observed_time_series,
        name='local_linear_trend')
    model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                        observed_time_series=observed_time_series)

    samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)
  ```

  Passing the posterior samples into `forecast`, we construct a forecast
  distribution:

  ```python
    forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                     parameter_samples=samples,
                                     num_steps_forecast=50)

    forecast_mean = forecast_dist.mean()[..., 0]  # shape: [50]
    forecast_scale = forecast_dist.stddev()[..., 0]  # shape: [50]
    forecast_samples = forecast_dist.sample(10)[..., 0]  # shape: [10, 50]
  ```

  If using variational inference instead of HMC, we'd construct a forecast using
  samples from the variational posterior:

  ```python
    (variational_loss,
     variational_distributions) = tfp.sts.build_factored_variational_loss(
       model=model, observed_time_series=observed_time_series)

    # OMITTED: take steps to optimize variational loss

    samples = {k: q.sample(30) for (k, q) in variational_distributions.items()}
    forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                         parameter_samples=samples,
                                         num_steps_forecast=50)
  ```

  We can visualize the forecast by plotting:

  ```python
    from matplotlib import pylab as plt
    def plot_forecast(observed_time_series,
                      forecast_mean,
                      forecast_scale,
                      forecast_samples):
      plt.figure(figsize=(12, 6))

      num_steps = observed_time_series.shape[-1]
      num_steps_forecast = forecast_mean.shape[-1]
      num_steps_train = num_steps - num_steps_forecast

      c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05)
      plt.plot(np.arange(num_steps), observed_time_series,
               lw=2, color=c1, label='ground truth')

      forecast_steps = np.arange(num_steps_train,
                       num_steps_train+num_steps_forecast)
      plt.plot(forecast_steps, forecast_samples.T, lw=1, color=c2, alpha=0.1)
      plt.plot(forecast_steps, forecast_mean, lw=2, ls='--', color=c2,
               label='forecast')
      plt.fill_between(forecast_steps,
                       forecast_mean - 2 * forecast_scale,
                       forecast_mean + 2 * forecast_scale, color=c2, alpha=0.2)

      plt.xlim([0, num_steps])
      plt.legend()

    plot_forecast(observed_time_series,
                  forecast_mean=forecast_mean,
                  forecast_scale=forecast_scale,
                  forecast_samples=forecast_samples)
  ```

  """

    with tf.name_scope('forecast'):
        [observed_time_series,
         mask] = sts_util.canonicalize_observed_time_series_with_mask(
             observed_time_series)

        # Run filtering over the observed timesteps to extract the
        # latent state posterior at timestep T+1 (i.e., the final
        # filtering distribution, pushed through the transition model).
        # This is the prior for the forecast model ("today's prior
        # is yesterday's posterior").
        num_observed_steps = dist_util.prefer_static_value(
            tf.shape(input=observed_time_series))[-2]
        observed_data_ssm = model.make_state_space_model(
            num_timesteps=num_observed_steps, param_vals=parameter_samples)
        (_, _, _, predictive_means, predictive_covs, _,
         _) = observed_data_ssm.forward_filter(observed_time_series, mask=mask)

        # Build a batch of state-space models over the forecast period. Because
        # we'll use MixtureSameFamily to mix over the posterior draws, we need to
        # do some shenanigans to move the `[num_posterior_draws]` batch dimension
        # from the leftmost to the rightmost side of the model's batch shape.
        # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
        # arbitrary axis, and eliminate `move_dimension` calls here.
        parameter_samples = model._canonicalize_param_vals_as_map(
            parameter_samples)  # pylint: disable=protected-access
        parameter_samples_with_reordered_batch_dimension = {
            param.name: dist_util.move_dimension(
                parameter_samples[param.name], 0,
                -(1 + _prefer_static_event_ndims(param.prior)))
            for param in model.parameters
        }
        forecast_prior = tfd.MultivariateNormalFullCovariance(
            loc=dist_util.move_dimension(predictive_means[..., -1, :], 0, -2),
            covariance_matrix=dist_util.move_dimension(
                predictive_covs[..., -1, :, :], 0, -3))

        # Ugly hack: because we moved `num_posterior_draws` to the trailing (rather
        # than leading) dimension of parameters, the parameter batch shapes no
        # longer broadcast against the `constant_offset` attribute used in `sts.Sum`
        # models. We fix this by manually adding an extra broadcasting dim to
        # `constant_offset` if present.
        # The root cause of this hack is that we mucked with param dimensions above
        # and are now passing params that are 'invalid' in the sense that they don't
        # match the shapes of the model's param priors. The fix (as above) will be
        # to update MixtureSameFamily so we can avoid changing param dimensions
        # altogether.
        # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
        # arbitrary axis, and eliminate this hack.
        kwargs = {}
        if hasattr(model, 'constant_offset'):
            kwargs['constant_offset'] = tf.convert_to_tensor(
                value=model.constant_offset,
                dtype=forecast_prior.dtype)[..., tf.newaxis]

        if not include_observation_noise:
            parameter_samples_with_reordered_batch_dimension[
                'observation_noise_scale'] = tf.zeros_like(
                    parameter_samples_with_reordered_batch_dimension[
                        'observation_noise_scale'])

        # We assume that any STS model that has a `constant_offset` attribute
        # will allow it to be overridden as a kwarg. This is currently just
        # `sts.Sum`.
        # TODO(b/120245392): when kwargs hack is removed, switch back to calling
        # the public version of `_make_state_space_model`.
        forecast_ssm = model._make_state_space_model(  # pylint: disable=protected-access
            num_timesteps=num_steps_forecast,
            param_map=parameter_samples_with_reordered_batch_dimension,
            initial_state_prior=forecast_prior,
            initial_step=num_observed_steps,
            **kwargs)

        num_posterior_draws = dist_util.prefer_static_value(
            forecast_ssm.batch_shape_tensor())[-1]
        return tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(
            logits=tf.zeros([num_posterior_draws], dtype=forecast_ssm.dtype)),
                                     components_distribution=forecast_ssm)
Пример #11
0
def mixtures_same_family(draw,
                         batch_shape=None,
                         event_dim=None,
                         enable_vars=False,
                         depth=None):
    """Strategy for drawing `MixtureSameFamily` distributions.

  The component distribution is drawn from the `distributions` strategy.

  The Categorical mixture distributions are either shared across all batch
  members, or drawn independently for the full batch (as required by
  `MixtureSameFamily`).

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      `MixtureSameFamily` distribution.  The component distribution will have a
      batch shape of 1 rank higher (for the components being mixed).  Hypothesis
      will pick a batch shape if omitted.
    event_dim: Optional Python int giving the size of each of the component
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
      `tfp.util.TransformedVariable`}
    depth: Python `int` giving maximum nesting depth of compound Distributions.

  Returns:
    dists: A strategy for drawing `MixtureSameFamily` distributions with the
      specified `batch_shape` (or an arbitrary one if omitted).
  """
    if depth is None:
        depth = draw(depths())

    if batch_shape is None:
        # Ensure the components dist has at least one batch dim (a component dim).
        batch_shape = draw(tfp_hps.shapes(min_ndims=1, min_lastdimsize=2))
    else:  # This mixture adds a batch dim to its underlying components dist.
        batch_shape = tensorshape_util.concatenate(
            batch_shape,
            draw(tfp_hps.shapes(min_ndims=1, max_ndims=1, min_lastdimsize=2)))

    component = draw(
        distributions(batch_shape=batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars,
                      depth=depth - 1))
    hp.note(
        'Drawing MixtureSameFamily with component {}; parameters {}'.format(
            component, params_used(component)))
    # scalar or same-shaped categorical?
    mixture_batch_shape = draw(
        hps.one_of(hps.just(batch_shape[:-1]), hps.just(tf.TensorShape([]))))
    mixture_dist = draw(
        base_distributions(dist_name='Categorical',
                           batch_shape=mixture_batch_shape,
                           event_dim=tensorshape_util.as_list(batch_shape)[-1],
                           enable_vars=enable_vars))
    hp.note(('Forming MixtureSameFamily with '
             'mixture distribution {}; parameters {}').format(
                 mixture_dist, params_used(mixture_dist)))
    result_dist = tfd.MixtureSameFamily(components_distribution=component,
                                        mixture_distribution=mixture_dist,
                                        validate_args=True)
    if batch_shape[:-1] != result_dist.batch_shape:
        msg = ('MixtureSameFamily strategy generated a bad batch shape '
               'for {}, should have been {}.').format(result_dist,
                                                      batch_shape[:-1])
        raise AssertionError(msg)
    return result_dist
    def testRWM2DMixNormal(self):
        """Sampling from a 2-D Mixture Normal Distribution."""
        dtype = np.float32

        # By symmetry, target has mean [0, 0]
        # Therefore, Var = E[X^2] = E[E[X^2 | c]], where c is the component.
        # Now..., for the first component,
        #   E[X1^2] =  Var[X1] + Mean[X1]^2
        #           =  0.3^2 + 1^2,
        # and similarly for the second.  As a result,
        # Var[mixture] = 1.09.
        target = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
            components_distribution=tfd.MultivariateNormalDiag(
                loc=[[-1., -1], [1., 1.]],
                scale_identity_multiplier=[0.3, 0.3]))

        inverse_temperatures = 10.**tf.linspace(0., -2., 4)
        step_sizes = tf.constant([0.3, 0.6, 1.2, 2.4])

        def make_kernel_fn(target_log_prob_fn, seed):
            kernel = tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_log_prob_fn,
                seed=seed,
                step_size=step_sizes[make_kernel_fn.idx],
                num_leapfrog_steps=2)
            make_kernel_fn.idx += 1
            return kernel

        # TODO(b/124770732): Remove this hack.
        make_kernel_fn.idx = 0

        remc = tfp.mcmc.ReplicaExchangeMC(
            target_log_prob_fn=tf.function(target.log_prob, autograph=False),
            # Verified that test fails if inverse_temperatures = [1.]
            inverse_temperatures=inverse_temperatures,
            make_kernel_fn=make_kernel_fn,
            seed=_set_seed())

        def _trace_log_accept_ratio(state, results):
            del state
            return [
                r.log_accept_ratio for r in results.sampled_replica_results
            ]

        num_results = 1000
        samples, log_accept_ratios = tfp.mcmc.sample_chain(
            num_results=num_results,
            # Start at one of the modes, in order to make mode jumping necessary
            # if we want to pass test.
            current_state=np.ones(2, dtype=dtype),
            kernel=remc,
            num_burnin_steps=500,
            trace_fn=_trace_log_accept_ratio,
            parallel_iterations=1)  # For determinism.
        self.assertAllEqual((num_results, 2), samples.shape)
        log_accept_ratios = [
            tf.reduce_mean(input_tensor=tf.exp(tf.minimum(0., lar)))
            for lar in log_accept_ratios
        ]

        sample_mean = tf.reduce_mean(input_tensor=samples, axis=0)
        sample_std = tf.sqrt(
            tf.reduce_mean(input_tensor=tf.math.squared_difference(
                samples, sample_mean),
                           axis=0))
        [sample_mean_, sample_std_, log_accept_ratios_
         ] = self.evaluate([sample_mean, sample_std, log_accept_ratios])
        tf1.logging.vlog(1, 'log_accept_ratios: %s  eager: %s',
                         log_accept_ratios_, tf.executing_eagerly())

        self.assertAllClose(sample_mean_, [0., 0.], atol=0.3, rtol=0.3)
        self.assertAllClose(sample_std_,
                            [np.sqrt(1.09), np.sqrt(1.09)],
                            atol=0.1,
                            rtol=0.1)
Пример #13
0
class BatchShapeInferenceTests(test_util.TestCase):

  @parameterized.named_parameters(
      {'testcase_name': '_trivial',
       'value_fn': lambda: tfd.Normal(loc=0., scale=1.),
       'expected_batch_shape': []},
      {'testcase_name': '_simple_tensor_broadcasting',
       'value_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=[0., 0.], scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])),
       'expected_batch_shape': [2]},
      {'testcase_name': '_rank_deficient_tensor_broadcasting',
       'value_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=0., scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])),
       'expected_batch_shape': [2]},
      {'testcase_name': '_mixture_same_family',
       'value_fn': lambda: tfd.MixtureSameFamily(  # pylint: disable=g-long-lambda
           mixture_distribution=tfd.Categorical(
               logits=[[[1., 2., 3.],
                        [4., 5., 6.]]]),
           components_distribution=tfd.Normal(loc=0.,
                                              scale=[[[1., 2., 3.],
                                                      [4., 5., 6.]]])),
       'expected_batch_shape': [1, 2]},
      {'testcase_name': '_deeply_nested',
       'value_fn': lambda: tfd.Independent(  # pylint: disable=g-long-lambda
           tfd.Independent(
               tfd.Independent(
                   tfd.Independent(
                       tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]),
                       reinterpreted_batch_ndims=2),
                   reinterpreted_batch_ndims=0),
               reinterpreted_batch_ndims=1),
           reinterpreted_batch_ndims=1),
       'expected_batch_shape': [1, 1, 1, 1]})
  def test_batch_shape_inference_is_correct(
      self, value_fn, expected_batch_shape):
    value = value_fn()  # Defer construction until we're in the right graph.
    self.assertAllEqual(
        expected_batch_shape,
        value.batch_shape_tensor())

    batch_shape = value.batch_shape
    self.assertIsInstance(batch_shape, tf.TensorShape)
    self.assertTrue(
        batch_shape.is_compatible_with(expected_batch_shape))

  def assert_all_parameters_have_full_batch_shape(
      self, dist, expected_batch_shape):
    self.assertAllEqual(expected_batch_shape, dist.batch_shape_tensor())
    param_batch_shapes = batch_shape_lib.batch_shape_parts(dist)
    for param_batch_shape in param_batch_shapes.values():
      self.assertAllEqual(expected_batch_shape, param_batch_shape)

  @parameterized.named_parameters(
      {'testcase_name': '_trivial',
       'dist_fn': lambda: tfd.Normal(loc=0., scale=1.)},
      {'testcase_name': '_simple_tensor_broadcasting',
       'dist_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=[0., 0.],
           scale_diag=[[1., 1.], [1., 1.]])},
      {'testcase_name': '_rank_deficient_tensor_broadcasting',
       'dist_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=0.,
           scale_diag=[[1., 1.], [1., 1.]])},
      {'testcase_name': '_deeply_nested',
       'dist_fn': lambda: tfd.Independent(  # pylint: disable=g-long-lambda
           tfd.Independent(
               tfd.Independent(
                   tfd.Independent(
                       tfd.Normal(loc=0.,
                                  scale=[[[[[[[[1.]]]]]]]]),
                       reinterpreted_batch_ndims=2),
                   reinterpreted_batch_ndims=0),
               reinterpreted_batch_ndims=1),
           reinterpreted_batch_ndims=1)},
      {'testcase_name': '_transformed_dist_simple',
       'dist_fn': lambda: tfd.TransformedDistribution(  # pylint: disable=g-long-lambda
           tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]),
           tfb.Scale(scale=[2., 3., 4.]))},
      {'testcase_name': '_transformed_dist_with_chain',
       'dist_fn': lambda: tfd.TransformedDistribution(  # pylint: disable=g-long-lambda
           tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]),
           tfb.Shift(-4.)(tfb.Scale(scale=[2., 3., 4.])))},
      {'testcase_name': '_transformed_dist_multipart_nested',
       'dist_fn': lambda: tfd.TransformedDistribution(  # pylint: disable=g-long-lambda
           tfd.TransformedDistribution(
               tfd.TransformedDistribution(
                   tfd.MultivariateNormalDiag(tf.zeros([4, 6]), tf.ones([6])),
                   tfb.Split([3, 3])),
               tfb.JointMap([tfb.Identity(), tfb.Reshape([3, 1])])),
           tfb.JointMap([tfb.Scale(scale=[2., 3., 4.]), tfb.Shift(1.)]))}
      )
  def test_batch_broadcasting(self, dist_fn):
    dist = dist_fn()
    broadcast_dist = dist._broadcast_parameters_with_batch_shape(
        dist.batch_shape)
    self.assert_all_parameters_have_full_batch_shape(
        broadcast_dist,
        expected_batch_shape=broadcast_dist.batch_shape_tensor())

    expanded_batch_shape = ps.concat([[7, 4], dist.batch_shape], axis=0)
    broadcast_params = batch_shape_lib.broadcast_parameters_with_batch_shape(
        dist, expanded_batch_shape)
    broadcast_dist = dist.copy(**broadcast_params)
    self.assert_all_parameters_have_full_batch_shape(
        broadcast_dist,
        expected_batch_shape=expanded_batch_shape)
Пример #14
0
def one_step_predictive(model,
                        posterior_samples,
                        num_forecast_steps=0,
                        original_mean=0.,
                        original_scale=1.,
                        thin_every=10):
    """Constructs a one-step-ahead predictive distribution at every timestep.

  Unlike the generic `tfp.sts.one_step_predictive`, this method uses the
  latent levels from Gibbs sampling to efficiently construct a predictive
  distribution that mixes over posterior samples. The predictive distribution
  may also include additional forecast steps.

  This method returns the predictive distributions for each timestep given
  previous timesteps and sampled model parameters, `p(observed_time_series[t] |
  observed_time_series[:t], weights, observation_noise_scale)`. Note that the
  posterior values of the weights and noise scale will in general be informed
  by observations from all timesteps *including the step being predicted*, so
  this is not a strictly kosher probabilistic quantity, but in general we assume
  that it's close, i.e., that the step being predicted had very small individual
  impact on the overall parameter posterior.

  Args:
    model: A `tfd.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    posterior_samples: A `GibbsSamplerState` instance in which each element is a
      `Tensor` with initial dimension of size `num_samples`.
    num_forecast_steps: Python `int` number of additional forecast steps to
      append.
      Default value: `0`.
    original_mean: Optional scalar float `Tensor`, added to the predictive
      distribution to undo the effect of input normalization.
      Default value: `0.`
    original_scale: Optional scalar float `Tensor`, used to rescale the
      predictive distribution to undo the effect of input normalization.
      Default value: `1.`
    thin_every: Optional Python `int` factor by which to thin the posterior
      samples, to reduce complexity of the predictive distribution. For example,
      if `thin_every=10`, every `10`th sample will be used.
      Default value: `10`.
  Returns:
    predictive_dist: A `tfd.MixtureSameFamily` instance of event shape
      `[num_timesteps + num_forecast_steps]` representing the predictive
      distribution of each timestep given previous timesteps.
  """
    dtype = dtype_util.common_dtype([
        posterior_samples.level_scale.dtype,
        posterior_samples.observation_noise_scale.dtype,
        posterior_samples.level.dtype, original_mean, original_scale
    ],
                                    dtype_hint=tf.float32)
    num_observed_steps = prefer_static.shape(posterior_samples.level)[-1]

    original_mean = tf.convert_to_tensor(original_mean, dtype=dtype)
    original_scale = tf.convert_to_tensor(original_scale, dtype=dtype)
    thinned_samples = tf.nest.map_structure(lambda x: x[::thin_every],
                                            posterior_samples)

    # The local level model expects that the level at step t+1 is equal
    # to the level at step t (plus transition noise of scale 'level_scale', which
    # we account for below).
    if num_forecast_steps > 0:
        num_batch_dims = prefer_static.rank_from_shape(
            prefer_static.shape(thinned_samples.level)) - 2
        forecast_level = tf.tile(
            thinned_samples.level[..., -1:],
            tf.concat([
                tf.ones([num_batch_dims + 1], dtype=tf.int32),
                [num_forecast_steps]
            ],
                      axis=0))
    level_pred = tf.concat(
        [
            thinned_samples.level[..., :1],  # t == 0
            thinned_samples.level[..., :-1]  # 1 <= t < T
        ] + ([forecast_level] if num_forecast_steps > 0 else []),
        axis=-1)

    design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps +
                                                         num_forecast_steps]
    regression_effect = tf.linalg.matvec(design_matrix,
                                         thinned_samples.weights)

    y_mean = (
        (level_pred + regression_effect) * original_scale[..., tf.newaxis] +
        original_mean[..., tf.newaxis])

    num_steps_from_last_observation = tf.concat([
        tf.ones([num_observed_steps], dtype=dtype),
        tf.range(1, num_forecast_steps + 1, dtype=dtype)
    ],
                                                axis=0)
    y_scale = (
        original_scale *
        tf.sqrt(thinned_samples.observation_noise_scale[..., tf.newaxis]**2 +
                thinned_samples.level_scale[..., tf.newaxis]**2 *
                num_steps_from_last_observation))

    num_posterior_draws = prefer_static.shape(y_mean)[0]
    return tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            logits=tf.zeros([num_posterior_draws], dtype=y_mean.dtype)),
        components_distribution=tfd.Normal(
            loc=dist_util.move_dimension(y_mean, 0, -1),
            scale=dist_util.move_dimension(y_scale, 0, -1)))
Пример #15
0
def one_step_predictive(model,
                        posterior_samples,
                        num_forecast_steps=0,
                        original_mean=0.,
                        original_scale=1.,
                        thin_every=10):
    """Constructs a one-step-ahead predictive distribution at every timestep.

  Unlike the generic `tfp.sts.one_step_predictive`, this method uses the
  latent levels from Gibbs sampling to efficiently construct a predictive
  distribution that mixes over posterior samples. The predictive distribution
  may also include additional forecast steps.

  This method returns the predictive distributions for each timestep given
  previous timesteps and sampled model parameters, `p(observed_time_series[t] |
  observed_time_series[:t], weights, observation_noise_scale)`. Note that the
  posterior values of the weights and noise scale will in general be informed
  by observations from all timesteps *including the step being predicted*, so
  this is not a strictly kosher probabilistic quantity, but in general we assume
  that it's close, i.e., that the step being predicted had very small individual
  impact on the overall parameter posterior.

  Args:
    model: A `tfd.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    posterior_samples: A `GibbsSamplerState` instance in which each element is a
      `Tensor` with initial dimension of size `num_samples`.
    num_forecast_steps: Python `int` number of additional forecast steps to
      append.
      Default value: `0`.
    original_mean: Optional scalar float `Tensor`, added to the predictive
      distribution to undo the effect of input normalization.
      Default value: `0.`
    original_scale: Optional scalar float `Tensor`, used to rescale the
      predictive distribution to undo the effect of input normalization.
      Default value: `1.`
    thin_every: Optional Python `int` factor by which to thin the posterior
      samples, to reduce complexity of the predictive distribution. For example,
      if `thin_every=10`, every `10`th sample will be used.
      Default value: `10`.
  Returns:
    predictive_dist: A `tfd.MixtureSameFamily` instance of event shape
      `[num_timesteps + num_forecast_steps]` representing the predictive
      distribution of each timestep given previous timesteps.
  """
    dtype = dtype_util.common_dtype([
        posterior_samples.level_scale,
        posterior_samples.observation_noise_scale, posterior_samples.level,
        original_mean, original_scale
    ],
                                    dtype_hint=tf.float32)
    num_observed_steps = prefer_static.shape(posterior_samples.level)[-1]

    original_mean = tf.convert_to_tensor(original_mean, dtype=dtype)
    original_scale = tf.convert_to_tensor(original_scale, dtype=dtype)
    thinned_samples = tf.nest.map_structure(lambda x: x[::thin_every],
                                            posterior_samples)

    if prefer_static.rank_from_shape(  # If no slope was inferred, treat as zero.
            prefer_static.shape(thinned_samples.slope)) <= 1:
        thinned_samples = thinned_samples._replace(
            slope=tf.zeros_like(thinned_samples.level),
            slope_scale=tf.zeros_like(thinned_samples.level_scale))

    num_steps_from_last_observation = tf.concat([
        tf.ones([num_observed_steps], dtype=dtype),
        tf.range(1, num_forecast_steps + 1, dtype=dtype)
    ],
                                                axis=0)

    # The local linear trend model expects that the level at step t + 1 is equal
    # to the level at step t, plus the slope at time t - 1,
    # plus transition noise of scale 'level_scale' (which we account for below).
    if num_forecast_steps > 0:
        num_batch_dims = prefer_static.rank_from_shape(
            prefer_static.shape(thinned_samples.level)) - 2
        # All else equal, the current level will remain stationary.
        forecast_level = tf.tile(
            thinned_samples.level[..., -1:],
            tf.concat([
                tf.ones([num_batch_dims + 1], dtype=tf.int32),
                [num_forecast_steps]
            ],
                      axis=0))
        # If the model includes slope, the level will steadily increase.
        forecast_level += (
            thinned_samples.slope[..., -1:] *
            tf.range(1., num_forecast_steps + 1., dtype=forecast_level.dtype))

    level_pred = tf.concat(
        [
            thinned_samples.level[..., :1],  # t == 0
            (thinned_samples.level[..., :-1] + thinned_samples.slope[..., :-1]
             )  # 1 <= t < T
        ] + ([forecast_level] if num_forecast_steps > 0 else []),
        axis=-1)

    design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps +
                                                         num_forecast_steps]
    regression_effect = tf.linalg.matvec(design_matrix,
                                         thinned_samples.weights)

    y_mean = (
        (level_pred + regression_effect) * original_scale[..., tf.newaxis] +
        original_mean[..., tf.newaxis])

    # To derive a forecast variance, including slope uncertainty, let
    #  `r[:k]` be iid Gaussian RVs with variance `level_scale**2` and `s[:k]` be
    # iid Gaussian RVs with variance `slope_scale**2`. Then the forecast level at
    # step `T + k` can be written as
    #   (level[T] +           # Last known level.
    #    r[0] + ... + r[k] +  # Sum of random walk terms on level.
    #    slope[T] * k         # Contribution from last known slope.
    #    (k - 1) * s[0] +     # Contributions from random walk terms on slope.
    #    (k - 2) * s[1] +
    #    ... +
    #    1 * s[k - 1])
    # which has variance of
    #  (level_scale**2 * k +
    #   slope_scale**2 * ( (k - 1)**2 +
    #                      (k - 2)**2 +
    #                      ... + 1 ))
    # Here the `slope_scale` coefficient is the `k - 1`th square pyramidal
    # number [1], which is given by
    #  (k - 1) * k * (2 * k - 1) / 6.
    #
    # [1] https://en.wikipedia.org/wiki/Square_pyramidal_number
    variance_from_level = (thinned_samples.level_scale[..., tf.newaxis]**2 *
                           num_steps_from_last_observation)
    variance_from_slope = thinned_samples.slope_scale[..., tf.newaxis]**2 * (
        (num_steps_from_last_observation - 1) *
        num_steps_from_last_observation *
        (2 * num_steps_from_last_observation - 1)) / 6.
    y_scale = (
        original_scale *
        tf.sqrt(thinned_samples.observation_noise_scale[..., tf.newaxis]**2 +
                variance_from_level + variance_from_slope))

    num_posterior_draws = prefer_static.shape(y_mean)[0]
    return tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            logits=tf.zeros([num_posterior_draws], dtype=y_mean.dtype)),
        components_distribution=tfd.Normal(
            loc=dist_util.move_dimension(y_mean, 0, -1),
            scale=dist_util.move_dimension(y_scale, 0, -1)))
Пример #16
0
def one_step_predictive(model, observed_time_series, parameter_samples):
    """Compute one-step-ahead predictive distributions for all timesteps.

  Given samples from the posterior over parameters, return the predictive
  distribution over observations at each time `T`, given observations up
  through time `T-1`.

  Args:
    model: An instance of `StructuralTimeSeries` representing a
      time-series model. This represents a joint distribution over
      time-series and their parameters with batch shape `[b1, ..., bN]`.
    observed_time_series: `float` `Tensor` of shape
      `concat([sample_shape, model.batch_shape, [num_timesteps, 1]]) where
      `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]`
      dimension may (optionally) be omitted if `num_timesteps > 1`.
    parameter_samples: Python `list` of `Tensors` representing posterior samples
      of model parameters, with shapes `[concat([[num_posterior_draws],
      param.prior.batch_shape, param.prior.event_shape]) for param in
      model.parameters]`. This may optionally also be a map (Python `dict`) of
      parameter names to `Tensor` values.

  Returns:
    forecast_dist: a `tfd.MixtureSameFamily` instance with event shape
      [num_timesteps] and
      batch shape `concat([sample_shape, model.batch_shape])`, with
      `num_posterior_draws` mixture components. The `t`th step represents the
      forecast distribution `p(observed_time_series[t] |
      observed_time_series[0:t-1], parameter_samples)`.

  #### Examples

  Suppose we've built a model and fit it to data using HMC:

  ```python
    day_of_week = tfp.sts.Seasonal(
        num_seasons=7,
        observed_time_series=observed_time_series,
        name='day_of_week')
    local_linear_trend = tfp.sts.LocalLinearTrend(
        observed_time_series=observed_time_series,
        name='local_linear_trend')
    model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                        observed_time_series=observed_time_series)

    samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)
  ```

  Passing the posterior samples into `one_step_predictive`, we construct a
  one-step-ahead predictive distribution:

  ```python
    one_step_predictive_dist = tfp.sts.one_step_predictive(
      model, observed_time_series, parameter_samples=samples)

    predictive_means = one_step_predictive_dist.mean()
    predictive_scales = one_step_predictive_dist.stddev()
  ```

  If using variational inference instead of HMC, we'd construct a forecast using
  samples from the variational posterior:

  ```python
    (variational_loss,
     variational_distributions) = tfp.sts.build_factored_variational_loss(
       model=model, observed_time_series=observed_time_series)

    # OMITTED: take steps to optimize variational loss

    samples = {k: q.sample(30) for (k, q) in variational_distributions.items()}
    one_step_predictive_dist = tfp.sts.one_step_predictive(
      model, observed_time_series, parameter_samples=samples)
  ```

  We can visualize the forecast by plotting:

  ```python
    from matplotlib import pylab as plt
    def plot_one_step_predictive(observed_time_series,
                                 forecast_mean,
                                 forecast_scale):
      plt.figure(figsize=(12, 6))
      num_timesteps = forecast_mean.shape[-1]
      c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05)
      plt.plot(observed_time_series, label="observed time series", color=c1)
      plt.plot(forecast_mean, label="one-step prediction", color=c2)
      plt.fill_between(np.arange(num_timesteps),
                       forecast_mean - 2 * forecast_scale,
                       forecast_mean + 2 * forecast_scale,
                       alpha=0.1, color=c2)
      plt.legend()

    plot_one_step_predictive(observed_time_series,
                             forecast_mean=predictive_means,
                             forecast_scale=predictive_scales)
  ```

  To detect anomalous timesteps, we check whether the observed value at each
  step is within a 95% predictive interval, i.e., two standard deviations from
  the mean:

  ```python
    z_scores = ((observed_time_series[..., 1:] - predictive_means[..., :-1])
                 / predictive_scales[..., :-1])
    anomalous_timesteps = tf.boolean_mask(
        tf.range(1, num_timesteps),
        tf.abs(z_scores) > 2.0)
  ```

  """

    with tf.name_scope('one_step_predictive',
                       values=[observed_time_series, parameter_samples]):
        observed_time_series = tf.convert_to_tensor(
            value=observed_time_series, name='observed_time_series')
        observed_time_series = sts_util.maybe_expand_trailing_dim(
            observed_time_series)

        # Run filtering over the training timesteps to extract the
        # predictive means and variances.
        num_timesteps = dist_util.prefer_static_value(
            tf.shape(input=observed_time_series))[-2]
        lgssm = model.make_state_space_model(num_timesteps=num_timesteps,
                                             param_vals=parameter_samples)
        (_, _, _, _, _, observation_means,
         observation_covs) = lgssm.forward_filter(observed_time_series)

        # Construct the predictive distribution by mixing over posterior draws.
        # Unfortunately this requires some shenanigans with shapes. The predictive
        # parameters have shape
        #   `concat([
        #      [num_posterior_draws],
        #      observed_time_series.sample_shape,
        #      lgssm.batch_shape,
        #      lgssm.event_shape  # => [num_timesteps, 1]
        #    ]`
        # Because MixtureSameFamily mixes over the rightmost batch dimension,
        # we need to move the `num_posterior_draws` dimension to be rightmost
        # in the batch shape. This requires use of `Independent` (to preserve
        # `num_timesteps` as part of the event shape) and `move_dimension`.
        # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
        # arbitrary axis, and eliminate `move_dimension` calls here.
        predictions = tfd.Independent(distribution=tfd.Normal(
            loc=dist_util.move_dimension(observation_means[..., 0], 0, -2),
            scale=tf.sqrt(
                dist_util.move_dimension(observation_covs[..., 0, 0], 0, -2))),
                                      reinterpreted_batch_ndims=1)

        num_posterior_draws = dist_util.prefer_static_value(
            tf.shape(input=observation_means))[0]
        return tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(
            logits=tf.zeros([num_posterior_draws], dtype=predictions.dtype)),
                                     components_distribution=predictions)
Пример #17
0
def forecast(model, observed_time_series, parameter_samples,
             num_steps_forecast):
    """Construct predictive distribution over future observations.

  Given samples from the posterior over parameters, return the predictive
  distribution over future observations for num_steps_forecast timesteps.

  Args:
    model: An instance of `StructuralTimeSeries` representing a
      time-series model. This represents a joint distribution over
      time-series and their parameters with batch shape `[b1, ..., bN]`.
    observed_time_series: `float` `Tensor` of shape
      `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where
      `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]`
      dimension may (optionally) be omitted if `num_timesteps > 1`.
    parameter_samples: Python `list` of `Tensors` representing posterior samples
      of model parameters, with shapes `[concat([[num_posterior_draws],
      param.prior.batch_shape, param.prior.event_shape]) for param in
      model.parameters]`. This may optionally also be a map (Python `dict`) of
      parameter names to `Tensor` values.
    num_steps_forecast: scalar `int` `Tensor` number of steps to forecast.

  Returns:
    forecast_dist: a `tfd.MixtureSameFamily` instance with event shape
      [num_steps_forecast, 1] and batch shape
      `concat([sample_shape, model.batch_shape])`, with `num_posterior_draws`
      mixture components.

  #### Examples

  Suppose we've built a model and fit it to data using HMC:

  ```python
    day_of_week = tfp.sts.Seasonal(
        num_seasons=7,
        observed_time_series=observed_time_series,
        name='day_of_week')
    local_linear_trend = tfp.sts.LocalLinearTrend(
        observed_time_series=observed_time_series,
        name='local_linear_trend')
    model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                        observed_time_series=observed_time_series)

    samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)
  ```

  Passing the posterior samples into `forecast`, we construct a forecast
  distribution:

  ```python
    forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                     parameter_samples=samples,
                                     num_steps_forecast=50)

    forecast_mean = forecast_dist.mean()[..., 0]  # shape: [50]
    forecast_scale = forecast_dist.stddev()[..., 0]  # shape: [50]
    forecast_samples = forecast_dist.sample(10)[..., 0]  # shape: [10, 50]
  ```

  If using variational inference instead of HMC, we'd construct a forecast using
  samples from the variational posterior:

  ```python
    (variational_loss,
     variational_distributions) = tfp.sts.build_factored_variational_loss(
       model=model, observed_time_series=observed_time_series)

    # OMITTED: take steps to optimize variational loss

    samples = {k: q.sample(30) for (k, q) in variational_distributions.items()}
    forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                         parameter_samples=samples,
                                         num_steps_forecast=50)
  ```

  We can visualize the forecast by plotting:

  ```python
    from matplotlib import pylab as plt
    def plot_forecast(observed_time_series,
                      forecast_mean,
                      forecast_scale,
                      forecast_samples):
      plt.figure(figsize=(12, 6))

      num_steps = observed_time_series.shape[-1]
      num_steps_forecast = forecast_mean.shape[-1]
      num_steps_train = num_steps - num_steps_forecast

      c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05)
      plt.plot(np.arange(num_steps), observed_time_series,
               lw=2, color=c1, label='ground truth')

      forecast_steps = np.arange(num_steps_train,
                       num_steps_train+num_steps_forecast)
      plt.plot(forecast_steps, forecast_samples.T, lw=1, color=c2, alpha=0.1)
      plt.plot(forecast_steps, forecast_mean, lw=2, ls='--', color=c2,
               label='forecast')
      plt.fill_between(forecast_steps,
                       forecast_mean - 2 * forecast_scale,
                       forecast_mean + 2 * forecast_scale, color=c2, alpha=0.2)

      plt.xlim([0, num_steps])
      plt.legend()

    plot_forecast(observed_time_series,
                  forecast_mean=forecast_mean,
                  forecast_scale=forecast_scale,
                  forecast_samples=forecast_samples)
  ```

  """

    with tf.name_scope('forecast',
                       values=[
                           observed_time_series, parameter_samples,
                           num_steps_forecast
                       ]):
        observed_time_series = tf.convert_to_tensor(
            value=observed_time_series, name='observed_time_series')
        observed_time_series = sts_util.maybe_expand_trailing_dim(
            observed_time_series)

        # Run filtering over the observed timesteps to extract the
        # latent state posterior at timestep T+1 (i.e., the final
        # filtering distribution, pushed through the transition model).
        # This is the prior for the forecast model ("today's prior
        # is yesterday's posterior").
        num_observed_steps = dist_util.prefer_static_value(
            tf.shape(input=observed_time_series))[-2]
        observed_data_ssm = model.make_state_space_model(
            num_timesteps=num_observed_steps, param_vals=parameter_samples)
        (_, _, _, predictive_means, predictive_covs, _,
         _) = observed_data_ssm.forward_filter(observed_time_series)

        # Build a batch of state-space models over the forecast period. Because
        # we'll use MixtureSameFamily to mix over the posterior draws, we need to
        # do some shenanigans to move the `[num_posterior_draws]` batch dimension
        # from the leftmost to the rightmost side of the model's batch shape.
        # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
        # arbitrary axis, and eliminate `move_dimension` calls here.
        parameter_samples = model._canonicalize_param_vals_as_map(
            parameter_samples)  # pylint: disable=protected-access
        parameter_samples_with_reordered_batch_dimension = {
            param.name: dist_util.move_dimension(
                parameter_samples[param.name], 0,
                -(1 + _prefer_static_event_ndims(param.prior)))
            for param in model.parameters
        }
        forecast_prior = tfd.MultivariateNormalFullCovariance(
            loc=dist_util.move_dimension(predictive_means[..., -1, :], 0, -2),
            covariance_matrix=dist_util.move_dimension(
                predictive_covs[..., -1, :, :], 0, -3))
        forecast_ssm = model.make_state_space_model(
            num_timesteps=num_steps_forecast,
            param_vals=parameter_samples_with_reordered_batch_dimension,
            initial_state_prior=forecast_prior,
            initial_step=num_observed_steps)

        num_posterior_draws = dist_util.prefer_static_value(
            forecast_ssm.batch_shape_tensor())[-1]
        return tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(
            logits=tf.zeros([num_posterior_draws], dtype=forecast_ssm.dtype)),
                                     components_distribution=forecast_ssm)