def _joint_prior_distribution(self): """Constructs the model's parameter prior distribution.""" # Patch prior distributions to default to their STS-associated bijectors, # which may enforce scaling and/or additional parameter constraints. return joint_distribution_auto_batched.JointDistributionNamedAutoBatched( collections.OrderedDict( (p.name, _with_default_bijector(p.prior, p.bijector)) for p in self.parameters), use_vectorized_map=False, batch_ndims=ps.rank_from_shape(self.batch_shape_tensor, self.batch_shape))
def joint_distribution(self, observed_time_series=None, num_timesteps=None, trajectories_shape=None, initial_step=0, mask=None, experimental_parallelize=False): """Constructs the joint distribution over parameters and observed values. Args: observed_time_series: Optional observed time series to model, as a `Tensor` or `tfp.sts.MaskedTimeSeries` instance having shape `concat([batch_shape, trajectories_shape, num_timesteps, 1])`. If an observed time series is provided, the `num_timesteps`, `trajectories_shape`, and `mask` arguments are ignored, and an unnormalized (pinned) distribution over parameter values is returned. Default value: `None`. num_timesteps: scalar `int` `Tensor` number of timesteps to model. This must be specified either directly or by passing an `observed_time_series`. Default value: `0`. trajectories_shape: `int` `Tensor` shape of sampled trajectories for each set of parameter values. If not specified (either directly or by passing an `observed_time_series`), defaults to a one-to-one correspondence between trajectories and parameter settings (implicitly `trajectories_shape=()`). Default value: `None`. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: `0`. mask: Optional `bool` `Tensor` having shape `concat([batch_shape, trajectories_shape, num_timesteps])`, in which `True` entries indicate that the series value at the corresponding step is missing and should be ignored. This argument should be passed only if `observed_time_series` is not specified or does not already contain a missingness mask; it is an error to pass both this argument and an `observed_time_series` value containing a missingness mask. Default value: `None`. experimental_parallelize: If `True`, use parallel message passing algorithms from `tfp.experimental.parallel_filter` to perform time series operations in `O(log num_timesteps)` sequential steps. The overall FLOP and memory cost may be larger than for the sequential implementations by a constant factor. Default value: `False`. Returns: joint_distribution: joint distribution of model parameters and observed trajectories. If no `observed_time_series` was specified, this is an instance of `tfd.JointDistributionNamedAutoBatched` with a random variable for each model parameter (with names and order matching `self.parameters`), plus a final random variable `observed_time_series` representing a trajectory(ies) conditioned on the parameters. If `observed_time_series` was specified, the return value is given by `joint_distribution.experimental_pin( observed_time_series=observed_time_series)` where `joint_distribution` is as just described, so it defines an unnormalized posterior distribution over the parameters. #### Example: The joint distribution can generate prior samples of parameters and trajectories: ```python from matplotlib import pylab as plt import tensorflow_probability as tfp # Sample and plot 100 trajectories from the prior. model = tfp.sts.LocalLinearTrendModel() prior_samples = model.joint_distribution().sample([100]) plt.plot( tf.linalg.matrix_transpose(prior_samples['observed_time_series'][..., 0])) ``` It also integrates with TFP inference APIs, providing a more flexible alternative to the STS-specific fitting utilities. ```python jd = model.joint_distribution(observed_time_series) # Variational inference. surrogate_posterior = ( tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=jd.event_shape, bijector=jd.experimental_default_event_space_bijector())) losses = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=jd.unnormalized_log_prob, surrogate_posterior=surrogate_posterior, optimizer=tf.optimizers.Adam(0.1), num_steps=200) parameter_samples = surrogate_posterior.sample(50) # No U-Turn Sampler. samples, kernel_results = tfp.experimental.mcmc.windowed_adaptive_nuts( n_draws=500, joint_dist=dist) ``` """ def state_space_model_likelihood(**param_vals): ssm = self.make_state_space_model( param_vals=param_vals, num_timesteps=num_timesteps, initial_step=initial_step, mask=mask, experimental_parallelize=experimental_parallelize) # Looping LGSSM methods are really expensive in eager mode; wrap them # to keep this from slowing things down in interactive use. ssm = tfe_util.JitPublicMethods(ssm, trace_only=True) if distribution_util.shape_may_be_nontrivial(trajectories_shape): return sample.Sample(ssm, sample_shape=trajectories_shape) return ssm batch_ndims = ps.rank_from_shape(self.batch_shape_tensor, self.batch_shape) if observed_time_series is not None: [observed_time_series, is_missing ] = sts_util.canonicalize_observed_time_series_with_mask( observed_time_series) if is_missing is not None: if mask is not None: raise ValueError( 'Passed non-None value for `mask`, but the observed ' 'time series already contains an `is_missing` mask.') mask = is_missing num_timesteps = ps.shape(observed_time_series)[-2] trajectories_shape = ps.shape(observed_time_series)[batch_ndims:-2] joint_distribution = ( joint_distribution_auto_batched.JointDistributionNamedAutoBatched( model=collections.OrderedDict( # Prior. list(self._joint_prior_distribution().model.items()) + # Likelihood. [('observed_time_series', state_space_model_likelihood)]), use_vectorized_map=False, batch_ndims=batch_ndims)) if observed_time_series is not None: return joint_distribution.experimental_pin( observed_time_series=observed_time_series) return joint_distribution