Exemple #1
0
    def _joint_prior_distribution(self):
        """Constructs the model's parameter prior distribution."""

        # Patch prior distributions to default to their STS-associated bijectors,
        # which may enforce scaling and/or additional parameter constraints.
        return joint_distribution_auto_batched.JointDistributionNamedAutoBatched(
            collections.OrderedDict(
                (p.name, _with_default_bijector(p.prior, p.bijector))
                for p in self.parameters),
            use_vectorized_map=False,
            batch_ndims=ps.rank_from_shape(self.batch_shape_tensor,
                                           self.batch_shape))
Exemple #2
0
    def joint_distribution(self,
                           observed_time_series=None,
                           num_timesteps=None,
                           trajectories_shape=None,
                           initial_step=0,
                           mask=None,
                           experimental_parallelize=False):
        """Constructs the joint distribution over parameters and observed values.

    Args:
      observed_time_series: Optional observed time series to model, as a
        `Tensor` or `tfp.sts.MaskedTimeSeries` instance having shape
        `concat([batch_shape, trajectories_shape, num_timesteps, 1])`. If
        an observed time series is provided, the `num_timesteps`,
        `trajectories_shape`, and `mask` arguments are ignored, and
        an unnormalized (pinned) distribution over parameter values is returned.
        Default value: `None`.
      num_timesteps: scalar `int` `Tensor` number of timesteps to model. This
        must be specified either directly or by passing an
        `observed_time_series`.
        Default value: `0`.
      trajectories_shape: `int` `Tensor` shape of sampled trajectories
        for each set of parameter values. If not specified (either directly
        or by passing an `observed_time_series`), defaults to a
        one-to-one correspondence between trajectories and parameter settings
        (implicitly `trajectories_shape=()`).
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: `0`.
      mask: Optional `bool` `Tensor` having shape
        `concat([batch_shape, trajectories_shape, num_timesteps])`, in which
        `True` entries indicate that the series value at the corresponding step
        is missing and should be ignored. This argument should be passed only
        if `observed_time_series` is not specified or does not already contain
        a missingness mask; it is an error to pass both this
        argument and an `observed_time_series` value containing a missingness
        mask.
        Default value: `None`.
      experimental_parallelize: If `True`, use parallel message passing
        algorithms from `tfp.experimental.parallel_filter` to perform time
        series operations in `O(log num_timesteps)` sequential steps. The
        overall FLOP and memory cost may be larger than for the sequential
        implementations by a constant factor.
        Default value: `False`.
    Returns:
      joint_distribution: joint distribution of model parameters and
        observed trajectories. If no `observed_time_series` was specified, this
        is an instance of `tfd.JointDistributionNamedAutoBatched` with a
        random variable for each model parameter (with names and order matching
        `self.parameters`), plus a final random variable `observed_time_series`
        representing a trajectory(ies) conditioned on the parameters. If
        `observed_time_series` was specified, the return value is given by
        `joint_distribution.experimental_pin(
        observed_time_series=observed_time_series)` where `joint_distribution`
        is as just described, so it defines an unnormalized posterior
        distribution over the parameters.

    #### Example:

    The joint distribution can generate prior samples of parameters and
    trajectories:

    ```python
    from matplotlib import pylab as plt
    import tensorflow_probability as tfp

    # Sample and plot 100 trajectories from the prior.
    model = tfp.sts.LocalLinearTrendModel()
    prior_samples = model.joint_distribution().sample([100])
    plt.plot(
      tf.linalg.matrix_transpose(prior_samples['observed_time_series'][..., 0]))
    ```

    It also integrates with TFP inference APIs, providing a more flexible
    alternative to the STS-specific fitting utilities.

    ```python
    jd = model.joint_distribution(observed_time_series)

    # Variational inference.
    surrogate_posterior = (
      tfp.experimental.vi.build_factored_surrogate_posterior(
        event_shape=jd.event_shape,
        bijector=jd.experimental_default_event_space_bijector()))
    losses = tfp.vi.fit_surrogate_posterior(
      target_log_prob_fn=jd.unnormalized_log_prob,
      surrogate_posterior=surrogate_posterior,
      optimizer=tf.optimizers.Adam(0.1),
      num_steps=200)
    parameter_samples = surrogate_posterior.sample(50)

    # No U-Turn Sampler.
    samples, kernel_results = tfp.experimental.mcmc.windowed_adaptive_nuts(
      n_draws=500, joint_dist=dist)
    ```

    """
        def state_space_model_likelihood(**param_vals):
            ssm = self.make_state_space_model(
                param_vals=param_vals,
                num_timesteps=num_timesteps,
                initial_step=initial_step,
                mask=mask,
                experimental_parallelize=experimental_parallelize)
            # Looping LGSSM methods are really expensive in eager mode; wrap them
            # to keep this from slowing things down in interactive use.
            ssm = tfe_util.JitPublicMethods(ssm, trace_only=True)
            if distribution_util.shape_may_be_nontrivial(trajectories_shape):
                return sample.Sample(ssm, sample_shape=trajectories_shape)
            return ssm

        batch_ndims = ps.rank_from_shape(self.batch_shape_tensor,
                                         self.batch_shape)
        if observed_time_series is not None:
            [observed_time_series, is_missing
             ] = sts_util.canonicalize_observed_time_series_with_mask(
                 observed_time_series)
            if is_missing is not None:
                if mask is not None:
                    raise ValueError(
                        'Passed non-None value for `mask`, but the observed '
                        'time series already contains an `is_missing` mask.')
                mask = is_missing
            num_timesteps = ps.shape(observed_time_series)[-2]
            trajectories_shape = ps.shape(observed_time_series)[batch_ndims:-2]

        joint_distribution = (
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched(
                model=collections.OrderedDict(
                    # Prior.
                    list(self._joint_prior_distribution().model.items()) +
                    # Likelihood.
                    [('observed_time_series', state_space_model_likelihood)]),
                use_vectorized_map=False,
                batch_ndims=batch_ndims))

        if observed_time_series is not None:
            return joint_distribution.experimental_pin(
                observed_time_series=observed_time_series)

        return joint_distribution