def _resample_latents(observed_residuals,
                      level_scale,
                      observation_noise_scale,
                      initial_state_prior,
                      slope_scale=None,
                      is_missing=None,
                      sample_shape=(),
                      seed=None):
    """Uses Durbin-Koopman sampling to resample the latent level and slope.

  Durbin-Koopman sampling [1] is an efficient algorithm to sample from the
  posterior latents of a linear Gaussian state space model. This method
  implements the algorithm.

  [1] Durbin, J. and Koopman, S.J. (2002) A simple and efficient simulation
      smoother for state space time series analysis.

  Args:
    observed_residuals: Float `Tensor` of shape `[..., num_observations]`,
      specifying the centered observations `(x - loc)`.
    level_scale: Float scalar `Tensor` (may contain batch dimensions) specifying
      the standard deviation of the level random walk steps.
    observation_noise_scale: Float scalar `Tensor` (may contain batch
      dimensions) specifying the standard deviation of the observation noise.
    initial_state_prior: instance of `tfd.MultivariateNormalLinearOperator`.
    slope_scale: Optional float scalar `Tensor` (may contain batch dimensions)
      specifying the standard deviation of slope random walk steps. If provided,
      a `LocalLinearTrend` model is used, otherwise, a `LocalLevel` model is
      used.
    is_missing: Optional `bool` `Tensor` missingness mask.
    sample_shape: Optional `int` `Tensor` shape of samples to draw.
    seed: `int` `Tensor` of shape `[2]` controlling stateless sampling.

  Returns:
    latents: Float `Tensor` resampled latent level, of shape
      `[..., num_timesteps, latent_size]`, where `...` concatenates the
      sample shape with any batch shape from `observed_time_series`.
  """

    num_timesteps = prefer_static.shape(observed_residuals)[-1]
    if slope_scale is None:
        ssm = sts.LocalLevelStateSpaceModel(
            num_timesteps=num_timesteps,
            initial_state_prior=initial_state_prior,
            observation_noise_scale=observation_noise_scale,
            level_scale=level_scale)
    else:
        ssm = sts.LocalLinearTrendStateSpaceModel(
            num_timesteps=num_timesteps,
            initial_state_prior=initial_state_prior,
            observation_noise_scale=observation_noise_scale,
            level_scale=level_scale,
            slope_scale=slope_scale)

    return ssm.posterior_sample(observed_residuals[..., tf.newaxis],
                                sample_shape=sample_shape,
                                mask=is_missing,
                                seed=seed)
  def resample_level(observed_residuals,
                     level_scale,
                     observation_noise_scale,
                     sample_shape=(),
                     seed=None):
    """Uses Durbin-Koopman sampling to resample the latent level.

    Durbin-Koopman sampling [1] is an efficient algorithm to sample from the
    posterior latents of a linear Gaussian state space model. This method
    implements the algorithm, specialized to the case of a one-dimensional
    latent local level model.

    [1] Durbin, J. and Koopman, S.J. (2002) A simple and efficient simulation
        smoother for state space time series analysis.

    Args:
      observed_residuals: Float `Tensor` of shape `[..., num_observations]`,
        specifying the centered observations `(x - loc)`.
      level_scale: Float scalar `Tensor` (may contain batch dimensions)
        specifying the standard deviation of the level random walk steps.
      observation_noise_scale: Float scalar `Tensor` (may contain batch
        dimensions) specifying the standard deviation of the observation noise.
      sample_shape: Optional `int` `Tensor` shape of samples to draw.
      seed: `int` `Tensor` of shape `[2]` controlling stateless sampling.
    Returns:
      level: Float `Tensor` resampled latent level, of shape
        `[..., num_timesteps]`, where `...` concatenates the sample shape
        with any batch shape from `observed_time_series`.
    """

    num_timesteps = prefer_static.shape(observed_residuals)[-1]
    ssm = sts.LocalLevelStateSpaceModel(
        num_timesteps=num_timesteps,
        initial_state_prior=initial_state_prior,
        observation_noise_scale=observation_noise_scale,
        level_scale=level_scale)
    return ssm.posterior_sample(observed_residuals[..., tf.newaxis],
                                sample_shape=sample_shape,
                                mask=is_missing,
                                seed=seed)[..., 0]