def pks(lgssm, ms, Ps, max_parallel=10000): max_num_levels = math.ceil(math.log2(max_parallel)) _, Fs, Qs, *_ = lgssm initial_elements = make_associative_smoothing_elements(Fs, Qs, ms, Ps) reversed_elements = tuple( tf.reverse(elem, axis=[0]) for elem in initial_elements) final_elements = scan_associative(smoothing_operator, reversed_elements, max_num_levels=max_num_levels) return tf.reverse(final_elements[1], axis=[0]), tf.reverse(final_elements[2], axis=[0])
def pkf(lgssm, observations, return_loglikelihood=False, max_parallel=10000): with tf.name_scope("parallel_filter"): P0, Fs, Qs, H, R = lgssm dtype = P0.dtype m0 = tf.zeros(tf.shape(P0)[0], dtype=dtype) max_num_levels = math.ceil(math.log2(max_parallel)) initial_elements = make_associative_filtering_elements( m0, P0, Fs, Qs, H, R, observations) final_elements = scan_associative(filtering_operator, initial_elements, max_num_levels=max_num_levels) if return_loglikelihood: filtered_means = tf.concat( [tf.expand_dims(m0, 0), final_elements[1][:-1]], axis=0) filtered_cov = tf.concat( [tf.expand_dims(P0, 0), final_elements[2][:-1]], axis=0) predicted_means = mv(Fs, filtered_means) predicted_covs = mm(Fs, mm(filtered_cov, Fs, transpose_b=True)) + Qs obs_means = mv(H, predicted_means) obs_covs = mm(H, mm(predicted_covs, H, transpose_b=True)) + tf.expand_dims(R, 0) dists = MultivariateNormalTriL(obs_means, tf.linalg.cholesky(obs_covs)) # TODO: some logic could be added here to avoid handling the covariance of non-nan models, but no impact for GPs logprobs = dists.log_prob(observations) logprobs_without_nans = tf.where(tf.math.is_nan(logprobs), tf.zeros_like(logprobs), logprobs) total_log_prob = tf.reduce_sum(logprobs_without_nans) return final_elements[1], final_elements[2], total_log_prob return final_elements[1], final_elements[2]
def kalman_filter(transition_matrix, transition_mean, transition_cov, observation_matrix, observation_mean, observation_cov, initial_mean, initial_cov, y, mask, return_all=True): """Infers latent values using a parallel Kalman filter. This method computes filtered marginal means and covariances of a linear Gaussian state-space model using a parallel message-passing algorithm, as described by Sarkka and Garcia-Fernandez [1]. The inference process is formulated as a prefix-sum problem that can be efficiently computed by `tfp.math.scan_associative`, so that inference for a time series of length `num_timesteps` requires only `O(log(num_timesteps))` sequential steps. As with a naive sequential implementation, the total FLOP count scales linearly in `num_timesteps` (as `O(T + T/2 + T/4 + ...) = O(T)`), so this approach does not require extra resources in an asymptotic sense. However, it likely has a somewhat larger constant factor, so a sequential filter may be preferred when throughput rather than latency is the highest priority. Args: transition_matrix: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size, latent_size]`. transition_mean: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size]`. transition_cov: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size, latent_size]`. observation_matrix: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size, latent_size]`. observation_mean: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size]`. observation_cov: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size, observation_size]`. initial_mean: float `Tensor` of shape `[B1, .., BN, latent_size]`. initial_cov: float `Tensor` of shape `[B1, .., BN, latent_size, latent_size]`. y: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size]`. mask: float `Tensor` of shape `[num_timesteps, B1, .., BN]`. return_all: Python `bool`, whether to compute log-likelihoods and predictive and observation distributions. If `False`, only `filtered_means` and `filtered_covs` are computed, and `None` is returned for the remaining values. Returns: log_likelihoods: float `Tensor` of shape `[num_timesteps, B1, .., BN]`, such that `log_likelihoods[t] = log p(y[t] | y[:t])`. filtered_means: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size]`, such that `filtered_means[t] == E[x[t] | y[:t + 1]]`. filtered_covs: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size, latent_size]`. predictive_means: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size]`, such that `predictive_means[t] = E[x[t + 1] | y[:t + 1]]`. predictive_covs: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size, latent_size]`. observation_means: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size]`, such that `observation_means[t] = E[y[t] | y[:t]]`. observation_covs:float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size, observation_size]`. ### Mathematical Details The assumed model consists of latent state vectors `x[:num_timesteps, :latent_size]` and corresponding observed values `y[:num_timesteps, :observation_size]`, governed by the following dynamics: ``` x[0] ~ MultivariateNormal(mean=initial_mean, cov=initial_cov) for t in range(num_timesteps - 1): x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t], x[t]) + transition_mean[t], cov=transition_cov[t]) # Observed values `y[:num_timesteps]` defined at all timesteps. y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean, cov=observation_cov) ``` ### Tensor layout `Tensor` arguments are expected to have `num_timesteps` as their *leftmost* axis, preceding any batch dimensions. This layout is used for internal computations, so providing arguments in this form avoids the need for potentially-spurious transposition. The returned `Tensor`s also follow this layout, for the same reason. Note that this differs from the layout mandated by the `tfd.Distribution` API (and exposed by `tfd.LinearGaussianStateSpaceModel`), in which the time axis is to the right of any batch dimensions; it is the caller's responsibility to perform any needed transpositions. Tensor arguments may be specified with partial batch shape, i.e., with shape prefix `[num_timesteps, Bk, ..., BN]` for `k > 1`. They will be internally reshaped and broadcast to the full batch shape prefix `[num_timesteps, B1, ..., BN]`. ### References [1] Simo Sarkka and Angel F. Garcia-Fernandez. Temporal Parallelization of Bayesian Smoothers. _arXiv preprint arXiv:1905.13002_, 2019. https://arxiv.org/abs/1905.13002 """ with tf.name_scope('kalman_filter'): time_indep, time_dep, observation = broadcast_to_full_batch_shape( time_indep=TimeIndependentParameters(initial_cov=initial_cov, initial_scale_tril=None, initial_mean=initial_mean), time_dep=TimeDependentParameters( transition_matrix=transition_matrix, transition_cov=transition_cov, transition_scale_tril=None, transition_mean=transition_mean, observation_matrix=observation_matrix, observation_cov=observation_cov, observation_scale_tril=None, observation_mean=observation_mean), observation=Observations(y, mask)) # Prevent any masked NaNs from leaking into gradients. if observation.mask is not None: observation = Observations(y=tf.where( observation.mask[..., None], tf.zeros([], dtype=observation.y.dtype), observation.y), mask=observation.mask) # Run Kalman filter. filtered = tfp_math.scan_associative( combine_filter_elements, filter_elements(time_indep, time_dep, observation)) filtered_means = filtered.posterior_mean filtered_covs = filtered.posterior_cov log_likelihoods = None predicted_means, predicted_covs = None, None observation_means, observation_covs = None, None # Compute derived quantities (predictive distributions, likelihood, etc.). if return_all: predicted_means = _propagate_mean( matrix=time_dep.transition_matrix, mean=filtered_means, added_mean=time_dep.transition_mean) observation_means = _propagate_mean( matrix=time_dep.observation_matrix, mean=tf.concat( [[time_indep.initial_mean], predicted_means[:-1]], axis=0), added_mean=time_dep.observation_mean) predicted_covs = _propagate_cov(matrix=time_dep.transition_matrix, cov=filtered_covs, added_cov=time_dep.transition_cov) observation_covs = _propagate_cov( matrix=time_dep.observation_matrix, cov=tf.concat([[time_indep.initial_cov], predicted_covs[:-1]], axis=0), added_cov=time_dep.observation_cov) log_likelihoods = mvn_tril.MultivariateNormalTriL( loc=observation_means, scale_tril=tf.linalg.cholesky(observation_covs)).log_prob( observation.y) if observation.mask is not None: log_likelihoods = tf.where( observation.mask, tf.zeros([], dtype=log_likelihoods.dtype), log_likelihoods) return FilterResults(log_likelihoods, filtered_means, filtered_covs, predicted_means, predicted_covs, observation_means, observation_covs)
def sample_walk(transition_matrix, transition_mean, transition_scale_tril, observation_matrix, observation_mean, observation_scale_tril, initial_mean, initial_scale_tril, seed=None): """Samples from the joint distribution of a linear Gaussian state-space model. This method draws samples from the joint prior distribution on latent and observed variables in a linear Gaussian state-space model. The sampling is parallelized over timesteps, so that sampling a sequence of length `num_timesteps` requires only `O(log(num_timesteps))` sequential steps. As with a naive sequential implementation, the total FLOP count scales linearly in `num_timesteps` (as `O(T + T/2 + T/4 + ...) = O(T)`), so this approach does not require extra resources in an asymptotic sense. However, it likely has a somewhat larger constant factor, so a sequential sampler may be preferred when throughput rather than latency is the highest priority. Args: transition_matrix: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size, latent_size]`. transition_mean: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size]`. transition_scale_tril: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size, latent_size]`. observation_matrix: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size, latent_size]`. observation_mean: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size]`. observation_scale_tril: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size, observation_size]`. initial_mean: float `Tensor` of shape `[B1, .., BN, latent_size]`. initial_scale_tril: float `Tensor` of shape `[B1, .., BN, latent_size, latent_size]`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: x: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size]`. y: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size]`. ### Mathematical Details The assumed model consists of latent state vectors `x[:num_timesteps, :latent_size]` and corresponding observed values `y[:num_timesteps, :observation_size]`, governed by the following dynamics: ``` x[0] ~ MultivariateNormal(mean=initial_mean, scale_tril=initial_scale_tril) for t in range(num_timesteps - 1): x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t], x[t]) + transition_mean[t], scale_tril=transition_scale_tril[t]) # Observed values `y[:num_timesteps]` defined at all timesteps. y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean, scale_tril=observation_scale_tril) ``` ### Tensor layout `Tensor` arguments are expected to have `num_timesteps` as their *leftmost* axis, preceding any batch dimensions. This layout is used for internal computations, so providing arguments in this form avoids the need for potentially-spurious transposition. The returned `Tensor`s also follow this layout, for the same reason. Note that this differs from the layout mandated by the `tfd.Distribution` API (and exposed by `tfd.LinearGaussianStateSpaceModel`), in which the time axis is to the right of any batch dimensions; it is the caller's responsibility to perform any needed transpositions. Note that this method takes `scale_tril` matrices specifying the Cholesky factors of covariance matrices, in contrast to `tfp.experimental.parallel_filter.kalman_filter`, which takes the covariance matrices directly. This is to avoid redundant factorization, since the sampling process uses Cholesky factors natively, while the filtering updates we implement require covariance matrices. In addition, taking `scale_tril` matrices directly ensures that sampling is well-defined even when one or more components of the model are deterministic (`scale_tril=zeros([...])`). Tensor arguments may be specified with partial batch shape, i.e., with shape prefix `[num_timesteps, Bk, ..., BN]` for `k > 1`. They will be internally reshaped and broadcast to the full batch shape prefix `[num_timesteps, B1, ..., BN]`. """ with tf.name_scope('sample_walk'): time_indep, time_dep, _ = broadcast_to_full_batch_shape( time_indep=TimeIndependentParameters( initial_cov=None, initial_scale_tril=initial_scale_tril, initial_mean=initial_mean), time_dep=TimeDependentParameters( transition_matrix=transition_matrix, transition_cov=None, transition_scale_tril=transition_scale_tril, transition_mean=transition_mean, observation_matrix=observation_matrix, observation_cov=None, observation_scale_tril=observation_scale_tril, observation_mean=observation_mean)) s1, s2, s3 = samplers.split_seed(seed, n=3) updates = tfp_math.scan_associative( combine_walk, AffineUpdate( transition_matrix=time_dep.transition_matrix[:-1], mean=mvn_tril.MultivariateNormalTriL( loc=time_dep.transition_mean[:-1], scale_tril=time_dep.transition_scale_tril[:-1]).sample( seed=s1))) x0 = mvn_tril.MultivariateNormalTriL( loc=time_indep.initial_mean, scale_tril=time_indep.initial_scale_tril).sample(seed=s2) x = tf.concat( [[x0], tf.linalg.matvec(updates.transition_matrix, x0) + updates.mean], axis=0) y = (tf.linalg.matvec(time_dep.observation_matrix, x) + time_dep.observation_mean + mvn_tril.MultivariateNormalTriL( scale_tril=time_dep.observation_scale_tril).sample(seed=s3)) return x, y