Ejemplo n.º 1
0
def _extract_log_probs(num_states, dist):
  """Tabulate log probabilities from a batch of distributions."""

  states = tf.reshape(tf.range(num_states),
                      tf.concat([[num_states],
                                 tf.ones_like(dist.batch_shape_tensor())],
                                axis=0))
  return util.move_dimension(dist.log_prob(states), 0, -1)
Ejemplo n.º 2
0
  def _log_prob(self, value):
    with tf.control_dependencies(self._runtime_assertions):
      # The argument `value` is a tensor of sequences of observations.
      # `observation_batch_shape` is the shape of that tensor with the
      # sequence part removed.
      # `observation_batch_shape` is then broadcast to the full batch shape
      # to give the `working_shape` that defines the shape of the result.

      observation_batch_shape = tf.shape(
          value)[:-1 - self._underlying_event_rank]
      # value :: observation_batch_shape num_steps observation_event_shape
      working_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())
      log_init = tf.broadcast_to(self._log_init,
                                 tf.concat([working_shape,
                                            [self._num_states]], axis=0))
      # log_init :: working_shape num_states
      log_transition = self._log_trans

      # `observation_event_shape` is the shape of each sequence of observations
      # emitted by the model.
      observation_event_shape = tf.shape(
          value)[-1 - self._underlying_event_rank:]
      working_obs = tf.broadcast_to(value,
                                    tf.concat([working_shape,
                                               observation_event_shape],
                                              axis=0))
      # working_obs :: working_shape observation_event_shape
      r = self._underlying_event_rank

      # Move index into sequence of observations to front so we can apply
      # tf.foldl
      working_obs = util.move_dimension(working_obs,
                                        -1 - r, 0)[..., tf.newaxis]
      # working_obs :: num_steps working_shape underlying_event_shape
      observation_probs = (
          self._observation_distribution.log_prob(working_obs))

      def forward_step(log_prev_step, log_observation):
        return _log_vector_matrix(log_prev_step,
                                  log_transition) + log_observation

      fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init)
      # fwd_prob :: working_shape num_states

      log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
      # log_prob :: working_shape

      return log_prob
Ejemplo n.º 3
0
def cholesky_update(chol, update_vector, multiplier=1., name=None):
    """Returns cholesky of chol @ chol.T + multiplier * u @ u.T.

  Given a (batch of) lower triangular cholesky factor(s) `chol`, along with a
  (batch of) vector(s) `update_vector`, compute the lower triangular cholesky
  factor of the rank-1 update `chol @ chol.T + multiplier * u @ u.T`, where
  `multiplier` is a (batch of) scalar(s).

  If `chol` has shape `[L, L]`, this has complexity `O(L^2)` compared to the
  naive algorithm which has complexity `O(L^3)`.

  Args:
    chol: Floating-point `Tensor` with shape `[B1, ..., Bn, L, L]`.
      Cholesky decomposition of `mat = chol @ chol.T`. Batch dimensions
      must be broadcastable with `update_vector` and `multiplier`.
    update_vector: Floating-point `Tensor` with shape `[B1, ... Bn, L]`. Vector
      defining rank-one update. Batch dimensions must be broadcastable with
      `chol` and `multiplier`.
    multiplier: Floating-point `Tensor` with shape `[B1, ..., Bn]. Scalar
      multiplier to rank-one update. Batch dimensions must be broadcastable
      with `chol` and `update_vector`. Note that updates where `multiplier` is
      positive are numerically stable, while when `multiplier` is negative
      (downdating), the update will only work if the new resulting matrix is
      still positive definite.
    name: Optional name for this op.

  #### References
  [1] Oswin Krause. Christian Igel. A More Efficient Rank-one Covariance
      Matrix Update for Evolution Strategies. 2015 ACM Conference.
      https://www.researchgate.net/publication/300581419_A_More_Efficient_Rank-one_Covariance_Matrix_Update_for_Evolution_Strategies
  """
    # TODO(b/154638092): Move this functionality in to TensorFlow.
    with tf.name_scope(name or 'cholesky_update'):
        dtype = dtype_util.common_dtype([chol, update_vector, multiplier],
                                        dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        update_vector = tf.convert_to_tensor(update_vector,
                                             name='update_vector',
                                             dtype=dtype)
        multiplier = tf.convert_to_tensor(multiplier,
                                          name='multiplier',
                                          dtype=dtype)

        batch_shape = prefer_static.broadcast_shape(
            prefer_static.broadcast_shape(
                tf.shape(chol)[:-2],
                tf.shape(update_vector)[:-1]), tf.shape(multiplier))
        chol = tf.broadcast_to(
            chol,
            prefer_static.concat(
                [batch_shape, tf.shape(chol)[-2:]], axis=0))
        update_vector = tf.broadcast_to(
            update_vector,
            prefer_static.concat(
                [batch_shape, tf.shape(update_vector)[-1:]], axis=0))
        multiplier = tf.broadcast_to(multiplier, batch_shape)

        chol_diag = tf.linalg.diag_part(chol)

        # The algorithm in [1] is implemented as a double for loop. We can treat
        # the inner loop in Algorithm 3.1 as a vector operation, and thus the
        # whole algorithm as a single for loop, and hence can use a `tf.scan`
        # on it.

        # We use for accumulation omega and b as defined in Algorithm 3.1, since
        # these are updated per iteration.

        def compute_new_column(accumulated_quantities, state):
            """Computes the next column of the updated cholesky."""
            _, _, omega, b = accumulated_quantities
            index, diagonal_member, col = state
            omega_at_index = tf.gather(omega, index, axis=-1)

            # Line 4
            new_diagonal_member = tf.math.sqrt(
                tf.math.square(diagonal_member) +
                multiplier / b * tf.math.square(omega_at_index))
            # `scaling_factor` is the same as `gamma` on Line 5.
            scaling_factor = (tf.math.square(diagonal_member) * b +
                              multiplier * tf.math.square(omega_at_index))

            # The following updates are the same as the for loop in lines 6-8.
            omega = omega - (omega_at_index /
                             diagonal_member)[..., tf.newaxis] * col
            new_col = new_diagonal_member[..., tf.newaxis] * (
                col / diagonal_member[..., tf.newaxis] +
                (multiplier * omega_at_index / scaling_factor)[..., tf.newaxis]
                * omega)
            b = b + multiplier * tf.math.square(
                omega_at_index / diagonal_member)
            return new_diagonal_member, new_col, omega, b

        # We will scan over the columns.
        chol = distribution_util.move_dimension(chol,
                                                source_idx=-1,
                                                dest_idx=0)
        chol_diag = distribution_util.move_dimension(chol_diag,
                                                     source_idx=-1,
                                                     dest_idx=0)

        new_diag, new_chol, _, _ = tf.scan(
            fn=compute_new_column,
            elems=(tf.range(0,
                            tf.shape(chol)[0]), chol_diag, chol),
            initializer=(tf.zeros_like(multiplier), tf.zeros_like(chol[0,
                                                                       ...]),
                         update_vector, tf.ones_like(multiplier)))
        new_chol = distribution_util.move_dimension(new_chol,
                                                    source_idx=0,
                                                    dest_idx=-1)
        new_diag = distribution_util.move_dimension(new_diag,
                                                    source_idx=0,
                                                    dest_idx=-1)
        new_chol = tf.linalg.set_diag(new_chol, new_diag)
        return new_chol
Ejemplo n.º 4
0
def decompose_forecast_by_component(model, forecast_dist, parameter_samples):
  """Decompose a forecast distribution into contributions from each component.

  Args:
    model: An instance of `tfp.sts.Sum` representing a structural time series
      model.
    forecast_dist: A `Distribution` instance returned by `tfp.sts.forecast()`.
      (specifically, must be a `tfd.MixtureSameFamily` over a
      `tfd.LinearGaussianStateSpaceModel` parameterized by posterior samples).
    parameter_samples: Python `list` of `Tensors` representing posterior samples
      of model parameters, with shapes `[concat([[num_posterior_draws],
      param.prior.batch_shape, param.prior.event_shape]) for param in
      model.parameters]`. This may optionally also be a map (Python `dict`) of
      parameter names to `Tensor` values.
  Returns:
    component_forecasts: A `collections.OrderedDict` instance mapping
      component StructuralTimeSeries instances (elements of `model.components`)
      to `tfd.Distribution` instances representing the marginal forecast for
      each component. Each distribution has batch and event shape matching
      `forecast_dist` (specifically, the event shape is
      `[num_steps_forecast]`).

  #### Examples

  Suppose we've built a model, fit it to data, and constructed a forecast
  distribution:

  ```python
    day_of_week = tfp.sts.Seasonal(
        num_seasons=7,
        observed_time_series=observed_time_series,
        name='day_of_week')
    local_linear_trend = tfp.sts.LocalLinearTrend(
        observed_time_series=observed_time_series,
        name='local_linear_trend')
    model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                        observed_time_series=observed_time_series)

    num_steps_forecast = 50
    samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)
    forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                 parameter_samples=samples,
                                 num_steps_forecast=num_steps_forecast)
  ```

  To extract the forecast for individual components, pass the forecast
  distribution into `decompose_forecast_by_components`:

  ```python
    component_forecasts = decompose_forecast_by_component(
      model, forecast_dist, samples)

    # Component mean and stddev have shape `[num_steps_forecast]`.
    day_of_week_effect_mean = forecast_components[day_of_week].mean()
    day_of_week_effect_stddev = forecast_components[day_of_week].stddev()
  ```

  Using the component forecasts, we can visualize the uncertainty for each
  component:

  ```
  from matplotlib import pylab as plt
  num_components = len(component_forecasts)
  xs = np.arange(num_steps_forecast)
  fig = plt.figure(figsize=(12, 3 * num_components))
  for i, (component, component_dist) in enumerate(component_forecasts.items()):

    # If in graph mode, replace `.numpy()` with `.eval()` or `sess.run()`.
    component_mean = component_dist.mean().numpy()
    component_stddev = component_dist.stddev().numpy()

    ax = fig.add_subplot(num_components, 1, 1 + i)
    ax.plot(xs, component_mean, lw=2)
    ax.fill_between(xs,
                    component_mean - 2 * component_stddev,
                    component_mean + 2 * component_stddev,
                    alpha=0.5)
    ax.set_title(component.name)
  ```

  """

  with tf.name_scope('decompose_forecast_by_component'):
    try:
      forecast_lgssm = forecast_dist.components_distribution
      forecast_latent_mean, _ = forecast_lgssm._joint_mean()  # pylint: disable=protected-access
      forecast_latent_covs, _ = forecast_lgssm._joint_covariances()  # pylint: disable=protected-access
    except AttributeError as e:
      raise ValueError(
          'Forecast distribution must be a MixtureSameFamily of'
          'LinearGaussianStateSpaceModel distributions, such as returned by'
          '`tfp.sts.forecast()`. (saw exception: {})'.format(e))

    # Since `parameter_samples` will have sample shape `[num_posterior_draws]`,
    # we need to move the `num_posterior_draws` dimension of the forecast
    # moments from the trailing batch dimension, where it's currently put by
    # `sts.forecast`, back to the leading (sample shape) dimension.
    forecast_latent_mean = dist_util.move_dimension(
        forecast_latent_mean, source_idx=-3, dest_idx=0)
    forecast_latent_covs = dist_util.move_dimension(
        forecast_latent_covs, source_idx=-4, dest_idx=0)
    return _decompose_from_posterior_marginals(
        model, forecast_latent_mean, forecast_latent_covs, parameter_samples,
        initial_step=forecast_lgssm.initial_step)
Ejemplo n.º 5
0
    def posterior_marginals(self, observations):
        """Compute marginal posterior distribution for each state.

    This function computes, for each time step, the marginal
    conditional probability that the hidden Markov model was in
    each possible state given the observations that were made
    at each time step.
    So if the hidden states are `z[0],...,z[num_steps - 1]` and
    the observations are `x[0],...,x[num_steps - 1]`, then
    this function computes `P(z[i] | x[0],...,x[num_steps - 1])`
    for all `i` from `0` to `num_steps-1`.

    This operation is sometimes called smoothing. It uses a form
    of the forward-backward algorithm.

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Args:
      observations: A tensor representing a batch of observations
      made on the hidden Markov model.  The rightmost dimension
      of this tensor gives the steps in a sequence of observations
      from a single sample from the hidden Markov model. The size
      of this dimension should match the `num_steps` parameter
      of the hidden Markov model object. The other dimensions are
      the dimensions of the batch and these are broadcast with
      the hidden Markov model's parameters.

    Returns:
      A `Categorical` distribution object representing the marginal
      probability of the hidden Markov model being in each state at
      each step. The rightmost dimension of the `Categorical`
      distributions batch will equal the `num_steps` parameter
      providing one marginal distribution for each step. The
      other dimensions are the dimensions corresponding to the
      batch of observations.

    Raises:
      ValueError: if rightmost dimension of `observations` does not
      have size `num_steps`.
    """

        with tf.name_scope("posterior_marginals", values=[observations]):
            with tf.control_dependencies(self._runtime_assertions):

                observation_tensor_shape = tf.shape(input=observations)

                with tf.control_dependencies([
                        tf.compat.v1.assert_equal(
                            observation_tensor_shape[-1],
                            self._num_steps,
                            message=
                            "Last dimension of `observations` must match `num_steps`"
                            "of `HiddenMarkovModel`")
                ]):
                    observation_batch_shape = observation_tensor_shape[:-1 -
                                                                       self.
                                                                       _underlying_event_rank]
                    observation_event_shape = observation_tensor_shape[
                        -1 - self._underlying_event_rank:]

                    working_shape = tf.broadcast_dynamic_shape(
                        observation_batch_shape, self.batch_shape_tensor())
                    log_init = tf.broadcast_to(
                        self._log_init,
                        tf.concat([working_shape, [self._num_states]], axis=0))
                    log_transition = self._log_trans

                    observations = tf.broadcast_to(
                        observations,
                        tf.concat([working_shape, observation_event_shape],
                                  axis=0))
                    observation_rank = tf.rank(observations)
                    underlying_event_rank = self._underlying_event_rank
                    observations = util.move_dimension(
                        observations,
                        observation_rank - underlying_event_rank - 1,
                        0)[..., tf.newaxis]
                    observation_log_probs = self._observation_distribution.log_prob(
                        observations)

                    log_adjoint_prob = tf.zeros_like(log_init)

                    def forward_step(log_previous_step, log_observation):
                        return _log_vector_matrix(
                            log_previous_step,
                            log_transition) + log_observation

                    log_prob = log_init + observation_log_probs[0]

                    forward_log_probs = tf.scan(forward_step,
                                                observation_log_probs[1:],
                                                initializer=log_prob,
                                                name="forward_log_probs")

                    forward_log_probs = tf.concat(
                        [[log_prob], forward_log_probs], axis=0)

                    def backward_step(log_previous_step, log_observation):
                        return _log_matrix_vector(
                            log_transition,
                            log_observation + log_previous_step)

                    backward_log_adjoint_probs = tf.scan(
                        backward_step,
                        observation_log_probs[1:],
                        initializer=log_adjoint_prob,
                        reverse=True,
                        name="backward_log_adjoint_probs")

                    total_log_prob = tf.reduce_logsumexp(
                        input_tensor=forward_log_probs[-1], axis=-1)

                    backward_log_adjoint_probs = tf.concat(
                        [backward_log_adjoint_probs, [log_adjoint_prob]],
                        axis=0)

                    log_likelihoods = forward_log_probs + backward_log_adjoint_probs

                    marginal_log_probs = util.move_dimension(
                        log_likelihoods - total_log_prob[..., tf.newaxis], 0,
                        -2)

                    return categorical.Categorical(logits=marginal_log_probs)
Ejemplo n.º 6
0
def forecast(model,
             observed_time_series,
             parameter_samples,
             num_steps_forecast,
             include_observation_noise=True):
  """Construct predictive distribution over future observations.

  Given samples from the posterior over parameters, return the predictive
  distribution over future observations for num_steps_forecast timesteps.

  Args:
    model: An instance of `StructuralTimeSeries` representing a
      time-series model. This represents a joint distribution over
      time-series and their parameters with batch shape `[b1, ..., bN]`.
    observed_time_series: `float` `Tensor` of shape
      `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where
      `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]`
      dimension may (optionally) be omitted if `num_timesteps > 1`. May
      optionally be an instance of `tfp.sts.MaskedTimeSeries` including a
      mask `Tensor` to encode the locations of missing observations.
    parameter_samples: Python `list` of `Tensors` representing posterior samples
      of model parameters, with shapes `[concat([[num_posterior_draws],
      param.prior.batch_shape, param.prior.event_shape]) for param in
      model.parameters]`. This may optionally also be a map (Python `dict`) of
      parameter names to `Tensor` values.
    num_steps_forecast: scalar `int` `Tensor` number of steps to forecast.
    include_observation_noise: Python `bool` indicating whether the forecast
      distribution should include uncertainty from observation noise. If `True`,
      the forecast is over future observations, if `False`, the forecast is over
      future values of the latent noise-free time series.
      Default value: `True`.

  Returns:
    forecast_dist: a `tfd.MixtureSameFamily` instance with event shape
      [num_steps_forecast, 1] and batch shape
      `concat([sample_shape, model.batch_shape])`, with `num_posterior_draws`
      mixture components.

  #### Examples

  Suppose we've built a model and fit it to data using HMC:

  ```python
    day_of_week = tfp.sts.Seasonal(
        num_seasons=7,
        observed_time_series=observed_time_series,
        name='day_of_week')
    local_linear_trend = tfp.sts.LocalLinearTrend(
        observed_time_series=observed_time_series,
        name='local_linear_trend')
    model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                        observed_time_series=observed_time_series)

    samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)
  ```

  Passing the posterior samples into `forecast`, we construct a forecast
  distribution:

  ```python
    forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                     parameter_samples=samples,
                                     num_steps_forecast=50)

    forecast_mean = forecast_dist.mean()[..., 0]  # shape: [50]
    forecast_scale = forecast_dist.stddev()[..., 0]  # shape: [50]
    forecast_samples = forecast_dist.sample(10)[..., 0]  # shape: [10, 50]
  ```

  If using variational inference instead of HMC, we'd construct a forecast using
  samples from the variational posterior:

  ```python
    surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(
      model=model)
    loss_curve = tfp.vi.fit_surrogate_posterior(
      target_log_prob_fn=model.joint_log_prob(observed_time_series),
      surrogate_posterior=surrogate_posterior,
      optimizer=tf.optimizers.Adam(learning_rate=0.1),
      num_steps=200)
    samples = surrogate_posterior.sample(30)

    forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                     parameter_samples=samples,
                                     num_steps_forecast=50)
  ```

  We can visualize the forecast by plotting:

  ```python
    from matplotlib import pylab as plt
    def plot_forecast(observed_time_series,
                      forecast_mean,
                      forecast_scale,
                      forecast_samples):
      plt.figure(figsize=(12, 6))

      num_steps = observed_time_series.shape[-1]
      num_steps_forecast = forecast_mean.shape[-1]
      num_steps_train = num_steps - num_steps_forecast

      c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05)
      plt.plot(np.arange(num_steps), observed_time_series,
               lw=2, color=c1, label='ground truth')

      forecast_steps = np.arange(num_steps_train,
                       num_steps_train+num_steps_forecast)
      plt.plot(forecast_steps, forecast_samples.T, lw=1, color=c2, alpha=0.1)
      plt.plot(forecast_steps, forecast_mean, lw=2, ls='--', color=c2,
               label='forecast')
      plt.fill_between(forecast_steps,
                       forecast_mean - 2 * forecast_scale,
                       forecast_mean + 2 * forecast_scale, color=c2, alpha=0.2)

      plt.xlim([0, num_steps])
      plt.legend()

    plot_forecast(observed_time_series,
                  forecast_mean=forecast_mean,
                  forecast_scale=forecast_scale,
                  forecast_samples=forecast_samples)
  ```

  """

  with tf.name_scope('forecast'):
    [
        observed_time_series,
        mask
    ] = sts_util.canonicalize_observed_time_series_with_mask(
        observed_time_series)

    # Run filtering over the observed timesteps to extract the
    # latent state posterior at timestep T+1 (i.e., the final
    # filtering distribution, pushed through the transition model).
    # This is the prior for the forecast model ("today's prior
    # is yesterday's posterior").
    num_observed_steps = dist_util.prefer_static_value(
        tf.shape(observed_time_series))[-2]
    observed_data_ssm = model.make_state_space_model(
        num_timesteps=num_observed_steps, param_vals=parameter_samples)
    (_, _, _, predictive_means, predictive_covs, _, _
    ) = observed_data_ssm.forward_filter(observed_time_series, mask=mask)

    # Build a batch of state-space models over the forecast period. Because
    # we'll use MixtureSameFamily to mix over the posterior draws, we need to
    # do some shenanigans to move the `[num_posterior_draws]` batch dimension
    # from the leftmost to the rightmost side of the model's batch shape.
    # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
    # arbitrary axis, and eliminate `move_dimension` calls here.
    parameter_samples = model._canonicalize_param_vals_as_map(parameter_samples)  # pylint: disable=protected-access
    parameter_samples_with_reordered_batch_dimension = {
        param.name: dist_util.move_dimension(
            parameter_samples[param.name],
            0, -(1 + _prefer_static_event_ndims(param.prior)))
        for param in model.parameters}
    forecast_prior = tfd.MultivariateNormalFullCovariance(
        loc=dist_util.move_dimension(predictive_means[..., -1, :], 0, -2),
        covariance_matrix=dist_util.move_dimension(
            predictive_covs[..., -1, :, :], 0, -3))

    # Ugly hack: because we moved `num_posterior_draws` to the trailing (rather
    # than leading) dimension of parameters, the parameter batch shapes no
    # longer broadcast against the `constant_offset` attribute used in `sts.Sum`
    # models. We fix this by manually adding an extra broadcasting dim to
    # `constant_offset` if present.
    # The root cause of this hack is that we mucked with param dimensions above
    # and are now passing params that are 'invalid' in the sense that they don't
    # match the shapes of the model's param priors. The fix (as above) will be
    # to update MixtureSameFamily so we can avoid changing param dimensions
    # altogether.
    # TODO(b/120245392): enhance `MixtureSameFamily` to reduce along an
    # arbitrary axis, and eliminate this hack.
    kwargs = {}
    if hasattr(model, 'constant_offset'):
      kwargs['constant_offset'] = tf.convert_to_tensor(
          value=model.constant_offset,
          dtype=forecast_prior.dtype)[..., tf.newaxis, :]

    if not include_observation_noise:
      parameter_samples_with_reordered_batch_dimension[
          'observation_noise_scale'] = tf.zeros_like(
              parameter_samples_with_reordered_batch_dimension[
                  'observation_noise_scale'])

    # We assume that any STS model that has a `constant_offset` attribute
    # will allow it to be overridden as a kwarg. This is currently just
    # `sts.Sum`.
    # TODO(b/120245392): when kwargs hack is removed, switch back to calling
    # the public version of `_make_state_space_model`.
    forecast_ssm = model._make_state_space_model(  # pylint: disable=protected-access
        num_timesteps=num_steps_forecast,
        param_map=parameter_samples_with_reordered_batch_dimension,
        initial_state_prior=forecast_prior,
        initial_step=num_observed_steps,
        **kwargs)

    num_posterior_draws = dist_util.prefer_static_value(
        forecast_ssm.batch_shape_tensor())[-1]
    return tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            logits=tf.zeros([num_posterior_draws], dtype=forecast_ssm.dtype)),
        components_distribution=forecast_ssm)
Ejemplo n.º 7
0
    def _sample_n(self, n, seed=None):
        init_seed, scan_seed, observation_seed = samplers.split_seed(
            seed, n=3, salt='HiddenMarkovModel')

        transition_batch_shape = self.transition_distribution.batch_shape_tensor(
        )
        num_states = transition_batch_shape[-1]

        batch_shape = self.batch_shape_tensor()
        batch_size = tf.reduce_prod(batch_shape)

        # The batch sizes of the underlying initial distributions and
        # transition distributions might not match the batch size of
        # the HMM distribution.
        # As a result we need to ask for more samples from the
        # underlying distributions and then reshape the results into
        # the correct batch size for the HMM.
        init_repeat = (
            tf.reduce_prod(batch_shape) //
            tf.reduce_prod(self._initial_distribution.batch_shape_tensor()))
        init_state = self._initial_distribution.sample(n * init_repeat,
                                                       seed=init_seed)
        init_state = tf.reshape(init_state, [n, batch_size])
        # init_state :: n batch_size

        transition_repeat = (tf.reduce_prod(batch_shape) //
                             tf.reduce_prod(transition_batch_shape[:-1]))

        init_shape = init_state.shape

        def generate_step(state_and_seed, _):
            """Take a single step in Markov chain."""
            state, seed = state_and_seed
            sample_seed, next_seed = samplers.split_seed(seed)

            gen = self._transition_distribution.sample(n * transition_repeat,
                                                       seed=sample_seed)
            # gen :: (n * transition_repeat) transition_batch

            new_states = tf.reshape(gen, [n, batch_size, num_states])

            # new_states :: n batch_size num_states

            old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32)

            # old_states :: n batch_size num_states

            result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1)
            # We know that `generate_step` must preserve the shape of the
            # tensor of states of each state. This is because
            # the transition matrix must be square. But TensorFlow might
            # not know this so we explicitly tell it that the result has the
            # same shape.
            tensorshape_util.set_shape(result, init_shape)
            return result, next_seed

        def _scan_multiple_steps():
            """Take multiple steps with tf.scan."""
            dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
            hidden_states, _ = tf.scan(generate_step,
                                       dummy_index,
                                       initializer=(init_state, scan_seed))

            # TODO(b/115618503): add/use prepend_initializer to tf.scan
            return tf.concat([[init_state], hidden_states], axis=0)

        hidden_states = prefer_static.cond(self._num_steps > 1,
                                           _scan_multiple_steps,
                                           lambda: init_state[tf.newaxis, ...])

        hidden_one_hot = tf.one_hot(hidden_states,
                                    num_states,
                                    dtype=self._observation_distribution.dtype)
        # hidden_one_hot :: num_steps n batch_size num_states

        # The observation distribution batch size might not match
        # the required batch size so as with the initial and
        # transition distributions we generate more samples and
        # reshape.
        observation_repeat = (batch_size // tf.reduce_prod(
            self._observation_distribution.batch_shape_tensor()[:-1]))

        possible_observations = self._observation_distribution.sample(
            [self._num_steps, observation_repeat * n], seed=observation_seed)

        inner_shape = self._observation_distribution.event_shape_tensor()

        # possible_observations :: num_steps (observation_repeat * n)
        #                          observation_batch[:-1] num_states inner_shape

        possible_observations = tf.reshape(
            possible_observations,
            tf.concat(
                [[self._num_steps, n], batch_shape, [num_states], inner_shape],
                axis=0))

        # possible_observations :: steps n batch_size num_states inner_shape

        hidden_one_hot = tf.reshape(
            hidden_one_hot,
            tf.concat([[self._num_steps, n], batch_shape, [num_states],
                       tf.ones_like(inner_shape)],
                      axis=0))

        # hidden_one_hot :: steps n batch_size num_states "inner_shape"

        observations = tf.reduce_sum(hidden_one_hot * possible_observations,
                                     axis=-1 - tf.size(inner_shape))

        # observations :: steps n batch_size inner_shape

        observations = distribution_util.move_dimension(
            observations, 0, 1 + tf.size(batch_shape))

        # returned :: n batch_shape steps inner_shape

        return observations
    def posterior_marginals(self, observations, mask=None, name=None):
        """Compute marginal posterior distribution for each state.

    This function computes, for each time step, the marginal
    conditional probability that the hidden Markov model was in
    each possible state given the observations that were made
    at each time step.
    So if the hidden states are `z[0],...,z[num_steps - 1]` and
    the observations are `x[0], ..., x[num_steps - 1]`, then
    this function computes `P(z[i] | x[0], ..., x[num_steps - 1])`
    for all `i` from `0` to `num_steps - 1`.

    This operation is sometimes called smoothing. It uses a form
    of the forward-backward algorithm.

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Args:
      observations: A tensor representing a batch of observations
        made on the hidden Markov model.  The rightmost dimension of this tensor
        gives the steps in a sequence of observations from a single sample from
        the hidden Markov model. The size of this dimension should match the
        `num_steps` parameter of the hidden Markov model object. The other
        dimensions are the dimensions of the batch and these are broadcast with
        the hidden Markov model's parameters.
      mask: optional bool-type `tensor` with rightmost dimension matching
        `num_steps` indicating which observations the result of this
        function should be conditioned on. When the mask has value
        `True` the corresponding observations aren't used.
        if `mask` is `None` then all of the observations are used.
        the `mask` dimensions left of the last are broadcast with the
        hmm batch as well as with the observations.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_marginal: A `Categorical` distribution object representing the
        marginal probability of the hidden Markov model being in each state at
        each step. The rightmost dimension of the `Categorical` distributions
        batch will equal the `num_steps` parameter providing one marginal
        distribution for each step. The other dimensions are the dimensions
        corresponding to the batch of observations.

    Raises:
      ValueError: if rightmost dimension of `observations` does not
      have size `num_steps`.
    """

        with tf.name_scope(name or "posterior_marginals"):
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(observations)
                mask_tensor_shape = tf.shape(
                    mask) if mask is not None else None

                with self._observation_mask_shape_preconditions(
                        observation_tensor_shape, mask_tensor_shape):
                    observation_log_probs = self._observation_log_probs(
                        observations, mask)
                    log_prob = self._log_init + observation_log_probs[0]
                    log_transition = self._log_trans
                    log_adjoint_prob = tf.zeros_like(log_prob)

                    def _scan_multiple_steps_forwards():
                        def forward_step(log_previous_step,
                                         log_prob_observation):
                            return _log_vector_matrix(
                                log_previous_step,
                                log_transition) + log_prob_observation

                        forward_log_probs = tf.scan(forward_step,
                                                    observation_log_probs[1:],
                                                    initializer=log_prob,
                                                    name="forward_log_probs")
                        return tf.concat([[log_prob], forward_log_probs],
                                         axis=0)

                    forward_log_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_forwards,
                        lambda: tf.convert_to_tensor([log_prob]))

                    total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1],
                                                         axis=-1)

                    def _scan_multiple_steps_backwards():
                        """Perform `scan` operation when `num_steps` > 1."""
                        def backward_step(log_previous_step,
                                          log_prob_observation):
                            return _log_matrix_vector(
                                log_transition,
                                log_prob_observation + log_previous_step)

                        backward_log_adjoint_probs = tf.scan(
                            backward_step,
                            observation_log_probs[1:],
                            initializer=log_adjoint_prob,
                            reverse=True,
                            name="backward_log_adjoint_probs")

                        return tf.concat(
                            [backward_log_adjoint_probs, [log_adjoint_prob]],
                            axis=0)

                    backward_log_adjoint_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_backwards,
                        lambda: tf.convert_to_tensor([log_adjoint_prob]))

                    log_likelihoods = forward_log_probs + backward_log_adjoint_probs

                    marginal_log_probs = distribution_util.move_dimension(
                        log_likelihoods - total_log_prob[..., tf.newaxis], 0,
                        -2)

                    return categorical.Categorical(logits=marginal_log_probs)
Ejemplo n.º 9
0
def infer_trajectories(observations,
                       initial_state_prior,
                       transition_fn,
                       observation_fn,
                       num_particles,
                       initial_state_proposal=None,
                       proposal_fn=None,
                       resample_criterion_fn=ess_below_threshold,
                       rejuvenation_kernel_fn=None,
                       num_transitions_per_observation=1,
                       num_steps_state_history_to_pass=None,
                       num_steps_observation_history_to_pass=None,
                       seed=None,
                       name=None):  # pylint: disable=g-doc-args
    """Use particle filtering to sample from the posterior over trajectories.

  ${particle_filter_arg_str}
  Returns:
    trajectories: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing unbiased samples from the posterior distribution
      `p(latent_states | observations)`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each timestep `t`. Note that
      (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}

  #### Examples

  **Tracking unknown position and velocity**: Let's consider tracking an object
  moving in a one-dimensional space. We'll define a dynamical system
  by specifying an `initial_state_prior`, a `transition_fn`,
  and `observation_fn`.

  The structure of the latent state space is determined by the prior
  distribution. Here, we'll define a state space that includes the object's
  current position and velocity:

  ```python
  initial_state_prior = tfd.JointDistributionNamed({
      'position': tfd.Normal(loc=0., scale=1.),
      'velocity': tfd.Normal(loc=0., scale=0.1)})
  ```

  The `transition_fn` specifies the evolution of the system. It should
  return a distribution over latent states of the same structure as the prior.
  Here, we'll assume that the position evolves according to the velocity,
  with a small random drift, and the velocity also changes slowly, following
  a random drift:

  ```python
  def transition_fn(_, previous_state):
    return tfd.JointDistributionNamed({
        'position': tfd.Normal(
            loc=previous_state['position'] + previous_state['velocity'],
            scale=0.1),
        'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)})
  ```

  The `observation_fn` specifies the process by which the system is observed
  at each time step. Let's suppose we observe only a noisy version of the =
  current position.

  ```python
    def observation_fn(_, state):
      return tfd.Normal(loc=state['position'], scale=0.1)
  ```

  Now let's track our object. Suppose we've been given observations
  corresponding to an initial position of `0.4` and constant velocity of `0.01`:

  ```python
  # Generate simulated observations.
  observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01),
                                  scale=0.1).sample()

  # Run particle filtering to sample plausible trajectories.
  (trajectories,  # {'position': [40, 1000], 'velocity': [40, 1000]}
   lps) = tfp.experimental.mcmc.infer_trajectories(
            observations=observed_positions,
            initial_state_prior=initial_state_prior,
            transition_fn=transition_fn,
            observation_fn=observation_fn,
            num_particles=1000)
  ```

  For all `i`, `trajectories['position'][:, i]` is a sample from the
  posterior over position sequences, given the observations:
  `p(state[0:T] | observations[0:T])`. Often, the sampled trajectories
  will be highly redundant in their earlier timesteps, because most
  of the initial particles have been discarded through resampling
  (this problem is known as 'particle degeneracy'; see section 3.5 of
  [Doucet and Johansen][1]).
  In such cases it may be useful to also consider the series of *filtering*
  distributions `p(state[t] | observations[:t])`, in which each latent state
  is inferred conditioned only on observations up to that point in time; these
  may be computed using `tfp.mcmc.experimental.particle_filter`.

  #### References

  [1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle
      filtering and smoothing: Fifteen years later.
      _Handbook of nonlinear filtering_, 12(656-704), 2009.
      https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf

  """
    with tf.name_scope(name or 'infer_trajectories') as name:
        seed = SeedStream(seed, 'infer_trajectories')
        (particles, log_weights, parent_indices,
         step_log_marginal_likelihoods) = particle_filter(
             observations=observations,
             initial_state_prior=initial_state_prior,
             transition_fn=transition_fn,
             observation_fn=observation_fn,
             num_particles=num_particles,
             initial_state_proposal=initial_state_proposal,
             proposal_fn=proposal_fn,
             resample_criterion_fn=resample_criterion_fn,
             rejuvenation_kernel_fn=rejuvenation_kernel_fn,
             num_transitions_per_observation=num_transitions_per_observation,
             num_steps_state_history_to_pass=num_steps_state_history_to_pass,
             num_steps_observation_history_to_pass=(
                 num_steps_observation_history_to_pass),
             seed=seed,
             name=name)
        weighted_trajectories = reconstruct_trajectories(
            particles, parent_indices)

        # Resample all steps of the trajectories using the final weights.
        resample_indices = categorical.Categorical(
            dist_util.move_dimension(log_weights[-1, ...],
                                     source_idx=0,
                                     dest_idx=-1)).sample(num_particles,
                                                          seed=seed)
        trajectories = tf.nest.map_structure(
            lambda x: _batch_gather(x, resample_indices, axis=1),
            weighted_trajectories)

        return trajectories, step_log_marginal_likelihoods
Ejemplo n.º 10
0
def resample_minimum_variance(log_probs,
                              event_size,
                              sample_shape,
                              seed=None,
                              name=None):
    """Minimum variance resampler for sequential Monte Carlo.

  This function is based on Algorithm #2 in [Maskell et al. (2006)][1].

  Args:
    log_probs: A tensor-valued batch of discrete log probability distributions.
    event_size: the dimension of the vector considered a single draw.
    sample_shape: the `sample_shape` determining the number of draws.
    seed: Python '`int` used to seed calls to `tf.random.*`.
      Default value: None (i.e. no seed).
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'resample_minimum_variance'`).

  Returns:
    resampled_indices: The result is similar to sampling with
    ```python
    expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1)
    tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)`
    ```
    but with values sorted along the first axis. It can be considered to be
    sampling events made up of a length-`event_size` vector of draws from
    the `Categorical` distribution. However, although the elements of
    this event have the appropriate marginal distribution, they are not
    independent of each other. Instead they have been chosen so as to form
    a good representative sample, suitable for use with Sequential Monte
    Carlo algorithms.
    The sortedness is an unintended side effect of the algorithm that is
    harmless in the context of simple SMC algorithms.

  #### References
  [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple
       Data Particle Filter.
       In 2006 IEEE Nonlinear Statistical Signal Processing Workshop.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf

  """
    with tf.name_scope(name or 'resample_minimum_variance') as name:
        log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
        log_probs = dist_util.move_dimension(log_probs,
                                             source_idx=0,
                                             dest_idx=-1)

        batch_shape = prefer_static.shape(log_probs)[:-1]
        working_shape = prefer_static.concat([sample_shape, batch_shape],
                                             axis=-1)
        log_cdf = tf.math.cumulative_logsumexp(log_probs[..., :-1], axis=-1)
        # Each resampling requires a single uniform random variable
        offset = uniform.Uniform(low=tf.constant(0., log_cdf.dtype),
                                 high=tf.constant(1., log_cdf.dtype)).sample(
                                     working_shape, seed=seed)[..., tf.newaxis]
        # It is possible for numerical error to result in a cumulative
        # sum that exceeds 1 so we need to clip.
        markers = prefer_static.cast(
            tf.floor(event_size * tf.math.exp(log_cdf) + offset), tf.int32)
        indices = markers[..., tf.newaxis]
        updates = tf.ones(prefer_static.shape(indices)[:-1], dtype=tf.int32)
        scatter_shape = prefer_static.concat([working_shape, [event_size + 1]],
                                             axis=-1)
        batch_dims = (prefer_static.rank_from_shape(sample_shape) +
                      prefer_static.rank_from_shape(batch_shape))
        x = _scatter_nd_batch(indices,
                              updates,
                              scatter_shape,
                              batch_dims=batch_dims)

        resampled = tf.cumsum(x, axis=-1)[..., :-1]
        resampled = dist_util.move_dimension(resampled,
                                             source_idx=-1,
                                             dest_idx=0)
        return resampled
    def test_basic_example_time_dependent_batched(self):
        batch_shape = (2, 3)
        ndim = 7  # Dimension of latent space
        mdim = 5  # Dimension of observation space
        nsteps = 9

        Batches = collections.namedtuple('Batches', [
            'initial_mean', 'initial_cov', 'transition_matrix',
            'transition_mean', 'transition_cov', 'observation_matrix',
            'observation_mean', 'observation_cov', 'mask'
        ])

        def batch_generator():
            # Skipping 'mask' case because it isn't used in sample generation.
            for skip in range(8):
                batch_list = skip * [()] + [batch_shape
                                            ] + (9 - skip - 1) * [()]
                yield Batches(*batch_list)

        # Test the broadcasting by ensuring each parameter individually
        # can be broadcast up to the full batch size.
        seed = test_util.test_seed(sampler_type='stateless')
        for batches in batch_generator():
            iter_seed, seed = samplers.split_seed(seed, n=2, salt='')
            s = samplers.split_seed(iter_seed, n=10, salt='')
            initial_mean = _random_vector(ndim,
                                          batches.initial_mean,
                                          dtype=self.dtype,
                                          seed=s[0])
            initial_cov = _random_variance(ndim,
                                           batches.initial_cov,
                                           dtype=self.dtype,
                                           seed=s[1])
            transition_matrix = 0.2 * _random_matrix(  # Avoid blowup (eigvals > 1).
                ndim,
                ndim, (nsteps, ) + batches.transition_matrix,
                dtype=self.dtype,
                seed=s[2])
            transition_mean = _random_vector(ndim, (nsteps, ) +
                                             batches.transition_mean,
                                             dtype=self.dtype,
                                             seed=s[3])
            transition_cov = _random_variance(ndim, (nsteps, ) +
                                              batches.transition_cov,
                                              dtype=self.dtype,
                                              seed=s[4])
            observation_matrix = _random_matrix(mdim,
                                                ndim, (nsteps, ) +
                                                batches.observation_matrix,
                                                dtype=self.dtype,
                                                seed=s[5])
            observation_mean = _random_vector(mdim, (nsteps, ) +
                                              batches.observation_mean,
                                              dtype=self.dtype,
                                              seed=s[6])
            observation_cov = _random_variance(mdim, (nsteps, ) +
                                               batches.observation_cov,
                                               dtype=self.dtype,
                                               seed=s[7])
            mask = _random_mask((nsteps, ) + batches.mask,
                                dtype=tf.bool,
                                seed=s[8])

            _, y = parallel_kalman_filter_lib.sample_walk(
                transition_matrix=transition_matrix,
                transition_mean=transition_mean,
                transition_scale_tril=tf.linalg.cholesky(transition_cov),
                observation_matrix=observation_matrix,
                observation_mean=observation_mean,
                observation_scale_tril=tf.linalg.cholesky(observation_cov),
                initial_mean=initial_mean,
                initial_scale_tril=tf.linalg.cholesky(initial_cov),
                seed=s[9])

            my_filter_results = parallel_kalman_filter_lib.kalman_filter(
                transition_matrix=transition_matrix,
                transition_mean=transition_mean,
                transition_cov=transition_cov,
                observation_matrix=observation_matrix,
                observation_mean=observation_mean,
                observation_cov=observation_cov,
                initial_mean=initial_mean,
                initial_cov=initial_cov,
                y=y,
                mask=mask)
            ((my_log_likelihoods, my_filtered_means, my_filtered_covs,
              my_predicted_means, my_predicted_covs, my_observation_means,
              my_observation_covs), y, mask) = tf.nest.map_structure(
                  lambda x, r: distribution_util.move_dimension(x, 0, -r),
                  (my_filter_results, y, mask),
                  (type(my_filter_results)(1, 2, 3, 2, 3, 2, 3), 2, 1))

            # pylint: disable=g-long-lambda,cell-var-from-loop
            mvn = tfd.MultivariateNormalFullCovariance
            dist = tfd.LinearGaussianStateSpaceModel(
                num_timesteps=nsteps,
                transition_matrix=lambda t: tf.linalg.LinearOperatorFullMatrix(
                    tf.gather(transition_matrix, t, axis=0)),
                transition_noise=lambda t: mvn(
                    loc=tf.gather(transition_mean, t, axis=0),
                    covariance_matrix=tf.gather(transition_cov, t, axis=0)),
                observation_matrix=lambda t: tf.linalg.
                LinearOperatorFullMatrix(
                    tf.gather(observation_matrix, t, axis=0)),
                observation_noise=lambda t: mvn(
                    loc=tf.gather(observation_mean, t, axis=0),
                    covariance_matrix=tf.gather(observation_cov, t, axis=0)),
                initial_state_prior=mvn(loc=initial_mean,
                                        covariance_matrix=initial_cov),
                experimental_parallelize=False
            )  # Compare against sequential filter.
            # pylint: enable=g-long-lambda,cell-var-from-loop

            (log_likelihoods, filtered_means, filtered_covs, predicted_means,
             predicted_covs, observation_means,
             observation_covs) = dist.forward_filter(y, mask)

            rtol = (1e-6 if self.dtype == np.float64 else 1e-1)
            atol = (1e-6 if self.dtype == np.float64 else 1e-3)
            self.assertAllClose(log_likelihoods,
                                my_log_likelihoods,
                                rtol=rtol,
                                atol=atol)

            rtol = (1e-6 if self.dtype == np.float64 else 1e-3)
            atol = (1e-6 if self.dtype == np.float64 else 1e-3)
            self.assertAllClose(filtered_means,
                                my_filtered_means,
                                rtol=rtol,
                                atol=atol)
            self.assertAllClose(filtered_covs,
                                my_filtered_covs,
                                rtol=rtol,
                                atol=atol)
            self.assertAllClose(predicted_means,
                                my_predicted_means,
                                rtol=rtol,
                                atol=atol)
            self.assertAllClose(predicted_covs,
                                my_predicted_covs,
                                rtol=rtol,
                                atol=atol)
            self.assertAllClose(observation_means,
                                my_observation_means,
                                rtol=rtol,
                                atol=atol)
            self.assertAllClose(observation_covs,
                                my_observation_covs,
                                rtol=rtol,
                                atol=atol)
Ejemplo n.º 12
0
def _sample_multinomial_as_iterated_binomial(num_samples, num_classes, probs,
                                             num_trials, dtype, seed):
    """Sample a multinomial by drawing one binomial sample per class.

  The batch shape is given by broadcasting num_trials with
  remove_last_dimension(probs).

  The loop over binomial samples is a `tf.while_loop`, thus supporting a dynamic
  number of classes.

  Args:
    num_samples: Singleton integer Tensor: number of multinomial samples to
      draw.
    num_classes: Singleton integer Tensor: number of classes.
    probs: Floating Tensor with last dimension `num_classes`, of normalized
      probabilities per class.
    num_trials: Tensor of number of categorical trials each multinomial consists
      of.  num_trials[..., tf.newaxis] must broadcast with probs.
    dtype: dtype at which to emit samples.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    samples: Tensor of given dtype and shape [num_samples] + batch_shape +
      [num_classes].
  """
    with tf.name_scope('draw_sample'):
        # `convert_to_tensor(num_classes) here to avoid unstacking inside
        # `split_seed`.  We can't take advantage of the Python-list code path anyway
        # because the index at which we will take the seed is a Tensor.
        seeds = samplers.split_seed(seed,
                                    n=ps.convert_to_shape_tensor(num_classes),
                                    salt='multinomial_draw_sample')

        def fn(i, num_trials, consumed_prob, accum):
            """Sample the counts for one class using binomial."""
            probs_here = tf.gather(probs, i, axis=-1)
            binomial_probs = tf.clip_by_value(
                probs_here / (1. - consumed_prob), 0, 1)
            seed_here = tf.gather(seeds, i, axis=0)
            binom = binomial.Binomial(total_count=num_trials,
                                      probs=binomial_probs)
            # Not passing `num_samples` to `binom.sample`, as it's is already in
            # `num_trials.shape`.
            sample = binom.sample(seed=seed_here)
            accum = accum.write(i, tf.cast(sample, dtype=dtype))
            return i + 1, num_trials - sample, consumed_prob + probs_here, accum

        num_trials = tf.cast(num_trials, probs.dtype)
        # Pre-broadcast with probs
        num_trials = num_trials + tf.zeros_like(probs[..., 0])
        # Pre-enlarge for different output samples
        num_trials = _replicate_along_left(num_trials, num_samples)
        i = tf.constant(0)
        consumed_prob = tf.zeros_like(probs[..., 0])
        accum = tf.TensorArray(dtype,
                               size=num_classes,
                               element_shape=num_trials.shape)
        _, num_trials_left, _, accum = tf.while_loop(
            cond=lambda index, _0, _1, _2: tf.less(index, num_classes - 1),
            body=fn,
            loop_vars=(i, num_trials, consumed_prob, accum))
        # Force the last iteration to put all the trials into the last bucket,
        # because probs[..., -1] / (1. - consumed_prob) might numerically not be 1.
        # Also saves one iteration around the while_loop and one run of the binomial
        # sampler.
        accum = accum.write(num_classes - 1,
                            tf.cast(num_trials_left, dtype=dtype))
        # This stop_gradient is necessary to prevent spurious zero gradients coming
        # from b/138796859, and a spurious gradient through num_trials_left.
        results = tf.stop_gradient(accum.stack())
        return distribution_util.move_dimension(results, 0, -1)
Ejemplo n.º 13
0
 def move_particles_to_rightmost_batch_dim(x, event_shape):
     ndims = prefer_static.rank_from_shape(prefer_static.shape(x))
     event_ndims = prefer_static.rank_from_shape(event_shape)
     return dist_util.move_dimension(x, 0, ndims - event_ndims - 1)
Ejemplo n.º 14
0
def one_step_predictive(model,
                        posterior_samples,
                        num_forecast_steps=0,
                        original_mean=0.,
                        original_scale=1.,
                        thin_every=10):
  """Constructs a one-step-ahead predictive distribution at every timestep.

  Unlike the generic `tfp.sts.one_step_predictive`, this method uses the
  latent levels from Gibbs sampling to efficiently construct a predictive
  distribution that mixes over posterior samples. The predictive distribution
  may also include additional forecast steps.

  This method returns the predictive distributions for each timestep given
  previous timesteps and sampled model parameters, `p(observed_time_series[t] |
  observed_time_series[:t], weights, observation_noise_scale)`. Note that the
  posterior values of the weights and noise scale will in general be informed
  by observations from all timesteps *including the step being predicted*, so
  this is not a strictly kosher probabilistic quantity, but in general we assume
  that it's close, i.e., that the step being predicted had very small individual
  impact on the overall parameter posterior.

  Args:
    model: A `tfd.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    posterior_samples: A `GibbsSamplerState` instance in which each element is a
      `Tensor` with initial dimension of size `num_samples`.
    num_forecast_steps: Python `int` number of additional forecast steps to
      append.
      Default value: `0`.
    original_mean: Optional scalar float `Tensor`, added to the predictive
      distribution to undo the effect of input normalization.
      Default value: `0.`
    original_scale: Optional scalar float `Tensor`, used to rescale the
      predictive distribution to undo the effect of input normalization.
      Default value: `1.`
    thin_every: Optional Python `int` factor by which to thin the posterior
      samples, to reduce complexity of the predictive distribution. For example,
      if `thin_every=10`, every `10`th sample will be used.
      Default value: `10`.
  Returns:
    predictive_dist: A `tfd.MixtureSameFamily` instance of event shape
      `[num_timesteps + num_forecast_steps]` representing the predictive
      distribution of each timestep given previous timesteps.
  """
  dtype = dtype_util.common_dtype([
      posterior_samples.level_scale,
      posterior_samples.observation_noise_scale,
      posterior_samples.level,
      original_mean,
      original_scale], dtype_hint=tf.float32)
  num_observed_steps = prefer_static.shape(posterior_samples.level)[-1]

  original_mean = tf.convert_to_tensor(original_mean, dtype=dtype)
  original_scale = tf.convert_to_tensor(original_scale, dtype=dtype)
  thinned_samples = tf.nest.map_structure(lambda x: x[::thin_every],
                                          posterior_samples)

  if prefer_static.rank_from_shape(  # If no slope was inferred, treat as zero.
      prefer_static.shape(thinned_samples.slope)) <= 1:
    thinned_samples = thinned_samples._replace(
        slope=tf.zeros_like(thinned_samples.level),
        slope_scale=tf.zeros_like(thinned_samples.level_scale))

  num_steps_from_last_observation = tf.concat([
      tf.ones([num_observed_steps], dtype=dtype),
      tf.range(1, num_forecast_steps + 1, dtype=dtype)], axis=0)

  # The local linear trend model expects that the level at step t + 1 is equal
  # to the level at step t, plus the slope at time t - 1,
  # plus transition noise of scale 'level_scale' (which we account for below).
  if num_forecast_steps > 0:
    num_batch_dims = prefer_static.rank_from_shape(
        prefer_static.shape(thinned_samples.level)) - 2
    # All else equal, the current level will remain stationary.
    forecast_level = tf.tile(thinned_samples.level[..., -1:],
                             tf.concat([tf.ones([num_batch_dims + 1],
                                                dtype=tf.int32),
                                        [num_forecast_steps]], axis=0))
    # If the model includes slope, the level will steadily increase.
    forecast_level += (thinned_samples.slope[..., -1:] *
                       tf.range(1., num_forecast_steps + 1.,
                                dtype=forecast_level.dtype))

  level_pred = tf.concat([thinned_samples.level[..., :1],  # t == 0
                          (thinned_samples.level[..., :-1] +
                           thinned_samples.slope[..., :-1])  # 1 <= t < T
                         ] + (
                             [forecast_level]
                             if num_forecast_steps > 0 else []),
                         axis=-1)

  design_matrix = _get_design_matrix(
      model).to_dense()[:num_observed_steps + num_forecast_steps]
  regression_effect = tf.linalg.matvec(design_matrix, thinned_samples.weights)

  y_mean = ((level_pred + regression_effect) *
            original_scale[..., tf.newaxis] + original_mean[..., tf.newaxis])

  # To derive a forecast variance, including slope uncertainty, let
  #  `r[:k]` be iid Gaussian RVs with variance `level_scale**2` and `s[:k]` be
  # iid Gaussian RVs with variance `slope_scale**2`. Then the forecast level at
  # step `T + k` can be written as
  #   (level[T] +           # Last known level.
  #    r[0] + ... + r[k] +  # Sum of random walk terms on level.
  #    slope[T] * k         # Contribution from last known slope.
  #    (k - 1) * s[0] +     # Contributions from random walk terms on slope.
  #    (k - 2) * s[1] +
  #    ... +
  #    1 * s[k - 1])
  # which has variance of
  #  (level_scale**2 * k +
  #   slope_scale**2 * ( (k - 1)**2 +
  #                      (k - 2)**2 +
  #                      ... + 1 ))
  # Here the `slope_scale` coefficient is the `k - 1`th square pyramidal
  # number [1], which is given by
  #  (k - 1) * k * (2 * k - 1) / 6.
  #
  # [1] https://en.wikipedia.org/wiki/Square_pyramidal_number
  variance_from_level = (thinned_samples.level_scale[..., tf.newaxis]**2 *
                         num_steps_from_last_observation)
  variance_from_slope = thinned_samples.slope_scale[..., tf.newaxis]**2 * (
      (num_steps_from_last_observation - 1) *
      num_steps_from_last_observation *
      (2 * num_steps_from_last_observation - 1)) / 6.
  y_scale = (original_scale * tf.sqrt(
      thinned_samples.observation_noise_scale[..., tf.newaxis]**2 +
      variance_from_level + variance_from_slope))

  num_posterior_draws = prefer_static.shape(y_mean)[0]
  return tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(
          logits=tf.zeros([num_posterior_draws], dtype=y_mean.dtype)),
      components_distribution=tfd.Normal(
          loc=dist_util.move_dimension(y_mean, 0, -1),
          scale=dist_util.move_dimension(y_scale, 0, -1)))
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
    def _observation_log_probs(self, observations, mask):
        """Compute and shape tensor of log probs associated with observations.."""

        # Let E be the underlying event shape
        #     M the number of steps in the HMM
        #     N the number of states of the HMM
        #
        # Then the incoming observations have shape
        #
        # observations : batch_o [M] E
        #
        # and the mask (if present) has shape
        #
        # mask : batch_m [M]
        #
        # Let this HMM distribution have batch shape batch_d
        # We need to broadcast all three of these batch shapes together
        # into the shape batch.
        #
        # We need to move the step dimension to the first dimension to make
        # them suitable for folding or scanning over.
        #
        # When we call `log_prob` for our observations we need to
        # do this for each state the observation could correspond to.
        # We do this by expanding the dimensions by 1 so we end up with:
        #
        # observations : [M] batch [1] [E]
        #
        # After calling `log_prob` we get
        #
        # observation_log_probs : [M] batch [N]
        #
        # We wish to use `mask` to select from this so we also
        # reshape and broadcast it up to shape
        #
        # mask : [M] batch [N]

        observation_tensor_shape = tf.shape(observations)
        observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                           _underlying_event_rank]
        observation_event_shape = observation_tensor_shape[
            -1 - self._underlying_event_rank:]

        if mask is not None:
            mask_tensor_shape = tf.shape(mask)
            mask_batch_shape = mask_tensor_shape[:-1]

        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())

        if mask is not None:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                                     mask_batch_shape)
        observations = tf.broadcast_to(
            observations,
            tf.concat([batch_shape, observation_event_shape], axis=0))
        observation_rank = tf.rank(observations)
        underlying_event_rank = self._underlying_event_rank
        observations = distribution_util.move_dimension(
            observations, observation_rank - underlying_event_rank - 1, 0)
        observations = tf.expand_dims(observations,
                                      observation_rank - underlying_event_rank)
        observation_log_probs = self._observation_distribution.log_prob(
            observations)

        if mask is not None:
            mask = tf.broadcast_to(
                mask, tf.concat([batch_shape, [self._num_steps]], axis=0))
            mask = distribution_util.move_dimension(mask, -1, 0)
            observation_log_probs = tf.where(
                mask[..., tf.newaxis], tf.zeros_like(observation_log_probs),
                observation_log_probs)

        return observation_log_probs
Ejemplo n.º 17
0
def resample_independent(log_probs,
                         event_size,
                         sample_shape,
                         seed=None,
                         name=None):
    """Categorical resampler for sequential Monte Carlo.

  This function is based on Algorithm #1 in the paper
  [Maskell et al. (2006)][1].

  Args:
    log_probs: A tensor-valued batch of discrete log probability distributions.
    event_size: the dimension of the vector considered a single draw.
    sample_shape: the `sample_shape` determining the number of draws.
    seed: Python '`int` used to seed calls to `tf.random.*`.
      Default value: None (i.e. no seed).
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'resample_independent'`).

  Returns:
    resampled_indices: The result is similar to sampling with
    ```python
    expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1)
    tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)`
    ```
    but with values sorted along the first axis. It can be considered to be
    sampling events made up of a length-`event_size` vector of draws from
    the `Categorical` distribution. For large input values this function should
    give better performance than using `Categorical`.
    The sortedness is an unintended side effect of the algorithm that is
    harmless in the context of simple SMC algorithms.

  #### References

  [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple
       Data Particle Filter.
       In 2006 IEEE Nonlinear Statistical Signal Processing Workshop.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf

  """
    with tf.name_scope(name or 'resample_independent') as name:
        log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
        log_probs = dist_util.move_dimension(log_probs,
                                             source_idx=0,
                                             dest_idx=-1)

        batch_shape = prefer_static.shape(log_probs)[:-1]
        num_markers = prefer_static.shape(log_probs)[-1]

        # `working_shape` specifies the total number of events
        # we will be generating.
        working_shape = prefer_static.concat([sample_shape, batch_shape],
                                             axis=0)
        # `points_shape` is the shape of the final result.
        points_shape = prefer_static.concat([working_shape, [event_size]],
                                            axis=0)
        # `markers_shape` is the shape of the markers we temporarily insert.
        markers_shape = prefer_static.concat([working_shape, [num_markers]],
                                             axis=0)
        # Generate one real point for each particle.
        log_points = -exponential.Exponential(
            rate=tf.constant(1.0, dtype=log_probs.dtype)).sample(points_shape,
                                                                 seed=seed)

        # We divide up the unit interval [0, 1] according to the provided
        # probability distributions using `cumsum`.
        # At the end of each division we place a 'marker'.
        # We generate random points on the unit interval.
        # We sort the combination of points and markers. The number
        # of points between the markers defining a division gives the number
        # of samples we require in that division.
        # For example, suppose `probs` is `[0.2, 0.3, 0.5]`.
        # We divide up `[0, 1]` using 3 markers:
        #
        #     |     |          |
        # 0.  0.2   0.5        1.0  <- markers
        #
        # Suppose we generate four points: [0.1, 0.25, 0.9, 0.75]
        # After sorting the combination we get:
        #
        # 0.1  0.25     0.75 0.9    <- points
        #  *  | *   |    *    *|
        # 0.   0.2 0.5         1.0  <- markers
        #
        # We have one sample in the first category, one in the second and
        # two in the last.
        #
        # All of these computations are carried out in batched form.
        markers = prefer_static.concat([
            tf.zeros(points_shape, dtype=tf.int32),
            tf.ones(markers_shape, dtype=tf.int32)
        ],
                                       axis=-1)
        log_marker_positions = tf.broadcast_to(
            tf.math.cumulative_logsumexp(log_probs, axis=-1), markers_shape)
        log_points_and_markers = prefer_static.concat(
            [log_points, log_marker_positions], axis=-1)
        indices = tf.argsort(log_points_and_markers, axis=-1, stable=False)
        sorted_markers = tf.gather_nd(
            markers,
            indices[..., tf.newaxis],
            batch_dims=(prefer_static.rank_from_shape(sample_shape) +
                        prefer_static.rank_from_shape(batch_shape)))
        markers_and_samples = prefer_static.cast(tf.cumsum(sorted_markers,
                                                           axis=-1),
                                                 dtype=tf.int32)
        markers_and_samples = tf.minimum(markers_and_samples, num_markers - 1)
        # Collect up samples, omitting markers.
        resampled = tf.reshape(
            markers_and_samples[tf.equal(sorted_markers, 0)], points_shape)
        resampled = dist_util.move_dimension(resampled,
                                             source_idx=-1,
                                             dest_idx=0)
        return resampled
Ejemplo n.º 18
0
  def _sample_n(self, n, seed=None):
    with tf.control_dependencies(self._runtime_assertions):
      seed = seed_stream.SeedStream(seed, salt="HiddenMarkovModel")

      num_states = self._num_states

      batch_shape = self.batch_shape_tensor()
      batch_size = tf.reduce_prod(batch_shape)

      # The batch sizes of the underlying initial distributions and
      # transition distributions might not match the batch size of
      # the HMM distribution.
      # As a result we need to ask for more samples from the
      # underlying distributions and then reshape the results into
      # the correct batch size for the HMM.
      init_repeat = (
          tf.reduce_prod(self.batch_shape_tensor()) //
          tf.reduce_prod(self._initial_distribution.batch_shape_tensor()))
      init_state = self._initial_distribution.sample(n * init_repeat,
                                                     seed=seed())
      init_state = tf.reshape(init_state, [n, batch_size])
      # init_state :: n batch_size

      transition_repeat = (
          tf.reduce_prod(self.batch_shape_tensor()) //
          tf.reduce_prod(
              self._transition_distribution.batch_shape_tensor()[:-1]))

      def generate_step(state, _):
        """Take a single step in Markov chain."""

        gen = self._transition_distribution.sample(n * transition_repeat,
                                                   seed=seed())
        # gen :: (n * transition_repeat) transition_batch

        new_states = tf.reshape(gen,
                                [n, batch_size, num_states])

        # new_states :: n batch_size num_states

        old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32)

        # old_states :: n batch_size num_states

        return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

      if self._num_steps > 1:
        dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
        hidden_states = tf.scan(generate_step, dummy_index,
                                initializer=init_state)

        # TODO(b/115618503): add/use prepend_initializer to tf.scan
        hidden_states = tf.concat([[init_state],
                                   hidden_states], axis=0)
      else:
        hidden_states = init_state[tf.newaxis, ...]

      # hidden_states :: num_steps n batch_size num_states

      hidden_one_hot = tf.one_hot(hidden_states, num_states,
                                  dtype=self._observation_distribution.dtype)
      # hidden_one_hot :: num_steps n batch_size num_states

      # The observation distribution batch size might not match
      # the required batch size so as with the initial and
      # transition distributions we generate more samples and
      # reshape.
      observation_repeat = (
          batch_size //
          tf.reduce_prod(
              self._observation_distribution.batch_shape_tensor()[:-1]))

      possible_observations = self._observation_distribution.sample(
          [self._num_steps, observation_repeat * n])

      inner_shape = self._observation_distribution.event_shape

      # possible_observations :: num_steps (observation_repeat * n)
      #                          observation_batch[:-1] num_states inner_shape

      possible_observations = tf.reshape(
          possible_observations,
          tf.concat([[self._num_steps, n],
                     batch_shape,
                     [num_states],
                     inner_shape], axis=0))

      # possible_observations :: steps n batch_size num_states inner_shape

      hidden_one_hot = tf.reshape(hidden_one_hot,
                                  tf.concat([[self._num_steps, n],
                                             batch_shape,
                                             [num_states],
                                             tf.ones_like(inner_shape)],
                                            axis=0))

      # hidden_one_hot :: steps n batch_size num_states "inner_shape"

      observations = tf.reduce_sum(hidden_one_hot * possible_observations,
                                   axis=-1 - tf.size(inner_shape))

      # observations :: steps n batch_size inner_shape

      observations = util.move_dimension(observations,
                                         0, 1 + tf.size(batch_shape))

      # returned :: n batch_shape steps inner_shape

      return observations
Ejemplo n.º 19
0
def bracket_root(objective_fn,
                 dtype=tf.float32,
                 num_points=512,
                 name='bracket_root'):
  """Finds bounds that bracket a root of the objective function.

  This method attempts to return an interval bracketing a root of the objective
  function. It evaluates the objective in parallel at `num_points`
  locations, at exponentially increasing distance from the origin, and returns
  the first pair of adjacent points `[low, high]` such that the objective is
  finite and has a different sign at the two points. If no such pair was
  observed, it returns the trivial interval
  `[np.finfo(dtype).min, np.finfo(dtype).max]` containing all float values of
  the specified `dtype`. If the objective has multiple
  roots, the returned interval will contain at least one (but perhaps not all)
  of the roots.

  Args:
    objective_fn: Python callable for which roots are searched. It must be a
      continuous function that accepts a scalar `Tensor` of type `dtype` and
      returns a `Tensor` of shape `batch_shape`.
    dtype: Optional float `dtype` of inputs to `objective_fn`.
      Default value: `tf.float32`.
    num_points: Optional Python `int` number of points at which to evaluate
      the objective.
      Default value: `512`.
    name: Python `str` name given to ops created by this method.
  Returns:
    low: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Lower bound
      on a root of `objective_fn`.
    high: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Upper bound
      on a root of `objective_fn`.
  """
  with tf.name_scope(name):
    # Build a logarithmic sequence of `num_points` values from -inf to inf.
    dtype_info = np.finfo(dtype_util.as_numpy_dtype(dtype))
    xs_positive = tf.exp(tf.linspace(tf.cast(-10., dtype),
                                     tf.math.log(dtype_info.max),
                                     num_points // 2))
    xs = tf.concat([tf.reverse(-xs_positive, axis=[0]), xs_positive], axis=0)

    # Evaluate the objective at all points. The objective function may return
    # a batch of values (e.g., `objective(x) = x - batch_of_roots`).
    if NUMPY_MODE:
      objective_output_spec = objective_fn(tf.zeros([], dtype=dtype))
    else:
      objective_output_spec = callable_util.get_output_spec(
          objective_fn,
          tf.convert_to_tensor(0., dtype=dtype))
    batch_ndims = tensorshape_util.rank(objective_output_spec.shape)
    if batch_ndims is None:
      raise ValueError('Cannot infer tensor rank of objective values.')
    xs_pad_shape = ps.pad([num_points],
                          paddings=[[0, batch_ndims]],
                          constant_values=1)
    ys = objective_fn(tf.reshape(xs, xs_pad_shape))

    # Find the smallest point where the objective is finite.
    is_finite = tf.math.is_finite(ys)
    ys_transposed = distribution_util.move_dimension(  # For batch gather.
        ys, 0, -1)
    first_finite_value = tf.gather(
        ys_transposed,
        tf.argmax(is_finite, axis=0),  # Index of smallest finite point.
        batch_dims=batch_ndims,
        axis=-1)
    # Select the next point where the objective has a different sign.
    sign_change_idx = tf.argmax(
        tf.not_equal(tf.math.sign(ys),
                     tf.math.sign(first_finite_value)) & is_finite,
        axis=0)
    # If the sign never changes, we can't bracket a root.
    bracketing_failed = tf.equal(sign_change_idx, 0)
    # If the objective's sign is zero, we've found an actual root.
    root_found = tf.equal(tf.gather(tf.math.sign(ys_transposed),
                                    sign_change_idx,
                                    batch_dims=batch_ndims,
                                    axis=-1),
                          0.)
    return _structure_broadcasting_where(
        bracketing_failed,
        # If we didn't detect a sign change, fall back to the trivial interval.
        (dtype_info.min, dtype_info.max),
        # Otherwise, return the points around the sign change, unless we
        # actually evaluated a root, in which case, return the zero-width
        # bracket at that root.
        (tf.gather(xs, tf.where(bracketing_failed | root_found,
                                sign_change_idx,
                                sign_change_idx - 1)),
         tf.gather(xs, sign_change_idx)))
Ejemplo n.º 20
0
def index_remapping_gather(params,
                           indices,
                           axis=0,
                           indices_axis=0,
                           name='index_remapping_gather'):
    """Gather values from `axis` of `params` using `indices_axis` of `indices`.

  The shape of `indices` must broadcast to that of `params` when
  their `indices_axis` and `axis` (respectively) are aligned:

  ```python
  # params.shape:
  [p[0],  ..., ...,         p[axis], ..., ..., p[rank(params)] - 1])
  # indices.shape:
        [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1])
  ```

  In particular, `params` must have at least as many
  leading dimensions as `indices` (`axis >= indices_axis`), and at least as many
  trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`).

  The `result` has the same shape as `params`, except that the dimension
  of size `p[axis]` is replaced by one of size `i[indices_axis]`:

  ```python
  # result.shape:
  [p[0],  ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]]
  ```

  In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and
  `indices_axis = 1`, the result is given by

   ```python
   # alignment is:                       v axis
   # params.shape    ==   [p[0], p[1], p[2], p[3], p[4]]
   # indices.shape   ==         [i[0], i[1], i[2]]
   #                                     ^ indices_axis
   result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m]
  ```

  Args:
    params:  `N-D` `Tensor` (`N > 0`) from which to gather values.
      Number of dimensions must be known statically.
    indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose
      shape broadcasts to that of `params` as described above.
    axis: Python `int` axis of `params` from which to gather.
    indices_axis: Python `int` axis of `indices` to align with the `axis`
      over which `params` is gathered.
    name: String name for scoping created ops.

  Returns:
    `Tensor` composed of elements of `params`.

  Raises:
    ValueError: If shape/rank requirements are not met.
  """
    with tf.name_scope(name):
        params = tf.convert_to_tensor(params, name='params')
        indices = tf.convert_to_tensor(indices, name='indices')

        params_ndims = tensorshape_util.rank(params.shape)
        indices_ndims = tensorshape_util.rank(indices.shape)
        # `axis` dtype must match ndims, which are 64-bit Python ints.
        axis = tf.get_static_value(tf.convert_to_tensor(axis, dtype=tf.int64))
        indices_axis = tf.get_static_value(
            tf.convert_to_tensor(indices_axis, dtype=tf.int64))

        if params_ndims is None:
            raise ValueError(
                'Rank of `params`, must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if axis is None:
            raise ValueError(
                '`axis` must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if indices_axis is None:
            raise ValueError(
                '`indices_axis` must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if indices_axis > axis:
            raise ValueError(
                '`indices_axis` should be <= `axis`, but was {} > {}'.format(
                    indices_axis, axis))

        if params_ndims < 1:
            raise ValueError(
                'Rank of params should be `> 0`, but was {}'.format(
                    params_ndims))

        if indices_ndims is not None and indices_ndims < 1:
            raise ValueError(
                'Rank of indices should be `> 0`, but was {}'.format(
                    indices_ndims))

        if (indices_ndims is not None
                and (indices_ndims - indices_axis > params_ndims - axis)):
            raise ValueError(
                '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - '
                'indices_axis` ({} - {}), but was not.'.format(
                    params_ndims, axis, indices_ndims, indices_axis))

        # `tf.gather` requires the axis to be the rightmost batch ndim. So, we
        # transpose `indices_axis` to be the rightmost dimension of `indices`...
        transposed_indices = dist_util.move_dimension(indices,
                                                      source_idx=indices_axis,
                                                      dest_idx=-1)

        # ... and `axis` to be the corresponding (aligned as in the docstring)
        # dimension of `params`.
        broadcast_indices_ndims = indices_ndims + (axis - indices_axis)
        transposed_params = dist_util.move_dimension(
            params, source_idx=axis, dest_idx=broadcast_indices_ndims - 1)

        # Next we broadcast `indices` so that its shape has the same prefix as
        # `params.shape`.
        transposed_params_shape = prefer_static.shape(transposed_params)
        result_shape = prefer_static.concat([
            transposed_params_shape[:broadcast_indices_ndims - 1],
            prefer_static.shape(indices)[indices_axis:indices_axis + 1],
            transposed_params_shape[broadcast_indices_ndims:]
        ],
                                            axis=0)
        broadcast_indices = prefer_static.broadcast_to(
            transposed_indices, result_shape[:broadcast_indices_ndims])

        result_t = tf.gather(transposed_params,
                             broadcast_indices,
                             batch_dims=broadcast_indices_ndims - 1,
                             axis=broadcast_indices_ndims - 1)
        return dist_util.move_dimension(result_t,
                                        source_idx=broadcast_indices_ndims - 1,
                                        dest_idx=axis)
Ejemplo n.º 21
0
  def __init__(self,
               num_timesteps,
               design_matrix,
               drift_scale,
               initial_state_prior,
               observation_noise_scale=0.,
               initial_step=0,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """State space model for a dynamic linear regression.

    Args:
      num_timesteps: Scalar `int` `Tensor` number of timesteps to model
        with this distribution.
      design_matrix: float `Tensor` of shape `concat([batch_shape,
        [num_timesteps, num_features]])`.
      drift_scale: Scalar (any additional dimensions are treated as batch
        dimensions) `float` `Tensor` indicating the standard deviation of the
        latent state transitions.
      initial_state_prior: instance of `tfd.MultivariateNormal`
        representing the prior distribution on latent states.  Must have
        event shape `[num_features]`.
      observation_noise_scale: Scalar (any additional dimensions are
        treated as batch dimensions) `float` `Tensor` indicating the standard
        deviation of the observation noise.
        Default value: `0.`.
      initial_step: scalar `int` `Tensor` specifying the starting timestep.
        Default value: `0`.
      validate_args: Python `bool`. Whether to validate input with asserts. If
        `validate_args` is `False`, and the inputs are invalid, correct behavior
        is not guaranteed.
        Default value: `False`.
      allow_nan_stats: Python `bool`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
        Default value: `True`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: 'DynamicLinearRegressionStateSpaceModel'.

    """

    with tf.name_scope(
        name or 'DynamicLinearRegressionStateSpaceModel') as name:
      dtype = dtype_util.common_dtype(
          [design_matrix, drift_scale, initial_state_prior])

      design_matrix = tf.convert_to_tensor(
          value=design_matrix, name='design_matrix', dtype=dtype)
      design_matrix_with_time_in_first_dim = distribution_util.move_dimension(
          design_matrix, -2, 0)

      drift_scale = tf.convert_to_tensor(
          value=drift_scale, name='drift_scale', dtype=dtype)

      observation_noise_scale = tf.convert_to_tensor(
          value=observation_noise_scale,
          name='observation_noise_scale',
          dtype=dtype)

      num_features = prefer_static.shape(design_matrix)[-1]

      def observation_matrix_fn(t):
        observation_matrix = tf.linalg.LinearOperatorFullMatrix(
            tf.gather(design_matrix_with_time_in_first_dim,
                      t)[..., tf.newaxis, :], name='observation_matrix')
        return observation_matrix

      self._drift_scale = drift_scale
      self._observation_noise_scale = observation_noise_scale

      super(DynamicLinearRegressionStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=tf.linalg.LinearOperatorIdentity(
              num_rows=num_features,
              dtype=dtype,
              name='transition_matrix'),
          transition_noise=tfd.MultivariateNormalDiag(
              scale_diag=(drift_scale[..., tf.newaxis] *
                          tf.ones([num_features], dtype=dtype)),
              name='transition_noise'),
          observation_matrix=observation_matrix_fn,
          observation_noise=tfd.MultivariateNormalDiag(
              scale_diag=observation_noise_scale[..., tf.newaxis],
              name='observation_noise'),
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          allow_nan_stats=allow_nan_stats,
          validate_args=validate_args,
          name=name)