def _extract_log_probs(num_states, dist): """Tabulate log probabilities from a batch of distributions.""" states = tf.reshape(tf.range(num_states), tf.concat([[num_states], tf.ones_like(dist.batch_shape_tensor())], axis=0)) return util.move_dimension(dist.log_prob(states), 0, -1)
def _log_prob(self, value): with tf.control_dependencies(self._runtime_assertions): # The argument `value` is a tensor of sequences of observations. # `observation_batch_shape` is the shape of that tensor with the # sequence part removed. # `observation_batch_shape` is then broadcast to the full batch shape # to give the `working_shape` that defines the shape of the result. observation_batch_shape = tf.shape( value)[:-1 - self._underlying_event_rank] # value :: observation_batch_shape num_steps observation_event_shape working_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) log_init = tf.broadcast_to(self._log_init, tf.concat([working_shape, [self._num_states]], axis=0)) # log_init :: working_shape num_states log_transition = self._log_trans # `observation_event_shape` is the shape of each sequence of observations # emitted by the model. observation_event_shape = tf.shape( value)[-1 - self._underlying_event_rank:] working_obs = tf.broadcast_to(value, tf.concat([working_shape, observation_event_shape], axis=0)) # working_obs :: working_shape observation_event_shape r = self._underlying_event_rank # Move index into sequence of observations to front so we can apply # tf.foldl working_obs = util.move_dimension(working_obs, -1 - r, 0)[..., tf.newaxis] # working_obs :: num_steps working_shape underlying_event_shape observation_probs = ( self._observation_distribution.log_prob(working_obs)) def forward_step(log_prev_step, log_observation): return _log_vector_matrix(log_prev_step, log_transition) + log_observation fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init) # fwd_prob :: working_shape num_states log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1) # log_prob :: working_shape return log_prob
def cholesky_update(chol, update_vector, multiplier=1., name=None): """Returns cholesky of chol @ chol.T + multiplier * u @ u.T. Given a (batch of) lower triangular cholesky factor(s) `chol`, along with a (batch of) vector(s) `update_vector`, compute the lower triangular cholesky factor of the rank-1 update `chol @ chol.T + multiplier * u @ u.T`, where `multiplier` is a (batch of) scalar(s). If `chol` has shape `[L, L]`, this has complexity `O(L^2)` compared to the naive algorithm which has complexity `O(L^3)`. Args: chol: Floating-point `Tensor` with shape `[B1, ..., Bn, L, L]`. Cholesky decomposition of `mat = chol @ chol.T`. Batch dimensions must be broadcastable with `update_vector` and `multiplier`. update_vector: Floating-point `Tensor` with shape `[B1, ... Bn, L]`. Vector defining rank-one update. Batch dimensions must be broadcastable with `chol` and `multiplier`. multiplier: Floating-point `Tensor` with shape `[B1, ..., Bn]. Scalar multiplier to rank-one update. Batch dimensions must be broadcastable with `chol` and `update_vector`. Note that updates where `multiplier` is positive are numerically stable, while when `multiplier` is negative (downdating), the update will only work if the new resulting matrix is still positive definite. name: Optional name for this op. #### References [1] Oswin Krause. Christian Igel. A More Efficient Rank-one Covariance Matrix Update for Evolution Strategies. 2015 ACM Conference. https://www.researchgate.net/publication/300581419_A_More_Efficient_Rank-one_Covariance_Matrix_Update_for_Evolution_Strategies """ # TODO(b/154638092): Move this functionality in to TensorFlow. with tf.name_scope(name or 'cholesky_update'): dtype = dtype_util.common_dtype([chol, update_vector, multiplier], dtype_hint=tf.float32) chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype) update_vector = tf.convert_to_tensor(update_vector, name='update_vector', dtype=dtype) multiplier = tf.convert_to_tensor(multiplier, name='multiplier', dtype=dtype) batch_shape = prefer_static.broadcast_shape( prefer_static.broadcast_shape( tf.shape(chol)[:-2], tf.shape(update_vector)[:-1]), tf.shape(multiplier)) chol = tf.broadcast_to( chol, prefer_static.concat( [batch_shape, tf.shape(chol)[-2:]], axis=0)) update_vector = tf.broadcast_to( update_vector, prefer_static.concat( [batch_shape, tf.shape(update_vector)[-1:]], axis=0)) multiplier = tf.broadcast_to(multiplier, batch_shape) chol_diag = tf.linalg.diag_part(chol) # The algorithm in [1] is implemented as a double for loop. We can treat # the inner loop in Algorithm 3.1 as a vector operation, and thus the # whole algorithm as a single for loop, and hence can use a `tf.scan` # on it. # We use for accumulation omega and b as defined in Algorithm 3.1, since # these are updated per iteration. def compute_new_column(accumulated_quantities, state): """Computes the next column of the updated cholesky.""" _, _, omega, b = accumulated_quantities index, diagonal_member, col = state omega_at_index = tf.gather(omega, index, axis=-1) # Line 4 new_diagonal_member = tf.math.sqrt( tf.math.square(diagonal_member) + multiplier / b * tf.math.square(omega_at_index)) # `scaling_factor` is the same as `gamma` on Line 5. scaling_factor = (tf.math.square(diagonal_member) * b + multiplier * tf.math.square(omega_at_index)) # The following updates are the same as the for loop in lines 6-8. omega = omega - (omega_at_index / diagonal_member)[..., tf.newaxis] * col new_col = new_diagonal_member[..., tf.newaxis] * ( col / diagonal_member[..., tf.newaxis] + (multiplier * omega_at_index / scaling_factor)[..., tf.newaxis] * omega) b = b + multiplier * tf.math.square( omega_at_index / diagonal_member) return new_diagonal_member, new_col, omega, b # We will scan over the columns. chol = distribution_util.move_dimension(chol, source_idx=-1, dest_idx=0) chol_diag = distribution_util.move_dimension(chol_diag, source_idx=-1, dest_idx=0) new_diag, new_chol, _, _ = tf.scan( fn=compute_new_column, elems=(tf.range(0, tf.shape(chol)[0]), chol_diag, chol), initializer=(tf.zeros_like(multiplier), tf.zeros_like(chol[0, ...]), update_vector, tf.ones_like(multiplier))) new_chol = distribution_util.move_dimension(new_chol, source_idx=0, dest_idx=-1) new_diag = distribution_util.move_dimension(new_diag, source_idx=0, dest_idx=-1) new_chol = tf.linalg.set_diag(new_chol, new_diag) return new_chol
def decompose_forecast_by_component(model, forecast_dist, parameter_samples): """Decompose a forecast distribution into contributions from each component. Args: model: An instance of `tfp.sts.Sum` representing a structural time series model. forecast_dist: A `Distribution` instance returned by `tfp.sts.forecast()`. (specifically, must be a `tfd.MixtureSameFamily` over a `tfd.LinearGaussianStateSpaceModel` parameterized by posterior samples). parameter_samples: Python `list` of `Tensors` representing posterior samples of model parameters, with shapes `[concat([[num_posterior_draws], param.prior.batch_shape, param.prior.event_shape]) for param in model.parameters]`. This may optionally also be a map (Python `dict`) of parameter names to `Tensor` values. Returns: component_forecasts: A `collections.OrderedDict` instance mapping component StructuralTimeSeries instances (elements of `model.components`) to `tfd.Distribution` instances representing the marginal forecast for each component. Each distribution has batch and event shape matching `forecast_dist` (specifically, the event shape is `[num_steps_forecast]`). #### Examples Suppose we've built a model, fit it to data, and constructed a forecast distribution: ```python day_of_week = tfp.sts.Seasonal( num_seasons=7, observed_time_series=observed_time_series, name='day_of_week') local_linear_trend = tfp.sts.LocalLinearTrend( observed_time_series=observed_time_series, name='local_linear_trend') model = tfp.sts.Sum(components=[day_of_week, local_linear_trend], observed_time_series=observed_time_series) num_steps_forecast = 50 samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series) forecast_dist = tfp.sts.forecast(model, observed_time_series, parameter_samples=samples, num_steps_forecast=num_steps_forecast) ``` To extract the forecast for individual components, pass the forecast distribution into `decompose_forecast_by_components`: ```python component_forecasts = decompose_forecast_by_component( model, forecast_dist, samples) # Component mean and stddev have shape `[num_steps_forecast]`. day_of_week_effect_mean = forecast_components[day_of_week].mean() day_of_week_effect_stddev = forecast_components[day_of_week].stddev() ``` Using the component forecasts, we can visualize the uncertainty for each component: ``` from matplotlib import pylab as plt num_components = len(component_forecasts) xs = np.arange(num_steps_forecast) fig = plt.figure(figsize=(12, 3 * num_components)) for i, (component, component_dist) in enumerate(component_forecasts.items()): # If in graph mode, replace `.numpy()` with `.eval()` or `sess.run()`. component_mean = component_dist.mean().numpy() component_stddev = component_dist.stddev().numpy() ax = fig.add_subplot(num_components, 1, 1 + i) ax.plot(xs, component_mean, lw=2) ax.fill_between(xs, component_mean - 2 * component_stddev, component_mean + 2 * component_stddev, alpha=0.5) ax.set_title(component.name) ``` """ with tf.name_scope('decompose_forecast_by_component'): try: forecast_lgssm = forecast_dist.components_distribution forecast_latent_mean, _ = forecast_lgssm._joint_mean() # pylint: disable=protected-access forecast_latent_covs, _ = forecast_lgssm._joint_covariances() # pylint: disable=protected-access except AttributeError as e: raise ValueError( 'Forecast distribution must be a MixtureSameFamily of' 'LinearGaussianStateSpaceModel distributions, such as returned by' '`tfp.sts.forecast()`. (saw exception: {})'.format(e)) # Since `parameter_samples` will have sample shape `[num_posterior_draws]`, # we need to move the `num_posterior_draws` dimension of the forecast # moments from the trailing batch dimension, where it's currently put by # `sts.forecast`, back to the leading (sample shape) dimension. forecast_latent_mean = dist_util.move_dimension( forecast_latent_mean, source_idx=-3, dest_idx=0) forecast_latent_covs = dist_util.move_dimension( forecast_latent_covs, source_idx=-4, dest_idx=0) return _decompose_from_posterior_marginals( model, forecast_latent_mean, forecast_latent_covs, parameter_samples, initial_step=forecast_lgssm.initial_step)
def posterior_marginals(self, observations): """Compute marginal posterior distribution for each state. This function computes, for each time step, the marginal conditional probability that the hidden Markov model was in each possible state given the observations that were made at each time step. So if the hidden states are `z[0],...,z[num_steps - 1]` and the observations are `x[0],...,x[num_steps - 1]`, then this function computes `P(z[i] | x[0],...,x[num_steps - 1])` for all `i` from `0` to `num_steps-1`. This operation is sometimes called smoothing. It uses a form of the forward-backward algorithm. Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimension of this tensor gives the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. Returns: A `Categorical` distribution object representing the marginal probability of the hidden Markov model being in each state at each step. The rightmost dimension of the `Categorical` distributions batch will equal the `num_steps` parameter providing one marginal distribution for each step. The other dimensions are the dimensions corresponding to the batch of observations. Raises: ValueError: if rightmost dimension of `observations` does not have size `num_steps`. """ with tf.name_scope("posterior_marginals", values=[observations]): with tf.control_dependencies(self._runtime_assertions): observation_tensor_shape = tf.shape(input=observations) with tf.control_dependencies([ tf.compat.v1.assert_equal( observation_tensor_shape[-1], self._num_steps, message= "Last dimension of `observations` must match `num_steps`" "of `HiddenMarkovModel`") ]): observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] working_shape = tf.broadcast_dynamic_shape( observation_batch_shape, self.batch_shape_tensor()) log_init = tf.broadcast_to( self._log_init, tf.concat([working_shape, [self._num_states]], axis=0)) log_transition = self._log_trans observations = tf.broadcast_to( observations, tf.concat([working_shape, observation_event_shape], axis=0)) observation_rank = tf.rank(observations) underlying_event_rank = self._underlying_event_rank observations = util.move_dimension( observations, observation_rank - underlying_event_rank - 1, 0)[..., tf.newaxis] observation_log_probs = self._observation_distribution.log_prob( observations) log_adjoint_prob = tf.zeros_like(log_init) def forward_step(log_previous_step, log_observation): return _log_vector_matrix( log_previous_step, log_transition) + log_observation log_prob = log_init + observation_log_probs[0] forward_log_probs = tf.scan(forward_step, observation_log_probs[1:], initializer=log_prob, name="forward_log_probs") forward_log_probs = tf.concat( [[log_prob], forward_log_probs], axis=0) def backward_step(log_previous_step, log_observation): return _log_matrix_vector( log_transition, log_observation + log_previous_step) backward_log_adjoint_probs = tf.scan( backward_step, observation_log_probs[1:], initializer=log_adjoint_prob, reverse=True, name="backward_log_adjoint_probs") total_log_prob = tf.reduce_logsumexp( input_tensor=forward_log_probs[-1], axis=-1) backward_log_adjoint_probs = tf.concat( [backward_log_adjoint_probs, [log_adjoint_prob]], axis=0) log_likelihoods = forward_log_probs + backward_log_adjoint_probs marginal_log_probs = util.move_dimension( log_likelihoods - total_log_prob[..., tf.newaxis], 0, -2) return categorical.Categorical(logits=marginal_log_probs)
def forecast(model, observed_time_series, parameter_samples, num_steps_forecast, include_observation_noise=True): """Construct predictive distribution over future observations. Given samples from the posterior over parameters, return the predictive distribution over future observations for num_steps_forecast timesteps. Args: model: An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. observed_time_series: `float` `Tensor` of shape `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]` dimension may (optionally) be omitted if `num_timesteps > 1`. May optionally be an instance of `tfp.sts.MaskedTimeSeries` including a mask `Tensor` to encode the locations of missing observations. parameter_samples: Python `list` of `Tensors` representing posterior samples of model parameters, with shapes `[concat([[num_posterior_draws], param.prior.batch_shape, param.prior.event_shape]) for param in model.parameters]`. This may optionally also be a map (Python `dict`) of parameter names to `Tensor` values. num_steps_forecast: scalar `int` `Tensor` number of steps to forecast. include_observation_noise: Python `bool` indicating whether the forecast distribution should include uncertainty from observation noise. If `True`, the forecast is over future observations, if `False`, the forecast is over future values of the latent noise-free time series. Default value: `True`. Returns: forecast_dist: a `tfd.MixtureSameFamily` instance with event shape [num_steps_forecast, 1] and batch shape `concat([sample_shape, model.batch_shape])`, with `num_posterior_draws` mixture components. #### Examples Suppose we've built a model and fit it to data using HMC: ```python day_of_week = tfp.sts.Seasonal( num_seasons=7, observed_time_series=observed_time_series, name='day_of_week') local_linear_trend = tfp.sts.LocalLinearTrend( observed_time_series=observed_time_series, name='local_linear_trend') model = tfp.sts.Sum(components=[day_of_week, local_linear_trend], observed_time_series=observed_time_series) samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series) ``` Passing the posterior samples into `forecast`, we construct a forecast distribution: ```python forecast_dist = tfp.sts.forecast(model, observed_time_series, parameter_samples=samples, num_steps_forecast=50) forecast_mean = forecast_dist.mean()[..., 0] # shape: [50] forecast_scale = forecast_dist.stddev()[..., 0] # shape: [50] forecast_samples = forecast_dist.sample(10)[..., 0] # shape: [10, 50] ``` If using variational inference instead of HMC, we'd construct a forecast using samples from the variational posterior: ```python surrogate_posterior = tfp.sts.build_factored_surrogate_posterior( model=model) loss_curve = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=model.joint_log_prob(observed_time_series), surrogate_posterior=surrogate_posterior, optimizer=tf.optimizers.Adam(learning_rate=0.1), num_steps=200) samples = surrogate_posterior.sample(30) forecast_dist = tfp.sts.forecast(model, observed_time_series, parameter_samples=samples, num_steps_forecast=50) ``` We can visualize the forecast by plotting: ```python from matplotlib import pylab as plt def plot_forecast(observed_time_series, forecast_mean, forecast_scale, forecast_samples): plt.figure(figsize=(12, 6)) num_steps = observed_time_series.shape[-1] num_steps_forecast = forecast_mean.shape[-1] num_steps_train = num_steps - num_steps_forecast c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05) plt.plot(np.arange(num_steps), observed_time_series, lw=2, color=c1, label='ground truth') forecast_steps = np.arange(num_steps_train, num_steps_train+num_steps_forecast) plt.plot(forecast_steps, forecast_samples.T, lw=1, color=c2, alpha=0.1) plt.plot(forecast_steps, forecast_mean, lw=2, ls='--', color=c2, label='forecast') plt.fill_between(forecast_steps, forecast_mean - 2 * forecast_scale, forecast_mean + 2 * forecast_scale, color=c2, alpha=0.2) plt.xlim([0, num_steps]) plt.legend() plot_forecast(observed_time_series, forecast_mean=forecast_mean, forecast_scale=forecast_scale, forecast_samples=forecast_samples) ``` """ with tf.name_scope('forecast'): [ observed_time_series, mask ] = sts_util.canonicalize_observed_time_series_with_mask( observed_time_series) # Run filtering over the observed timesteps to extract the # latent state posterior at timestep T+1 (i.e., the final # filtering distribution, pushed through the transition model). # This is the prior for the forecast model ("today's prior # is yesterday's posterior"). num_observed_steps = dist_util.prefer_static_value( tf.shape(observed_time_series))[-2] observed_data_ssm = model.make_state_space_model( num_timesteps=num_observed_steps, param_vals=parameter_samples) (_, _, _, predictive_means, predictive_covs, _, _ ) = observed_data_ssm.forward_filter(observed_time_series, mask=mask) # Build a batch of state-space models over the forecast period. Because # we'll use MixtureSameFamily to mix over the posterior draws, we need to # do some shenanigans to move the `[num_posterior_draws]` batch dimension # from the leftmost to the rightmost side of the model's batch shape. # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an # arbitrary axis, and eliminate `move_dimension` calls here. parameter_samples = model._canonicalize_param_vals_as_map(parameter_samples) # pylint: disable=protected-access parameter_samples_with_reordered_batch_dimension = { param.name: dist_util.move_dimension( parameter_samples[param.name], 0, -(1 + _prefer_static_event_ndims(param.prior))) for param in model.parameters} forecast_prior = tfd.MultivariateNormalFullCovariance( loc=dist_util.move_dimension(predictive_means[..., -1, :], 0, -2), covariance_matrix=dist_util.move_dimension( predictive_covs[..., -1, :, :], 0, -3)) # Ugly hack: because we moved `num_posterior_draws` to the trailing (rather # than leading) dimension of parameters, the parameter batch shapes no # longer broadcast against the `constant_offset` attribute used in `sts.Sum` # models. We fix this by manually adding an extra broadcasting dim to # `constant_offset` if present. # The root cause of this hack is that we mucked with param dimensions above # and are now passing params that are 'invalid' in the sense that they don't # match the shapes of the model's param priors. The fix (as above) will be # to update MixtureSameFamily so we can avoid changing param dimensions # altogether. # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an # arbitrary axis, and eliminate this hack. kwargs = {} if hasattr(model, 'constant_offset'): kwargs['constant_offset'] = tf.convert_to_tensor( value=model.constant_offset, dtype=forecast_prior.dtype)[..., tf.newaxis, :] if not include_observation_noise: parameter_samples_with_reordered_batch_dimension[ 'observation_noise_scale'] = tf.zeros_like( parameter_samples_with_reordered_batch_dimension[ 'observation_noise_scale']) # We assume that any STS model that has a `constant_offset` attribute # will allow it to be overridden as a kwarg. This is currently just # `sts.Sum`. # TODO(b/120245392): when kwargs hack is removed, switch back to calling # the public version of `_make_state_space_model`. forecast_ssm = model._make_state_space_model( # pylint: disable=protected-access num_timesteps=num_steps_forecast, param_map=parameter_samples_with_reordered_batch_dimension, initial_state_prior=forecast_prior, initial_step=num_observed_steps, **kwargs) num_posterior_draws = dist_util.prefer_static_value( forecast_ssm.batch_shape_tensor())[-1] return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical( logits=tf.zeros([num_posterior_draws], dtype=forecast_ssm.dtype)), components_distribution=forecast_ssm)
def _sample_n(self, n, seed=None): init_seed, scan_seed, observation_seed = samplers.split_seed( seed, n=3, salt='HiddenMarkovModel') transition_batch_shape = self.transition_distribution.batch_shape_tensor( ) num_states = transition_batch_shape[-1] batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(batch_shape) // tf.reduce_prod(self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=init_seed) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = (tf.reduce_prod(batch_shape) // tf.reduce_prod(transition_batch_shape[:-1])) init_shape = init_state.shape def generate_step(state_and_seed, _): """Take a single step in Markov chain.""" state, seed = state_and_seed sample_seed, next_seed = samplers.split_seed(seed) gen = self._transition_distribution.sample(n * transition_repeat, seed=sample_seed) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1) # We know that `generate_step` must preserve the shape of the # tensor of states of each state. This is because # the transition matrix must be square. But TensorFlow might # not know this so we explicitly tell it that the result has the # same shape. tensorshape_util.set_shape(result, init_shape) return result, next_seed def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) hidden_states, _ = tf.scan(generate_step, dummy_index, initializer=(init_state, scan_seed)) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond(self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot(hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=observation_seed) inner_shape = self._observation_distribution.event_shape_tensor() # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat( [[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def posterior_marginals(self, observations, mask=None, name=None): """Compute marginal posterior distribution for each state. This function computes, for each time step, the marginal conditional probability that the hidden Markov model was in each possible state given the observations that were made at each time step. So if the hidden states are `z[0],...,z[num_steps - 1]` and the observations are `x[0], ..., x[num_steps - 1]`, then this function computes `P(z[i] | x[0], ..., x[num_steps - 1])` for all `i` from `0` to `num_steps - 1`. This operation is sometimes called smoothing. It uses a form of the forward-backward algorithm. Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimension of this tensor gives the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_marginal: A `Categorical` distribution object representing the marginal probability of the hidden Markov model being in each state at each step. The rightmost dimension of the `Categorical` distributions batch will equal the `num_steps` parameter providing one marginal distribution for each step. The other dimensions are the dimensions corresponding to the batch of observations. Raises: ValueError: if rightmost dimension of `observations` does not have size `num_steps`. """ with tf.name_scope(name or "posterior_marginals"): with tf.control_dependencies(self._runtime_assertions): observation_tensor_shape = tf.shape(observations) mask_tensor_shape = tf.shape( mask) if mask is not None else None with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape): observation_log_probs = self._observation_log_probs( observations, mask) log_prob = self._log_init + observation_log_probs[0] log_transition = self._log_trans log_adjoint_prob = tf.zeros_like(log_prob) def _scan_multiple_steps_forwards(): def forward_step(log_previous_step, log_prob_observation): return _log_vector_matrix( log_previous_step, log_transition) + log_prob_observation forward_log_probs = tf.scan(forward_step, observation_log_probs[1:], initializer=log_prob, name="forward_log_probs") return tf.concat([[log_prob], forward_log_probs], axis=0) forward_log_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_forwards, lambda: tf.convert_to_tensor([log_prob])) total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1], axis=-1) def _scan_multiple_steps_backwards(): """Perform `scan` operation when `num_steps` > 1.""" def backward_step(log_previous_step, log_prob_observation): return _log_matrix_vector( log_transition, log_prob_observation + log_previous_step) backward_log_adjoint_probs = tf.scan( backward_step, observation_log_probs[1:], initializer=log_adjoint_prob, reverse=True, name="backward_log_adjoint_probs") return tf.concat( [backward_log_adjoint_probs, [log_adjoint_prob]], axis=0) backward_log_adjoint_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_backwards, lambda: tf.convert_to_tensor([log_adjoint_prob])) log_likelihoods = forward_log_probs + backward_log_adjoint_probs marginal_log_probs = distribution_util.move_dimension( log_likelihoods - total_log_prob[..., tf.newaxis], 0, -2) return categorical.Categorical(logits=marginal_log_probs)
def infer_trajectories(observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, num_transitions_per_observation=1, num_steps_state_history_to_pass=None, num_steps_observation_history_to_pass=None, seed=None, name=None): # pylint: disable=g-doc-args """Use particle filtering to sample from the posterior over trajectories. ${particle_filter_arg_str} Returns: trajectories: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing unbiased samples from the posterior distribution `p(latent_states | observations)`. step_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. ${non_markovian_specification_str} #### Examples **Tracking unknown position and velocity**: Let's consider tracking an object moving in a one-dimensional space. We'll define a dynamical system by specifying an `initial_state_prior`, a `transition_fn`, and `observation_fn`. The structure of the latent state space is determined by the prior distribution. Here, we'll define a state space that includes the object's current position and velocity: ```python initial_state_prior = tfd.JointDistributionNamed({ 'position': tfd.Normal(loc=0., scale=1.), 'velocity': tfd.Normal(loc=0., scale=0.1)}) ``` The `transition_fn` specifies the evolution of the system. It should return a distribution over latent states of the same structure as the prior. Here, we'll assume that the position evolves according to the velocity, with a small random drift, and the velocity also changes slowly, following a random drift: ```python def transition_fn(_, previous_state): return tfd.JointDistributionNamed({ 'position': tfd.Normal( loc=previous_state['position'] + previous_state['velocity'], scale=0.1), 'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)}) ``` The `observation_fn` specifies the process by which the system is observed at each time step. Let's suppose we observe only a noisy version of the = current position. ```python def observation_fn(_, state): return tfd.Normal(loc=state['position'], scale=0.1) ``` Now let's track our object. Suppose we've been given observations corresponding to an initial position of `0.4` and constant velocity of `0.01`: ```python # Generate simulated observations. observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01), scale=0.1).sample() # Run particle filtering to sample plausible trajectories. (trajectories, # {'position': [40, 1000], 'velocity': [40, 1000]} lps) = tfp.experimental.mcmc.infer_trajectories( observations=observed_positions, initial_state_prior=initial_state_prior, transition_fn=transition_fn, observation_fn=observation_fn, num_particles=1000) ``` For all `i`, `trajectories['position'][:, i]` is a sample from the posterior over position sequences, given the observations: `p(state[0:T] | observations[0:T])`. Often, the sampled trajectories will be highly redundant in their earlier timesteps, because most of the initial particles have been discarded through resampling (this problem is known as 'particle degeneracy'; see section 3.5 of [Doucet and Johansen][1]). In such cases it may be useful to also consider the series of *filtering* distributions `p(state[t] | observations[:t])`, in which each latent state is inferred conditioned only on observations up to that point in time; these may be computed using `tfp.mcmc.experimental.particle_filter`. #### References [1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. _Handbook of nonlinear filtering_, 12(656-704), 2009. https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf """ with tf.name_scope(name or 'infer_trajectories') as name: seed = SeedStream(seed, 'infer_trajectories') (particles, log_weights, parent_indices, step_log_marginal_likelihoods) = particle_filter( observations=observations, initial_state_prior=initial_state_prior, transition_fn=transition_fn, observation_fn=observation_fn, num_particles=num_particles, initial_state_proposal=initial_state_proposal, proposal_fn=proposal_fn, resample_criterion_fn=resample_criterion_fn, rejuvenation_kernel_fn=rejuvenation_kernel_fn, num_transitions_per_observation=num_transitions_per_observation, num_steps_state_history_to_pass=num_steps_state_history_to_pass, num_steps_observation_history_to_pass=( num_steps_observation_history_to_pass), seed=seed, name=name) weighted_trajectories = reconstruct_trajectories( particles, parent_indices) # Resample all steps of the trajectories using the final weights. resample_indices = categorical.Categorical( dist_util.move_dimension(log_weights[-1, ...], source_idx=0, dest_idx=-1)).sample(num_particles, seed=seed) trajectories = tf.nest.map_structure( lambda x: _batch_gather(x, resample_indices, axis=1), weighted_trajectories) return trajectories, step_log_marginal_likelihoods
def resample_minimum_variance(log_probs, event_size, sample_shape, seed=None, name=None): """Minimum variance resampler for sequential Monte Carlo. This function is based on Algorithm #2 in [Maskell et al. (2006)][1]. Args: log_probs: A tensor-valued batch of discrete log probability distributions. event_size: the dimension of the vector considered a single draw. sample_shape: the `sample_shape` determining the number of draws. seed: Python '`int` used to seed calls to `tf.random.*`. Default value: None (i.e. no seed). name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'resample_minimum_variance'`). Returns: resampled_indices: The result is similar to sampling with ```python expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1) tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)` ``` but with values sorted along the first axis. It can be considered to be sampling events made up of a length-`event_size` vector of draws from the `Categorical` distribution. However, although the elements of this event have the appropriate marginal distribution, they are not independent of each other. Instead they have been chosen so as to form a good representative sample, suitable for use with Sequential Monte Carlo algorithms. The sortedness is an unintended side effect of the algorithm that is harmless in the context of simple SMC algorithms. #### References [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple Data Particle Filter. In 2006 IEEE Nonlinear Statistical Signal Processing Workshop. http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ with tf.name_scope(name or 'resample_minimum_variance') as name: log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32) log_probs = dist_util.move_dimension(log_probs, source_idx=0, dest_idx=-1) batch_shape = prefer_static.shape(log_probs)[:-1] working_shape = prefer_static.concat([sample_shape, batch_shape], axis=-1) log_cdf = tf.math.cumulative_logsumexp(log_probs[..., :-1], axis=-1) # Each resampling requires a single uniform random variable offset = uniform.Uniform(low=tf.constant(0., log_cdf.dtype), high=tf.constant(1., log_cdf.dtype)).sample( working_shape, seed=seed)[..., tf.newaxis] # It is possible for numerical error to result in a cumulative # sum that exceeds 1 so we need to clip. markers = prefer_static.cast( tf.floor(event_size * tf.math.exp(log_cdf) + offset), tf.int32) indices = markers[..., tf.newaxis] updates = tf.ones(prefer_static.shape(indices)[:-1], dtype=tf.int32) scatter_shape = prefer_static.concat([working_shape, [event_size + 1]], axis=-1) batch_dims = (prefer_static.rank_from_shape(sample_shape) + prefer_static.rank_from_shape(batch_shape)) x = _scatter_nd_batch(indices, updates, scatter_shape, batch_dims=batch_dims) resampled = tf.cumsum(x, axis=-1)[..., :-1] resampled = dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0) return resampled
def test_basic_example_time_dependent_batched(self): batch_shape = (2, 3) ndim = 7 # Dimension of latent space mdim = 5 # Dimension of observation space nsteps = 9 Batches = collections.namedtuple('Batches', [ 'initial_mean', 'initial_cov', 'transition_matrix', 'transition_mean', 'transition_cov', 'observation_matrix', 'observation_mean', 'observation_cov', 'mask' ]) def batch_generator(): # Skipping 'mask' case because it isn't used in sample generation. for skip in range(8): batch_list = skip * [()] + [batch_shape ] + (9 - skip - 1) * [()] yield Batches(*batch_list) # Test the broadcasting by ensuring each parameter individually # can be broadcast up to the full batch size. seed = test_util.test_seed(sampler_type='stateless') for batches in batch_generator(): iter_seed, seed = samplers.split_seed(seed, n=2, salt='') s = samplers.split_seed(iter_seed, n=10, salt='') initial_mean = _random_vector(ndim, batches.initial_mean, dtype=self.dtype, seed=s[0]) initial_cov = _random_variance(ndim, batches.initial_cov, dtype=self.dtype, seed=s[1]) transition_matrix = 0.2 * _random_matrix( # Avoid blowup (eigvals > 1). ndim, ndim, (nsteps, ) + batches.transition_matrix, dtype=self.dtype, seed=s[2]) transition_mean = _random_vector(ndim, (nsteps, ) + batches.transition_mean, dtype=self.dtype, seed=s[3]) transition_cov = _random_variance(ndim, (nsteps, ) + batches.transition_cov, dtype=self.dtype, seed=s[4]) observation_matrix = _random_matrix(mdim, ndim, (nsteps, ) + batches.observation_matrix, dtype=self.dtype, seed=s[5]) observation_mean = _random_vector(mdim, (nsteps, ) + batches.observation_mean, dtype=self.dtype, seed=s[6]) observation_cov = _random_variance(mdim, (nsteps, ) + batches.observation_cov, dtype=self.dtype, seed=s[7]) mask = _random_mask((nsteps, ) + batches.mask, dtype=tf.bool, seed=s[8]) _, y = parallel_kalman_filter_lib.sample_walk( transition_matrix=transition_matrix, transition_mean=transition_mean, transition_scale_tril=tf.linalg.cholesky(transition_cov), observation_matrix=observation_matrix, observation_mean=observation_mean, observation_scale_tril=tf.linalg.cholesky(observation_cov), initial_mean=initial_mean, initial_scale_tril=tf.linalg.cholesky(initial_cov), seed=s[9]) my_filter_results = parallel_kalman_filter_lib.kalman_filter( transition_matrix=transition_matrix, transition_mean=transition_mean, transition_cov=transition_cov, observation_matrix=observation_matrix, observation_mean=observation_mean, observation_cov=observation_cov, initial_mean=initial_mean, initial_cov=initial_cov, y=y, mask=mask) ((my_log_likelihoods, my_filtered_means, my_filtered_covs, my_predicted_means, my_predicted_covs, my_observation_means, my_observation_covs), y, mask) = tf.nest.map_structure( lambda x, r: distribution_util.move_dimension(x, 0, -r), (my_filter_results, y, mask), (type(my_filter_results)(1, 2, 3, 2, 3, 2, 3), 2, 1)) # pylint: disable=g-long-lambda,cell-var-from-loop mvn = tfd.MultivariateNormalFullCovariance dist = tfd.LinearGaussianStateSpaceModel( num_timesteps=nsteps, transition_matrix=lambda t: tf.linalg.LinearOperatorFullMatrix( tf.gather(transition_matrix, t, axis=0)), transition_noise=lambda t: mvn( loc=tf.gather(transition_mean, t, axis=0), covariance_matrix=tf.gather(transition_cov, t, axis=0)), observation_matrix=lambda t: tf.linalg. LinearOperatorFullMatrix( tf.gather(observation_matrix, t, axis=0)), observation_noise=lambda t: mvn( loc=tf.gather(observation_mean, t, axis=0), covariance_matrix=tf.gather(observation_cov, t, axis=0)), initial_state_prior=mvn(loc=initial_mean, covariance_matrix=initial_cov), experimental_parallelize=False ) # Compare against sequential filter. # pylint: enable=g-long-lambda,cell-var-from-loop (log_likelihoods, filtered_means, filtered_covs, predicted_means, predicted_covs, observation_means, observation_covs) = dist.forward_filter(y, mask) rtol = (1e-6 if self.dtype == np.float64 else 1e-1) atol = (1e-6 if self.dtype == np.float64 else 1e-3) self.assertAllClose(log_likelihoods, my_log_likelihoods, rtol=rtol, atol=atol) rtol = (1e-6 if self.dtype == np.float64 else 1e-3) atol = (1e-6 if self.dtype == np.float64 else 1e-3) self.assertAllClose(filtered_means, my_filtered_means, rtol=rtol, atol=atol) self.assertAllClose(filtered_covs, my_filtered_covs, rtol=rtol, atol=atol) self.assertAllClose(predicted_means, my_predicted_means, rtol=rtol, atol=atol) self.assertAllClose(predicted_covs, my_predicted_covs, rtol=rtol, atol=atol) self.assertAllClose(observation_means, my_observation_means, rtol=rtol, atol=atol) self.assertAllClose(observation_covs, my_observation_covs, rtol=rtol, atol=atol)
def _sample_multinomial_as_iterated_binomial(num_samples, num_classes, probs, num_trials, dtype, seed): """Sample a multinomial by drawing one binomial sample per class. The batch shape is given by broadcasting num_trials with remove_last_dimension(probs). The loop over binomial samples is a `tf.while_loop`, thus supporting a dynamic number of classes. Args: num_samples: Singleton integer Tensor: number of multinomial samples to draw. num_classes: Singleton integer Tensor: number of classes. probs: Floating Tensor with last dimension `num_classes`, of normalized probabilities per class. num_trials: Tensor of number of categorical trials each multinomial consists of. num_trials[..., tf.newaxis] must broadcast with probs. dtype: dtype at which to emit samples. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: samples: Tensor of given dtype and shape [num_samples] + batch_shape + [num_classes]. """ with tf.name_scope('draw_sample'): # `convert_to_tensor(num_classes) here to avoid unstacking inside # `split_seed`. We can't take advantage of the Python-list code path anyway # because the index at which we will take the seed is a Tensor. seeds = samplers.split_seed(seed, n=ps.convert_to_shape_tensor(num_classes), salt='multinomial_draw_sample') def fn(i, num_trials, consumed_prob, accum): """Sample the counts for one class using binomial.""" probs_here = tf.gather(probs, i, axis=-1) binomial_probs = tf.clip_by_value( probs_here / (1. - consumed_prob), 0, 1) seed_here = tf.gather(seeds, i, axis=0) binom = binomial.Binomial(total_count=num_trials, probs=binomial_probs) # Not passing `num_samples` to `binom.sample`, as it's is already in # `num_trials.shape`. sample = binom.sample(seed=seed_here) accum = accum.write(i, tf.cast(sample, dtype=dtype)) return i + 1, num_trials - sample, consumed_prob + probs_here, accum num_trials = tf.cast(num_trials, probs.dtype) # Pre-broadcast with probs num_trials = num_trials + tf.zeros_like(probs[..., 0]) # Pre-enlarge for different output samples num_trials = _replicate_along_left(num_trials, num_samples) i = tf.constant(0) consumed_prob = tf.zeros_like(probs[..., 0]) accum = tf.TensorArray(dtype, size=num_classes, element_shape=num_trials.shape) _, num_trials_left, _, accum = tf.while_loop( cond=lambda index, _0, _1, _2: tf.less(index, num_classes - 1), body=fn, loop_vars=(i, num_trials, consumed_prob, accum)) # Force the last iteration to put all the trials into the last bucket, # because probs[..., -1] / (1. - consumed_prob) might numerically not be 1. # Also saves one iteration around the while_loop and one run of the binomial # sampler. accum = accum.write(num_classes - 1, tf.cast(num_trials_left, dtype=dtype)) # This stop_gradient is necessary to prevent spurious zero gradients coming # from b/138796859, and a spurious gradient through num_trials_left. results = tf.stop_gradient(accum.stack()) return distribution_util.move_dimension(results, 0, -1)
def move_particles_to_rightmost_batch_dim(x, event_shape): ndims = prefer_static.rank_from_shape(prefer_static.shape(x)) event_ndims = prefer_static.rank_from_shape(event_shape) return dist_util.move_dimension(x, 0, ndims - event_ndims - 1)
def one_step_predictive(model, posterior_samples, num_forecast_steps=0, original_mean=0., original_scale=1., thin_every=10): """Constructs a one-step-ahead predictive distribution at every timestep. Unlike the generic `tfp.sts.one_step_predictive`, this method uses the latent levels from Gibbs sampling to efficiently construct a predictive distribution that mixes over posterior samples. The predictive distribution may also include additional forecast steps. This method returns the predictive distributions for each timestep given previous timesteps and sampled model parameters, `p(observed_time_series[t] | observed_time_series[:t], weights, observation_noise_scale)`. Note that the posterior values of the weights and noise scale will in general be informed by observations from all timesteps *including the step being predicted*, so this is not a strictly kosher probabilistic quantity, but in general we assume that it's close, i.e., that the step being predicted had very small individual impact on the overall parameter posterior. Args: model: A `tfd.sts.StructuralTimeSeries` model instance. This must be of the form constructed by `build_model_for_gibbs_sampling`. posterior_samples: A `GibbsSamplerState` instance in which each element is a `Tensor` with initial dimension of size `num_samples`. num_forecast_steps: Python `int` number of additional forecast steps to append. Default value: `0`. original_mean: Optional scalar float `Tensor`, added to the predictive distribution to undo the effect of input normalization. Default value: `0.` original_scale: Optional scalar float `Tensor`, used to rescale the predictive distribution to undo the effect of input normalization. Default value: `1.` thin_every: Optional Python `int` factor by which to thin the posterior samples, to reduce complexity of the predictive distribution. For example, if `thin_every=10`, every `10`th sample will be used. Default value: `10`. Returns: predictive_dist: A `tfd.MixtureSameFamily` instance of event shape `[num_timesteps + num_forecast_steps]` representing the predictive distribution of each timestep given previous timesteps. """ dtype = dtype_util.common_dtype([ posterior_samples.level_scale, posterior_samples.observation_noise_scale, posterior_samples.level, original_mean, original_scale], dtype_hint=tf.float32) num_observed_steps = prefer_static.shape(posterior_samples.level)[-1] original_mean = tf.convert_to_tensor(original_mean, dtype=dtype) original_scale = tf.convert_to_tensor(original_scale, dtype=dtype) thinned_samples = tf.nest.map_structure(lambda x: x[::thin_every], posterior_samples) if prefer_static.rank_from_shape( # If no slope was inferred, treat as zero. prefer_static.shape(thinned_samples.slope)) <= 1: thinned_samples = thinned_samples._replace( slope=tf.zeros_like(thinned_samples.level), slope_scale=tf.zeros_like(thinned_samples.level_scale)) num_steps_from_last_observation = tf.concat([ tf.ones([num_observed_steps], dtype=dtype), tf.range(1, num_forecast_steps + 1, dtype=dtype)], axis=0) # The local linear trend model expects that the level at step t + 1 is equal # to the level at step t, plus the slope at time t - 1, # plus transition noise of scale 'level_scale' (which we account for below). if num_forecast_steps > 0: num_batch_dims = prefer_static.rank_from_shape( prefer_static.shape(thinned_samples.level)) - 2 # All else equal, the current level will remain stationary. forecast_level = tf.tile(thinned_samples.level[..., -1:], tf.concat([tf.ones([num_batch_dims + 1], dtype=tf.int32), [num_forecast_steps]], axis=0)) # If the model includes slope, the level will steadily increase. forecast_level += (thinned_samples.slope[..., -1:] * tf.range(1., num_forecast_steps + 1., dtype=forecast_level.dtype)) level_pred = tf.concat([thinned_samples.level[..., :1], # t == 0 (thinned_samples.level[..., :-1] + thinned_samples.slope[..., :-1]) # 1 <= t < T ] + ( [forecast_level] if num_forecast_steps > 0 else []), axis=-1) design_matrix = _get_design_matrix( model).to_dense()[:num_observed_steps + num_forecast_steps] regression_effect = tf.linalg.matvec(design_matrix, thinned_samples.weights) y_mean = ((level_pred + regression_effect) * original_scale[..., tf.newaxis] + original_mean[..., tf.newaxis]) # To derive a forecast variance, including slope uncertainty, let # `r[:k]` be iid Gaussian RVs with variance `level_scale**2` and `s[:k]` be # iid Gaussian RVs with variance `slope_scale**2`. Then the forecast level at # step `T + k` can be written as # (level[T] + # Last known level. # r[0] + ... + r[k] + # Sum of random walk terms on level. # slope[T] * k # Contribution from last known slope. # (k - 1) * s[0] + # Contributions from random walk terms on slope. # (k - 2) * s[1] + # ... + # 1 * s[k - 1]) # which has variance of # (level_scale**2 * k + # slope_scale**2 * ( (k - 1)**2 + # (k - 2)**2 + # ... + 1 )) # Here the `slope_scale` coefficient is the `k - 1`th square pyramidal # number [1], which is given by # (k - 1) * k * (2 * k - 1) / 6. # # [1] https://en.wikipedia.org/wiki/Square_pyramidal_number variance_from_level = (thinned_samples.level_scale[..., tf.newaxis]**2 * num_steps_from_last_observation) variance_from_slope = thinned_samples.slope_scale[..., tf.newaxis]**2 * ( (num_steps_from_last_observation - 1) * num_steps_from_last_observation * (2 * num_steps_from_last_observation - 1)) / 6. y_scale = (original_scale * tf.sqrt( thinned_samples.observation_noise_scale[..., tf.newaxis]**2 + variance_from_level + variance_from_slope)) num_posterior_draws = prefer_static.shape(y_mean)[0] return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical( logits=tf.zeros([num_posterior_draws], dtype=y_mean.dtype)), components_distribution=tfd.Normal( loc=dist_util.move_dimension(y_mean, 0, -1), scale=dist_util.move_dimension(y_scale, 0, -1)))
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): strm = SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def _observation_log_probs(self, observations, mask): """Compute and shape tensor of log probs associated with observations..""" # Let E be the underlying event shape # M the number of steps in the HMM # N the number of states of the HMM # # Then the incoming observations have shape # # observations : batch_o [M] E # # and the mask (if present) has shape # # mask : batch_m [M] # # Let this HMM distribution have batch shape batch_d # We need to broadcast all three of these batch shapes together # into the shape batch. # # We need to move the step dimension to the first dimension to make # them suitable for folding or scanning over. # # When we call `log_prob` for our observations we need to # do this for each state the observation could correspond to. # We do this by expanding the dimensions by 1 so we end up with: # # observations : [M] batch [1] [E] # # After calling `log_prob` we get # # observation_log_probs : [M] batch [N] # # We wish to use `mask` to select from this so we also # reshape and broadcast it up to shape # # mask : [M] batch [N] observation_tensor_shape = tf.shape(observations) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] if mask is not None: mask_tensor_shape = tf.shape(mask) mask_batch_shape = mask_tensor_shape[:-1] batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) if mask is not None: batch_shape = tf.broadcast_dynamic_shape(batch_shape, mask_batch_shape) observations = tf.broadcast_to( observations, tf.concat([batch_shape, observation_event_shape], axis=0)) observation_rank = tf.rank(observations) underlying_event_rank = self._underlying_event_rank observations = distribution_util.move_dimension( observations, observation_rank - underlying_event_rank - 1, 0) observations = tf.expand_dims(observations, observation_rank - underlying_event_rank) observation_log_probs = self._observation_distribution.log_prob( observations) if mask is not None: mask = tf.broadcast_to( mask, tf.concat([batch_shape, [self._num_steps]], axis=0)) mask = distribution_util.move_dimension(mask, -1, 0) observation_log_probs = tf.where( mask[..., tf.newaxis], tf.zeros_like(observation_log_probs), observation_log_probs) return observation_log_probs
def resample_independent(log_probs, event_size, sample_shape, seed=None, name=None): """Categorical resampler for sequential Monte Carlo. This function is based on Algorithm #1 in the paper [Maskell et al. (2006)][1]. Args: log_probs: A tensor-valued batch of discrete log probability distributions. event_size: the dimension of the vector considered a single draw. sample_shape: the `sample_shape` determining the number of draws. seed: Python '`int` used to seed calls to `tf.random.*`. Default value: None (i.e. no seed). name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'resample_independent'`). Returns: resampled_indices: The result is similar to sampling with ```python expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1) tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)` ``` but with values sorted along the first axis. It can be considered to be sampling events made up of a length-`event_size` vector of draws from the `Categorical` distribution. For large input values this function should give better performance than using `Categorical`. The sortedness is an unintended side effect of the algorithm that is harmless in the context of simple SMC algorithms. #### References [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple Data Particle Filter. In 2006 IEEE Nonlinear Statistical Signal Processing Workshop. http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ with tf.name_scope(name or 'resample_independent') as name: log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32) log_probs = dist_util.move_dimension(log_probs, source_idx=0, dest_idx=-1) batch_shape = prefer_static.shape(log_probs)[:-1] num_markers = prefer_static.shape(log_probs)[-1] # `working_shape` specifies the total number of events # we will be generating. working_shape = prefer_static.concat([sample_shape, batch_shape], axis=0) # `points_shape` is the shape of the final result. points_shape = prefer_static.concat([working_shape, [event_size]], axis=0) # `markers_shape` is the shape of the markers we temporarily insert. markers_shape = prefer_static.concat([working_shape, [num_markers]], axis=0) # Generate one real point for each particle. log_points = -exponential.Exponential( rate=tf.constant(1.0, dtype=log_probs.dtype)).sample(points_shape, seed=seed) # We divide up the unit interval [0, 1] according to the provided # probability distributions using `cumsum`. # At the end of each division we place a 'marker'. # We generate random points on the unit interval. # We sort the combination of points and markers. The number # of points between the markers defining a division gives the number # of samples we require in that division. # For example, suppose `probs` is `[0.2, 0.3, 0.5]`. # We divide up `[0, 1]` using 3 markers: # # | | | # 0. 0.2 0.5 1.0 <- markers # # Suppose we generate four points: [0.1, 0.25, 0.9, 0.75] # After sorting the combination we get: # # 0.1 0.25 0.75 0.9 <- points # * | * | * *| # 0. 0.2 0.5 1.0 <- markers # # We have one sample in the first category, one in the second and # two in the last. # # All of these computations are carried out in batched form. markers = prefer_static.concat([ tf.zeros(points_shape, dtype=tf.int32), tf.ones(markers_shape, dtype=tf.int32) ], axis=-1) log_marker_positions = tf.broadcast_to( tf.math.cumulative_logsumexp(log_probs, axis=-1), markers_shape) log_points_and_markers = prefer_static.concat( [log_points, log_marker_positions], axis=-1) indices = tf.argsort(log_points_and_markers, axis=-1, stable=False) sorted_markers = tf.gather_nd( markers, indices[..., tf.newaxis], batch_dims=(prefer_static.rank_from_shape(sample_shape) + prefer_static.rank_from_shape(batch_shape))) markers_and_samples = prefer_static.cast(tf.cumsum(sorted_markers, axis=-1), dtype=tf.int32) markers_and_samples = tf.minimum(markers_and_samples, num_markers - 1) # Collect up samples, omitting markers. resampled = tf.reshape( markers_and_samples[tf.equal(sorted_markers, 0)], points_shape) resampled = dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0) return resampled
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): seed = seed_stream.SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=seed()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=seed()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) if self._num_steps > 1: dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan hidden_states = tf.concat([[init_state], hidden_states], axis=0) else: hidden_states = init_state[tf.newaxis, ...] # hidden_states :: num_steps n batch_size num_states hidden_one_hot = tf.one_hot(hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = ( batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n]) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape(hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = util.move_dimension(observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def bracket_root(objective_fn, dtype=tf.float32, num_points=512, name='bracket_root'): """Finds bounds that bracket a root of the objective function. This method attempts to return an interval bracketing a root of the objective function. It evaluates the objective in parallel at `num_points` locations, at exponentially increasing distance from the origin, and returns the first pair of adjacent points `[low, high]` such that the objective is finite and has a different sign at the two points. If no such pair was observed, it returns the trivial interval `[np.finfo(dtype).min, np.finfo(dtype).max]` containing all float values of the specified `dtype`. If the objective has multiple roots, the returned interval will contain at least one (but perhaps not all) of the roots. Args: objective_fn: Python callable for which roots are searched. It must be a continuous function that accepts a scalar `Tensor` of type `dtype` and returns a `Tensor` of shape `batch_shape`. dtype: Optional float `dtype` of inputs to `objective_fn`. Default value: `tf.float32`. num_points: Optional Python `int` number of points at which to evaluate the objective. Default value: `512`. name: Python `str` name given to ops created by this method. Returns: low: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Lower bound on a root of `objective_fn`. high: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Upper bound on a root of `objective_fn`. """ with tf.name_scope(name): # Build a logarithmic sequence of `num_points` values from -inf to inf. dtype_info = np.finfo(dtype_util.as_numpy_dtype(dtype)) xs_positive = tf.exp(tf.linspace(tf.cast(-10., dtype), tf.math.log(dtype_info.max), num_points // 2)) xs = tf.concat([tf.reverse(-xs_positive, axis=[0]), xs_positive], axis=0) # Evaluate the objective at all points. The objective function may return # a batch of values (e.g., `objective(x) = x - batch_of_roots`). if NUMPY_MODE: objective_output_spec = objective_fn(tf.zeros([], dtype=dtype)) else: objective_output_spec = callable_util.get_output_spec( objective_fn, tf.convert_to_tensor(0., dtype=dtype)) batch_ndims = tensorshape_util.rank(objective_output_spec.shape) if batch_ndims is None: raise ValueError('Cannot infer tensor rank of objective values.') xs_pad_shape = ps.pad([num_points], paddings=[[0, batch_ndims]], constant_values=1) ys = objective_fn(tf.reshape(xs, xs_pad_shape)) # Find the smallest point where the objective is finite. is_finite = tf.math.is_finite(ys) ys_transposed = distribution_util.move_dimension( # For batch gather. ys, 0, -1) first_finite_value = tf.gather( ys_transposed, tf.argmax(is_finite, axis=0), # Index of smallest finite point. batch_dims=batch_ndims, axis=-1) # Select the next point where the objective has a different sign. sign_change_idx = tf.argmax( tf.not_equal(tf.math.sign(ys), tf.math.sign(first_finite_value)) & is_finite, axis=0) # If the sign never changes, we can't bracket a root. bracketing_failed = tf.equal(sign_change_idx, 0) # If the objective's sign is zero, we've found an actual root. root_found = tf.equal(tf.gather(tf.math.sign(ys_transposed), sign_change_idx, batch_dims=batch_ndims, axis=-1), 0.) return _structure_broadcasting_where( bracketing_failed, # If we didn't detect a sign change, fall back to the trivial interval. (dtype_info.min, dtype_info.max), # Otherwise, return the points around the sign change, unless we # actually evaluated a root, in which case, return the zero-width # bracket at that root. (tf.gather(xs, tf.where(bracketing_failed | root_found, sign_change_idx, sign_change_idx - 1)), tf.gather(xs, sign_change_idx)))
def index_remapping_gather(params, indices, axis=0, indices_axis=0, name='index_remapping_gather'): """Gather values from `axis` of `params` using `indices_axis` of `indices`. The shape of `indices` must broadcast to that of `params` when their `indices_axis` and `axis` (respectively) are aligned: ```python # params.shape: [p[0], ..., ..., p[axis], ..., ..., p[rank(params)] - 1]) # indices.shape: [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1]) ``` In particular, `params` must have at least as many leading dimensions as `indices` (`axis >= indices_axis`), and at least as many trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`). The `result` has the same shape as `params`, except that the dimension of size `p[axis]` is replaced by one of size `i[indices_axis]`: ```python # result.shape: [p[0], ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]] ``` In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and `indices_axis = 1`, the result is given by ```python # alignment is: v axis # params.shape == [p[0], p[1], p[2], p[3], p[4]] # indices.shape == [i[0], i[1], i[2]] # ^ indices_axis result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m] ``` Args: params: `N-D` `Tensor` (`N > 0`) from which to gather values. Number of dimensions must be known statically. indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose shape broadcasts to that of `params` as described above. axis: Python `int` axis of `params` from which to gather. indices_axis: Python `int` axis of `indices` to align with the `axis` over which `params` is gathered. name: String name for scoping created ops. Returns: `Tensor` composed of elements of `params`. Raises: ValueError: If shape/rank requirements are not met. """ with tf.name_scope(name): params = tf.convert_to_tensor(params, name='params') indices = tf.convert_to_tensor(indices, name='indices') params_ndims = tensorshape_util.rank(params.shape) indices_ndims = tensorshape_util.rank(indices.shape) # `axis` dtype must match ndims, which are 64-bit Python ints. axis = tf.get_static_value(tf.convert_to_tensor(axis, dtype=tf.int64)) indices_axis = tf.get_static_value( tf.convert_to_tensor(indices_axis, dtype=tf.int64)) if params_ndims is None: raise ValueError( 'Rank of `params`, must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if axis is None: raise ValueError( '`axis` must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if indices_axis is None: raise ValueError( '`indices_axis` must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if indices_axis > axis: raise ValueError( '`indices_axis` should be <= `axis`, but was {} > {}'.format( indices_axis, axis)) if params_ndims < 1: raise ValueError( 'Rank of params should be `> 0`, but was {}'.format( params_ndims)) if indices_ndims is not None and indices_ndims < 1: raise ValueError( 'Rank of indices should be `> 0`, but was {}'.format( indices_ndims)) if (indices_ndims is not None and (indices_ndims - indices_axis > params_ndims - axis)): raise ValueError( '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - ' 'indices_axis` ({} - {}), but was not.'.format( params_ndims, axis, indices_ndims, indices_axis)) # `tf.gather` requires the axis to be the rightmost batch ndim. So, we # transpose `indices_axis` to be the rightmost dimension of `indices`... transposed_indices = dist_util.move_dimension(indices, source_idx=indices_axis, dest_idx=-1) # ... and `axis` to be the corresponding (aligned as in the docstring) # dimension of `params`. broadcast_indices_ndims = indices_ndims + (axis - indices_axis) transposed_params = dist_util.move_dimension( params, source_idx=axis, dest_idx=broadcast_indices_ndims - 1) # Next we broadcast `indices` so that its shape has the same prefix as # `params.shape`. transposed_params_shape = prefer_static.shape(transposed_params) result_shape = prefer_static.concat([ transposed_params_shape[:broadcast_indices_ndims - 1], prefer_static.shape(indices)[indices_axis:indices_axis + 1], transposed_params_shape[broadcast_indices_ndims:] ], axis=0) broadcast_indices = prefer_static.broadcast_to( transposed_indices, result_shape[:broadcast_indices_ndims]) result_t = tf.gather(transposed_params, broadcast_indices, batch_dims=broadcast_indices_ndims - 1, axis=broadcast_indices_ndims - 1) return dist_util.move_dimension(result_t, source_idx=broadcast_indices_ndims - 1, dest_idx=axis)
def __init__(self, num_timesteps, design_matrix, drift_scale, initial_state_prior, observation_noise_scale=0., initial_step=0, validate_args=False, allow_nan_stats=True, name=None): """State space model for a dynamic linear regression. Args: num_timesteps: Scalar `int` `Tensor` number of timesteps to model with this distribution. design_matrix: float `Tensor` of shape `concat([batch_shape, [num_timesteps, num_features]])`. drift_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the latent state transitions. initial_state_prior: instance of `tfd.MultivariateNormal` representing the prior distribution on latent states. Must have event shape `[num_features]`. observation_noise_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the observation noise. Default value: `0.`. initial_step: scalar `int` `Tensor` specifying the starting timestep. Default value: `0`. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. allow_nan_stats: Python `bool`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. Default value: `True`. name: Python `str` name prefixed to ops created by this class. Default value: 'DynamicLinearRegressionStateSpaceModel'. """ with tf.name_scope( name or 'DynamicLinearRegressionStateSpaceModel') as name: dtype = dtype_util.common_dtype( [design_matrix, drift_scale, initial_state_prior]) design_matrix = tf.convert_to_tensor( value=design_matrix, name='design_matrix', dtype=dtype) design_matrix_with_time_in_first_dim = distribution_util.move_dimension( design_matrix, -2, 0) drift_scale = tf.convert_to_tensor( value=drift_scale, name='drift_scale', dtype=dtype) observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) num_features = prefer_static.shape(design_matrix)[-1] def observation_matrix_fn(t): observation_matrix = tf.linalg.LinearOperatorFullMatrix( tf.gather(design_matrix_with_time_in_first_dim, t)[..., tf.newaxis, :], name='observation_matrix') return observation_matrix self._drift_scale = drift_scale self._observation_noise_scale = observation_noise_scale super(DynamicLinearRegressionStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=tf.linalg.LinearOperatorIdentity( num_rows=num_features, dtype=dtype, name='transition_matrix'), transition_noise=tfd.MultivariateNormalDiag( scale_diag=(drift_scale[..., tf.newaxis] * tf.ones([num_features], dtype=dtype)), name='transition_noise'), observation_matrix=observation_matrix_fn, observation_noise=tfd.MultivariateNormalDiag( scale_diag=observation_noise_scale[..., tf.newaxis], name='observation_noise'), initial_state_prior=initial_state_prior, initial_step=initial_step, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name)