Ejemplo n.º 1
0
def pks(lgssm, ms, Ps, max_parallel=10000):
    max_num_levels = math.ceil(math.log2(max_parallel))
    _, Fs, Qs, *_ = lgssm
    initial_elements = make_associative_smoothing_elements(Fs, Qs, ms, Ps)
    reversed_elements = tuple(
        tf.reverse(elem, axis=[0]) for elem in initial_elements)

    final_elements = scan_associative(smoothing_operator,
                                      reversed_elements,
                                      max_num_levels=max_num_levels)
    return tf.reverse(final_elements[1],
                      axis=[0]), tf.reverse(final_elements[2], axis=[0])
Ejemplo n.º 2
0
def pkf(lgssm, observations, return_loglikelihood=False, max_parallel=10000):
    with tf.name_scope("parallel_filter"):
        P0, Fs, Qs, H, R = lgssm
        dtype = P0.dtype
        m0 = tf.zeros(tf.shape(P0)[0], dtype=dtype)

        max_num_levels = math.ceil(math.log2(max_parallel))

        initial_elements = make_associative_filtering_elements(
            m0, P0, Fs, Qs, H, R, observations)

        final_elements = scan_associative(filtering_operator,
                                          initial_elements,
                                          max_num_levels=max_num_levels)

        if return_loglikelihood:
            filtered_means = tf.concat(
                [tf.expand_dims(m0, 0), final_elements[1][:-1]], axis=0)
            filtered_cov = tf.concat(
                [tf.expand_dims(P0, 0), final_elements[2][:-1]], axis=0)
            predicted_means = mv(Fs, filtered_means)
            predicted_covs = mm(Fs, mm(filtered_cov, Fs,
                                       transpose_b=True)) + Qs
            obs_means = mv(H, predicted_means)
            obs_covs = mm(H, mm(predicted_covs, H,
                                transpose_b=True)) + tf.expand_dims(R, 0)

            dists = MultivariateNormalTriL(obs_means,
                                           tf.linalg.cholesky(obs_covs))
            # TODO: some logic could be added here to avoid handling the covariance of non-nan models, but no impact for GPs
            logprobs = dists.log_prob(observations)

            logprobs_without_nans = tf.where(tf.math.is_nan(logprobs),
                                             tf.zeros_like(logprobs), logprobs)
            total_log_prob = tf.reduce_sum(logprobs_without_nans)
            return final_elements[1], final_elements[2], total_log_prob
        return final_elements[1], final_elements[2]
def kalman_filter(transition_matrix,
                  transition_mean,
                  transition_cov,
                  observation_matrix,
                  observation_mean,
                  observation_cov,
                  initial_mean,
                  initial_cov,
                  y,
                  mask,
                  return_all=True):
    """Infers latent values using a parallel Kalman filter.

  This method computes filtered marginal means and covariances of a linear
  Gaussian state-space model using a parallel message-passing algorithm, as
  described by Sarkka and Garcia-Fernandez [1]. The inference process is
  formulated as a prefix-sum problem that can be efficiently computed by
  `tfp.math.scan_associative`, so that inference for a time series of length
  `num_timesteps` requires only `O(log(num_timesteps))` sequential steps.

  As with a naive sequential implementation, the total FLOP count scales
  linearly in `num_timesteps` (as `O(T + T/2 + T/4 + ...) = O(T)`), so this
  approach does not require extra resources in an asymptotic sense. However, it
  likely has a somewhat larger constant factor, so a sequential filter may be
  preferred when throughput rather than latency is the highest priority.

  Args:
    transition_matrix: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, latent_size, latent_size]`.
    transition_mean: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, latent_size]`.
    transition_cov: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, latent_size, latent_size]`.
    observation_matrix: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, observation_size, latent_size]`.
    observation_mean: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, observation_size]`.
    observation_cov: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, observation_size, observation_size]`.
    initial_mean: float `Tensor` of shape
       `[B1, .., BN, latent_size]`.
    initial_cov: float `Tensor` of shape
       `[B1, .., BN, latent_size, latent_size]`.
    y: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, observation_size]`.
    mask: float `Tensor` of shape `[num_timesteps, B1, .., BN]`.
    return_all: Python `bool`, whether to compute log-likelihoods and
      predictive and observation distributions. If `False`, only
      `filtered_means` and `filtered_covs` are computed, and `None` is returned
      for the remaining values.
  Returns:
    log_likelihoods: float `Tensor` of shape `[num_timesteps, B1, .., BN]`, such
      that `log_likelihoods[t] = log p(y[t] | y[:t])`.
    filtered_means: float `Tensor` of shape
      `[num_timesteps, B1, .., BN, latent_size]`, such that
      `filtered_means[t] == E[x[t] | y[:t + 1]]`.
    filtered_covs: float `Tensor` of shape
      `[num_timesteps, B1, .., BN, latent_size, latent_size]`.
    predictive_means: float `Tensor` of shape
      `[num_timesteps, B1, .., BN, latent_size]`, such that
      `predictive_means[t] = E[x[t + 1] | y[:t + 1]]`.
    predictive_covs: float `Tensor` of shape
      `[num_timesteps, B1, .., BN, latent_size, latent_size]`.
    observation_means: float `Tensor` of shape
      `[num_timesteps, B1, .., BN, observation_size]`, such that
      `observation_means[t] = E[y[t] | y[:t]]`.
    observation_covs:float `Tensor` of shape
      `[num_timesteps, B1, .., BN, observation_size, observation_size]`.

  ### Mathematical Details

  The assumed model consists of latent state vectors
  `x[:num_timesteps, :latent_size]` and corresponding observed values
  `y[:num_timesteps, :observation_size]`, governed by the following dynamics:

  ```
  x[0] ~ MultivariateNormal(mean=initial_mean, cov=initial_cov)
  for t in range(num_timesteps - 1):
    x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t],
                                              x[t]) + transition_mean[t],
                                  cov=transition_cov[t])
  # Observed values `y[:num_timesteps]` defined at all timesteps.
  y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean,
                         cov=observation_cov)
  ```

  ### Tensor layout

  `Tensor` arguments are expected to have `num_timesteps` as their *leftmost*
  axis, preceding any batch dimensions. This layout is used
  for internal computations, so providing arguments in this form avoids the
  need for potentially-spurious transposition. The returned `Tensor`s also
  follow this layout, for the same reason. Note that this differs from the
  layout mandated by the `tfd.Distribution`
  API (and exposed by `tfd.LinearGaussianStateSpaceModel`), in which the time
  axis is to the right of any batch dimensions; it is the caller's
  responsibility to perform any needed transpositions.

  Tensor arguments may be specified with partial batch shape, i.e., with
  shape prefix `[num_timesteps, Bk, ..., BN]` for `k > 1`. They will be
  internally reshaped and broadcast to the full batch shape prefix
  `[num_timesteps, B1, ..., BN]`.

  ### References

  [1] Simo Sarkka and Angel F. Garcia-Fernandez. Temporal Parallelization of
      Bayesian Smoothers. _arXiv preprint arXiv:1905.13002_, 2019.
      https://arxiv.org/abs/1905.13002

  """
    with tf.name_scope('kalman_filter'):
        time_indep, time_dep, observation = broadcast_to_full_batch_shape(
            time_indep=TimeIndependentParameters(initial_cov=initial_cov,
                                                 initial_scale_tril=None,
                                                 initial_mean=initial_mean),
            time_dep=TimeDependentParameters(
                transition_matrix=transition_matrix,
                transition_cov=transition_cov,
                transition_scale_tril=None,
                transition_mean=transition_mean,
                observation_matrix=observation_matrix,
                observation_cov=observation_cov,
                observation_scale_tril=None,
                observation_mean=observation_mean),
            observation=Observations(y, mask))

        # Prevent any masked NaNs from leaking into gradients.
        if observation.mask is not None:
            observation = Observations(y=tf.where(
                observation.mask[..., None],
                tf.zeros([], dtype=observation.y.dtype), observation.y),
                                       mask=observation.mask)

        # Run Kalman filter.
        filtered = tfp_math.scan_associative(
            combine_filter_elements,
            filter_elements(time_indep, time_dep, observation))
        filtered_means = filtered.posterior_mean
        filtered_covs = filtered.posterior_cov
        log_likelihoods = None
        predicted_means, predicted_covs = None, None
        observation_means, observation_covs = None, None
        # Compute derived quantities (predictive distributions, likelihood, etc.).
        if return_all:
            predicted_means = _propagate_mean(
                matrix=time_dep.transition_matrix,
                mean=filtered_means,
                added_mean=time_dep.transition_mean)
            observation_means = _propagate_mean(
                matrix=time_dep.observation_matrix,
                mean=tf.concat(
                    [[time_indep.initial_mean], predicted_means[:-1]], axis=0),
                added_mean=time_dep.observation_mean)
            predicted_covs = _propagate_cov(matrix=time_dep.transition_matrix,
                                            cov=filtered_covs,
                                            added_cov=time_dep.transition_cov)
            observation_covs = _propagate_cov(
                matrix=time_dep.observation_matrix,
                cov=tf.concat([[time_indep.initial_cov], predicted_covs[:-1]],
                              axis=0),
                added_cov=time_dep.observation_cov)

            log_likelihoods = mvn_tril.MultivariateNormalTriL(
                loc=observation_means,
                scale_tril=tf.linalg.cholesky(observation_covs)).log_prob(
                    observation.y)
            if observation.mask is not None:
                log_likelihoods = tf.where(
                    observation.mask,
                    tf.zeros([], dtype=log_likelihoods.dtype), log_likelihoods)

        return FilterResults(log_likelihoods, filtered_means, filtered_covs,
                             predicted_means, predicted_covs,
                             observation_means, observation_covs)
def sample_walk(transition_matrix,
                transition_mean,
                transition_scale_tril,
                observation_matrix,
                observation_mean,
                observation_scale_tril,
                initial_mean,
                initial_scale_tril,
                seed=None):
    """Samples from the joint distribution of a linear Gaussian state-space model.

  This method draws samples from the joint prior distribution on latent and
  observed variables in a linear Gaussian state-space model. The sampling is
  parallelized over timesteps, so that sampling a sequence of length
  `num_timesteps` requires only `O(log(num_timesteps))` sequential steps.

  As with a naive sequential implementation, the total FLOP count scales
  linearly in `num_timesteps` (as `O(T + T/2 + T/4 + ...) = O(T)`), so this
  approach does not require extra resources in an asymptotic sense. However, it
  likely has a somewhat larger constant factor, so a sequential sampler
  may be preferred when throughput rather than latency is the highest priority.

  Args:
    transition_matrix: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, latent_size, latent_size]`.
    transition_mean: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, latent_size]`.
    transition_scale_tril: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, latent_size, latent_size]`.
    observation_matrix: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, observation_size, latent_size]`.
    observation_mean: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, observation_size]`.
    observation_scale_tril: float `Tensor` of shape
       `[num_timesteps, B1, .., BN, observation_size, observation_size]`.
    initial_mean: float `Tensor` of shape
       `[B1, .., BN, latent_size]`.
    initial_scale_tril: float `Tensor` of shape
       `[B1, .., BN, latent_size, latent_size]`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
  Returns:
    x: float `Tensor` of shape `[num_timesteps, B1, .., BN, latent_size]`.
    y: float `Tensor` of shape `[num_timesteps, B1, .., BN, observation_size]`.

  ### Mathematical Details

  The assumed model consists of latent state vectors
  `x[:num_timesteps, :latent_size]` and corresponding observed values
  `y[:num_timesteps, :observation_size]`, governed by the following dynamics:

  ```
  x[0] ~ MultivariateNormal(mean=initial_mean, scale_tril=initial_scale_tril)
  for t in range(num_timesteps - 1):
    x[t + 1] ~ MultivariateNormal(mean=matmul(transition_matrix[t],
                                              x[t]) + transition_mean[t],
                                  scale_tril=transition_scale_tril[t])
  # Observed values `y[:num_timesteps]` defined at all timesteps.
  y ~ MultivariateNormal(mean=matmul(observation_matrix, x) + observation_mean,
                         scale_tril=observation_scale_tril)
  ```

  ### Tensor layout

  `Tensor` arguments are expected to have `num_timesteps` as their *leftmost*
  axis, preceding any batch dimensions. This layout is used
  for internal computations, so providing arguments in this form avoids the
  need for potentially-spurious transposition. The returned `Tensor`s also
  follow this layout, for the same reason. Note that this differs from the
  layout mandated by the `tfd.Distribution`
  API (and exposed by `tfd.LinearGaussianStateSpaceModel`), in which the time
  axis is to the right of any batch dimensions; it is the caller's
  responsibility to perform any needed transpositions.

  Note that this method takes `scale_tril` matrices specifying the Cholesky
  factors of covariance matrices, in contrast to
  `tfp.experimental.parallel_filter.kalman_filter`, which takes the covariance
  matrices directly. This is to avoid redundant factorization, since the
  sampling process uses Cholesky factors natively, while the filtering updates
  we implement require covariance matrices. In addition, taking `scale_tril`
  matrices directly ensures that sampling is well-defined even when one or more
  components of the model are deterministic (`scale_tril=zeros([...])`).

  Tensor arguments may be specified with partial batch shape, i.e., with
  shape prefix `[num_timesteps, Bk, ..., BN]` for `k > 1`. They will be
  internally reshaped and broadcast to the full batch shape prefix
  `[num_timesteps, B1, ..., BN]`.

  """
    with tf.name_scope('sample_walk'):
        time_indep, time_dep, _ = broadcast_to_full_batch_shape(
            time_indep=TimeIndependentParameters(
                initial_cov=None,
                initial_scale_tril=initial_scale_tril,
                initial_mean=initial_mean),
            time_dep=TimeDependentParameters(
                transition_matrix=transition_matrix,
                transition_cov=None,
                transition_scale_tril=transition_scale_tril,
                transition_mean=transition_mean,
                observation_matrix=observation_matrix,
                observation_cov=None,
                observation_scale_tril=observation_scale_tril,
                observation_mean=observation_mean))

        s1, s2, s3 = samplers.split_seed(seed, n=3)
        updates = tfp_math.scan_associative(
            combine_walk,
            AffineUpdate(
                transition_matrix=time_dep.transition_matrix[:-1],
                mean=mvn_tril.MultivariateNormalTriL(
                    loc=time_dep.transition_mean[:-1],
                    scale_tril=time_dep.transition_scale_tril[:-1]).sample(
                        seed=s1)))
        x0 = mvn_tril.MultivariateNormalTriL(
            loc=time_indep.initial_mean,
            scale_tril=time_indep.initial_scale_tril).sample(seed=s2)

        x = tf.concat(
            [[x0],
             tf.linalg.matvec(updates.transition_matrix, x0) + updates.mean],
            axis=0)
        y = (tf.linalg.matvec(time_dep.observation_matrix, x) +
             time_dep.observation_mean + mvn_tril.MultivariateNormalTriL(
                 scale_tril=time_dep.observation_scale_tril).sample(seed=s3))
        return x, y