def _build_sampler_loop_body(model, observed_time_series, is_missing=None):
    """Builds a Gibbs sampler for the given model and observed data.

  Args:
    model: A `tf.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    observed_time_series: Float `Tensor` time series of shape
      `[..., num_timesteps]`.
    is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A
      `True` value indicates that the observation for that timestep is missing.
  Returns:
    sampler_loop_body: Python callable that performs a single cycle of Gibbs
      sampling. Its first argument is a `GibbsSamplerState`, and it returns a
      new `GibbsSamplerState`. The second argument (passed by `tf.scan`) is
      ignored.
  """
    level_component = model.components[0]
    if not (isinstance(level_component, sts.LocalLevel)
            or isinstance(level_component, sts.LocalLinearTrend)):
        raise ValueError(
            'Expected the first model component to be an instance of '
            '`tfp.sts.LocalLevel` or `tfp.sts.LocalLinearTrend`; '
            'instead saw {}'.format(level_component))
    model_has_slope = isinstance(level_component, sts.LocalLinearTrend)

    # Require that the model has exactly the parameters expected by
    # `GibbsSamplerState`.
    if model_has_slope:
        (observation_noise_param, level_scale_param, slope_scale_param,
         weights_param) = model.parameters
    else:
        (observation_noise_param, level_scale_param,
         weights_param) = model.parameters
    # If all non-`slope_scale` parameters are as expected, by process of
    # elimination `slope_scale` must be correct as well.
    if (('observation_noise' not in observation_noise_param.name)
            or ('level_scale' not in level_scale_param.name)
            or ('weights' not in weights_param.name)):
        raise ValueError(
            'Model parameters {} do not match the expected sampler '
            'state.'.format(model.parameters))

    if is_missing is not None:  # Ensure series does not contain NaNs.
        observed_time_series = tf.where(is_missing,
                                        tf.zeros_like(observed_time_series),
                                        observed_time_series)

    num_observed_steps = prefer_static.shape(observed_time_series)[-1]
    design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps]

    # Untransform scale priors -> variance priors by reaching thru Sqrt bijector.
    level_scale_variance_prior = level_scale_param.prior.distribution
    if model_has_slope:
        slope_scale_variance_prior = slope_scale_param.prior.distribution
    observation_noise_variance_prior = observation_noise_param.prior.distribution

    def sampler_loop_body(previous_sample, _):
        """Runs one sampler iteration, resampling all model variables."""

        (weights_seed, level_seed, observation_noise_scale_seed,
         level_scale_seed,
         loop_seed) = samplers.split_seed(previous_sample.seed,
                                          n=5,
                                          salt='sampler_loop_body')
        # Preserve backward-compatible seed behavior by splitting slope separately.
        slope_scale_seed, = samplers.split_seed(previous_sample.seed,
                                                n=1,
                                                salt='sampler_loop_body_slope')

        # We encourage a reasonable initialization by sampling the weights first,
        # so at the first step they are regressed directly against the observed
        # time series. If we instead sampled the level first it might 'explain away'
        # some observed variation that we would ultimately prefer to explain through
        # the regression weights, because the level can represent arbitrary
        # variation, while the weights are limited to representing variation in the
        # subspace given by the design matrix.
        weights = _resample_weights(
            design_matrix=design_matrix,
            target_residuals=(observed_time_series - previous_sample.level),
            observation_noise_scale=previous_sample.observation_noise_scale,
            weights_prior_scale=weights_param.prior.distribution.scale,
            is_missing=is_missing,
            seed=weights_seed)

        regression_residuals = observed_time_series - tf.linalg.matvec(
            design_matrix, weights)
        latents = _resample_latents(
            observed_residuals=regression_residuals,
            level_scale=previous_sample.level_scale,
            slope_scale=previous_sample.slope_scale
            if model_has_slope else None,
            observation_noise_scale=previous_sample.observation_noise_scale,
            initial_state_prior=level_component.initial_state_prior,
            is_missing=is_missing,
            seed=level_seed)
        level = latents[..., 0]
        level_residuals = level[..., 1:] - level[..., :-1]
        if model_has_slope:
            slope = latents[..., 1]
            level_residuals -= slope[..., :-1]
            slope_residuals = slope[..., 1:] - slope[..., :-1]

        # Estimate level scale from the empirical changes in level.
        level_scale = _resample_scale(prior=level_scale_variance_prior,
                                      observed_residuals=level_residuals,
                                      is_missing=None,
                                      seed=level_scale_seed)
        if model_has_slope:
            slope_scale = _resample_scale(prior=slope_scale_variance_prior,
                                          observed_residuals=slope_residuals,
                                          is_missing=None,
                                          seed=slope_scale_seed)
        # Estimate noise scale from the residuals.
        observation_noise_scale = _resample_scale(
            prior=observation_noise_variance_prior,
            observed_residuals=regression_residuals - level,
            is_missing=is_missing,
            seed=observation_noise_scale_seed)

        return GibbsSamplerState(
            observation_noise_scale=observation_noise_scale,
            level_scale=level_scale,
            slope_scale=(slope_scale
                         if model_has_slope else previous_sample.slope_scale),
            weights=weights,
            level=level,
            slope=(slope if model_has_slope else previous_sample.slope),
            seed=loop_seed)

    return sampler_loop_body
Example #2
0
def fit_with_gibbs_sampling(model,
                            observed_time_series,
                            num_chains=(),
                            num_results=2000,
                            num_warmup_steps=200,
                            initial_state=None,
                            seed=None):
  """Fits parameters for an STS model using Gibbs sampling.

  Args:
    model: A `tfp.sts.StructuralTimeSeries` model instance return by
      `build_model_for_gibbs_fitting`.
    observed_time_series: `float` `Tensor` of shape [..., T, 1]`
      (omitting the trailing unit dimension is also supported when `T > 1`),
      specifying an observed time series. May optionally be an instance of
      `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify
      timesteps with missing observations.
    num_chains: Optional int to indicate the number of parallel MCMC chains.
      Default to an empty tuple to sample a single chain.
    num_results: Optional int to indicate number of MCMC samples.
    num_warmup_steps: Optional int to indicate number of MCMC samples.
    initial_state: A `GibbsSamplerState` structure of the initial states of the
      MCMC chains.
    seed: Optional `Python` `int` seed controlling the sampled values.
  Returns:
    model: A `GibbsSamplerState` structure of posterior samples.
  """
  if not hasattr(model, 'supports_gibbs_sampling'):
    raise ValueError('This STS model does not support Gibbs sampling. Models '
                     'for Gibbs sampling must be created using the '
                     'method `build_model_for_gibbs_fitting`.')
  if not tf.nest.is_nested(num_chains):
    num_chains = [num_chains]

  [
      observed_time_series,
      is_missing
  ] = sts_util.canonicalize_observed_time_series_with_mask(
      observed_time_series)
  dtype = observed_time_series.dtype

  # The canonicalized time series always has trailing dimension `1`,
  # because although LinearGaussianSSMs support vector observations, STS models
  # describe scalar time series only. For our purposes it'll be cleaner to
  # remove this dimension.
  observed_time_series = observed_time_series[..., 0]
  batch_shape = prefer_static.concat(
      [num_chains,
       prefer_static.shape(observed_time_series)[:-1]], axis=-1)
  level_slope_shape = prefer_static.concat(
      [num_chains, prefer_static.shape(observed_time_series)], axis=-1)

  # Treat a LocalLevel model as the special case of LocalLinearTrend where
  # the slope_scale is always zero.
  initial_slope_scale = 0.
  initial_slope = 0.
  if isinstance(model.components[0], sts.LocalLinearTrend):
    initial_slope_scale = 1. * tf.ones(batch_shape, dtype=dtype)
    initial_slope = tf.zeros(level_slope_shape, dtype=dtype)

  if initial_state is None:
    initial_state = GibbsSamplerState(
        observation_noise_scale=tf.ones(batch_shape, dtype=dtype),
        level_scale=tf.ones(batch_shape, dtype=dtype),
        slope_scale=initial_slope_scale,
        weights=tf.zeros(prefer_static.concat([
            batch_shape,
            _get_design_matrix(model).shape[-1:]], axis=0), dtype=dtype),
        level=tf.zeros(level_slope_shape, dtype=dtype),
        slope=initial_slope,
        seed=None)  # Set below.

  if isinstance(seed, six.integer_types):
    tf.random.set_seed(seed)

  # Always use the passed-in `seed` arg, ignoring any seed in the initial state.
  initial_state = initial_state._replace(
      seed=samplers.sanitize_seed(seed, salt='initial_GibbsSamplerState'))

  sampler_loop_body = _build_sampler_loop_body(model,
                                               observed_time_series,
                                               is_missing)

  samples = tf.scan(sampler_loop_body,
                    np.arange(num_warmup_steps + num_results),
                    initial_state)
  return tf.nest.map_structure(lambda x: x[num_warmup_steps:], samples)
Example #3
0
def _build_sampler_loop_body(model,
                             observed_time_series,
                             is_missing=None):
  """Builds a Gibbs sampler for the given model and observed data.

  Args:
    model: A `tf.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    observed_time_series: Float `Tensor` time series of shape
      `[..., num_timesteps]`.
    is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A
      `True` value indicates that the observation for that timestep is missing.
  Returns:
    sampler_loop_body: Python callable that performs a single cycle of Gibbs
      sampling. Its first argument is a `GibbsSamplerState`, and it returns a
      new `GibbsSamplerState`. The second argument (passed by `tf.scan`) is
      ignored.
  """
  level_component = model.components[0]
  if not (isinstance(level_component, sts.LocalLevel) or
          isinstance(level_component, sts.LocalLinearTrend)):
    raise ValueError('Expected the first model component to be an instance of '
                     '`tfp.sts.LocalLevel` or `tfp.sts.LocalLinearTrend`; '
                     'instead saw {}'.format(level_component))
  model_has_slope = isinstance(level_component, sts.LocalLinearTrend)

  regression_component = model.components[1]
  if not (isinstance(regression_component, sts.LinearRegression) or
          isinstance(regression_component, SpikeAndSlabSparseLinearRegression)):
    raise ValueError('Expected the second model component to be an instance of '
                     '`tfp.sts.LinearRegression` or '
                     '`SpikeAndSlabSparseLinearRegression`; '
                     'instead saw {}'.format(regression_component))
  model_has_spike_slab_regression = isinstance(
      regression_component, SpikeAndSlabSparseLinearRegression)

  if is_missing is not None:  # Ensure series does not contain NaNs.
    observed_time_series = tf.where(is_missing,
                                    tf.zeros_like(observed_time_series),
                                    observed_time_series)

  num_observed_steps = prefer_static.shape(observed_time_series)[-1]
  design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps]
  if is_missing is not None:
    # Replace design matrix with zeros at unobserved timesteps. This ensures
    # they will not affect the posterior on weights.
    design_matrix = tf.where(is_missing[..., tf.newaxis],
                             tf.zeros_like(design_matrix),
                             design_matrix)

  # Untransform scale priors -> variance priors by reaching thru Sqrt bijector.
  observation_noise_param = model.parameters[0]
  if 'observation_noise' not in observation_noise_param.name:
    raise ValueError('Model parameters {} do not match the expected sampler '
                     'state.'.format(model.parameters))
  observation_noise_variance_prior = observation_noise_param.prior.distribution
  if model_has_slope:
    level_scale_variance_prior, slope_scale_variance_prior = [
        p.prior.distribution for p in level_component.parameters]
  else:
    level_scale_variance_prior = (
        level_component.parameters[0].prior.distribution)

  if model_has_spike_slab_regression:
    spike_and_slab_sampler = spike_and_slab.SpikeSlabSampler(
        design_matrix,
        weights_prior_precision=regression_component._weights_prior_precision,  # pylint: disable=protected-access
        nonzero_prior_prob=regression_component._sparse_weights_nonzero_prob,  # pylint: disable=protected-access
        observation_noise_variance_prior_concentration=(
            observation_noise_variance_prior.concentration),
        observation_noise_variance_prior_scale=(
            observation_noise_variance_prior.scale),
        observation_noise_variance_upper_bound=(
            observation_noise_variance_prior.upper_bound
            if hasattr(observation_noise_variance_prior, 'upper_bound')
            else None))
  else:
    weights_prior_scale = (
        regression_component.parameters[0].prior.scale)

  def sampler_loop_body(previous_sample, _):
    """Runs one sampler iteration, resampling all model variables."""

    (weights_seed,
     level_seed,
     observation_noise_scale_seed,
     level_scale_seed,
     loop_seed) = samplers.split_seed(
         previous_sample.seed, n=5, salt='sampler_loop_body')
    # Preserve backward-compatible seed behavior by splitting slope separately.
    slope_scale_seed, = samplers.split_seed(
        previous_sample.seed, n=1, salt='sampler_loop_body_slope')

    # We encourage a reasonable initialization by sampling the weights first,
    # so at the first step they are regressed directly against the observed
    # time series. If we instead sampled the level first it might 'explain away'
    # some observed variation that we would ultimately prefer to explain through
    # the regression weights, because the level can represent arbitrary
    # variation, while the weights are limited to representing variation in the
    # subspace given by the design matrix.
    if model_has_spike_slab_regression:
      (observation_noise_variance,
       weights) = spike_and_slab_sampler.sample_noise_variance_and_weights(
           initial_nonzeros=tf.not_equal(previous_sample.weights, 0.),
           targets=observed_time_series - previous_sample.level,
           seed=weights_seed)
      observation_noise_scale = tf.sqrt(observation_noise_variance)
    else:
      weights = _resample_weights(
          design_matrix=design_matrix,
          target_residuals=observed_time_series - previous_sample.level,
          observation_noise_scale=previous_sample.observation_noise_scale,
          weights_prior_scale=weights_prior_scale,
          seed=weights_seed)
      # Noise scale will be resampled below.
      observation_noise_scale = previous_sample.observation_noise_scale

    regression_residuals = observed_time_series - tf.linalg.matvec(
        design_matrix, weights)
    latents = _resample_latents(
        observed_residuals=regression_residuals,
        level_scale=previous_sample.level_scale,
        slope_scale=previous_sample.slope_scale if model_has_slope else None,
        observation_noise_scale=observation_noise_scale,
        initial_state_prior=level_component.initial_state_prior,
        is_missing=is_missing,
        seed=level_seed)
    level = latents[..., 0]
    level_residuals = level[..., 1:] - level[..., :-1]
    if model_has_slope:
      slope = latents[..., 1]
      level_residuals -= slope[..., :-1]
      slope_residuals = slope[..., 1:] - slope[..., :-1]

    # Estimate level scale from the empirical changes in level.
    level_scale = _resample_scale(
        prior=level_scale_variance_prior,
        observed_residuals=level_residuals,
        is_missing=None,
        seed=level_scale_seed)
    if model_has_slope:
      slope_scale = _resample_scale(
          prior=slope_scale_variance_prior,
          observed_residuals=slope_residuals,
          is_missing=None,
          seed=slope_scale_seed)
    if not model_has_spike_slab_regression:
      # Estimate noise scale from the residuals.
      observation_noise_scale = _resample_scale(
          prior=observation_noise_variance_prior,
          observed_residuals=regression_residuals - level,
          is_missing=is_missing,
          seed=observation_noise_scale_seed)

    return GibbsSamplerState(
        observation_noise_scale=observation_noise_scale,
        level_scale=level_scale,
        slope_scale=(slope_scale if model_has_slope
                     else previous_sample.slope_scale),
        weights=weights,
        level=level,
        slope=(slope if model_has_slope
               else previous_sample.slope),
        seed=loop_seed)
  return sampler_loop_body
Example #4
0
def random_gamma_rejection(sample_shape,
                           alpha,
                           beta,
                           internal_dtype=tf.float64,
                           seed=None):
    """Samples from the gamma distribution.

  The sampling algorithm is rejection sampling [1], and pathwise gradients with
  respect to alpha are computed via implicit differentiation [2].

  Args:
    sample_shape: The output sample shape. Must broadcast with both
      `alpha` and `beta`.
    alpha: Floating point tensor, the alpha params of the
      distribution(s). Must contain only positive values. Must broadcast with
      `beta`.
    beta: Floating point tensor, the inverse scale params of the
      distribution(s). Must contain only positive values. Must broadcast with
      `alpha`.
    internal_dtype: dtype to use for internal computations.
    seed: (optional) The random seed.

  Returns:
    Differentiable samples from the gamma distribution.

  #### References

  [1] George Marsaglia and Wai Wan Tsang. A simple method for generating Gamma
      variables. ACM Transactions on Mathematical Software, 2000.

  [2] Michael Figurnov, Shakir Mohamed, and Andriy Mnih. Implicit
      Reparameterization Gradients. Neural Information Processing Systems, 2018.
  """
    generate_and_test_samples_seed, alpha_fix_seed = samplers.split_seed(
        seed, salt='random_gamma')
    output_dtype = dtype_util.common_dtype([alpha, beta],
                                           dtype_hint=tf.float32)

    def rejection_sample(alpha):
        """Gamma rejection sampler."""
        # Note that alpha here already has a shape that is broadcast with beta.
        cast_alpha = tf.cast(alpha, internal_dtype)

        good_params_mask = (alpha > 0.)
        # When replacing NaN values, use 100. for alpha, since that leads to a
        # high-likelihood of the rejection sampler accepting on the first pass.
        safe_alpha = tf.where(good_params_mask, cast_alpha, 100.)

        modified_safe_alpha = tf.where(safe_alpha < 1., safe_alpha + 1.,
                                       safe_alpha)

        one_third = tf.constant(1. / 3, dtype=internal_dtype)
        d = modified_safe_alpha - one_third
        c = one_third / tf.sqrt(d)

        def generate_and_test_samples(seed):
            """Generate and test samples."""
            v_seed, u_seed = samplers.split_seed(seed)

            def generate_positive_v():
                """Generate positive v."""
                def _inner(seed):
                    x = samplers.normal(sample_shape,
                                        dtype=internal_dtype,
                                        seed=seed)
                    # This implicitly broadcasts alpha up to sample shape.
                    v = 1 + c * x
                    return (x, v), v > 0.

                # Note: It should be possible to remove this 'inner' call to
                # `batched_las_vegas_algorithm` and merge the v > 0 check into the
                # overall check for a good sample. This would lead to a slightly simpler
                # implementation; it is unclear whether it would be faster. We include
                # the inner loop so this implementation is more easily comparable to
                # Ref. [1] and other implementations.
                return brs.batched_las_vegas_algorithm(_inner, v_seed)[0]

            (x, v) = generate_positive_v()
            x2 = x * x
            v3 = v * v * v
            u = samplers.uniform(sample_shape,
                                 dtype=internal_dtype,
                                 seed=u_seed)

            # In [1], the suggestion is to first check u < 1 - 0.331 * x2 * x2, and to
            # run the check below only if it fails, in order to avoid the relatively
            # expensive logarithm calls. Our algorithm operates in batch mode: we will
            # have to compute or not compute the logarithms for the entire batch, and
            # as the batch gets larger, the odds we compute it grow. Therefore we
            # don't bother with the "cheap" check.
            good_sample_mask = (tf.math.log(u) < x2 / 2. + d *
                                (1 - v3 + tf.math.log(v3)))

            return v3, good_sample_mask

        samples = brs.batched_las_vegas_algorithm(
            generate_and_test_samples, seed=generate_and_test_samples_seed)[0]

        samples = samples * d

        one = tf.constant(1., dtype=internal_dtype)

        alpha_lt_one_fix = tf.where(
            safe_alpha < 1.,
            tf.math.pow(
                samplers.uniform(sample_shape,
                                 dtype=internal_dtype,
                                 seed=alpha_fix_seed), one / safe_alpha), one)
        samples = samples * alpha_lt_one_fix
        samples = tf.where(good_params_mask, samples, np.nan)

        output_type_samples = tf.cast(samples, output_dtype)

        # We use `tf.where` instead of `tf.maximum` because we need to allow for
        # `samples` to be `nan`, but `tf.maximum(nan, x) == x`.
        output_type_samples = tf.where(
            output_type_samples == 0,
            np.finfo(dtype_util.as_numpy_dtype(
                output_type_samples.dtype)).tiny, output_type_samples)

        return output_type_samples

    broadcast_alpha_shape = ps.broadcast_shape(ps.shape(alpha), ps.shape(beta))
    broadcast_alpha = tf.broadcast_to(alpha, broadcast_alpha_shape)
    alpha_samples = rejection_sample(broadcast_alpha)

    corrected_beta = tf.where(beta > 0., beta, np.nan)
    return alpha_samples / corrected_beta
Example #5
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = ps.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated = x_rotated - tf.reduce_mean(
                x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = ps.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = ps.cast(x_len, np.float64)
        target_length = ps.pow(np.float64(2.),
                               ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = ps.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            tensorshape_util.set_shape(shifted_product_chopped, chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = ps.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - ps.range(0., max_lags + 1.)
        denominator = ps.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
  def __init__(self,
               num_timesteps,
               design_matrix,
               drift_scale,
               initial_state_prior,
               observation_noise_scale=0.,
               initial_step=0,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """State space model for a dynamic linear regression.

    Args:
      num_timesteps: Scalar `int` `Tensor` number of timesteps to model
        with this distribution.
      design_matrix: float `Tensor` of shape `concat([batch_shape,
        [num_timesteps, num_features]])`.
      drift_scale: Scalar (any additional dimensions are treated as batch
        dimensions) `float` `Tensor` indicating the standard deviation of the
        latent state transitions.
      initial_state_prior: instance of `tfd.MultivariateNormal`
        representing the prior distribution on latent states.  Must have
        event shape `[num_features]`.
      observation_noise_scale: Scalar (any additional dimensions are
        treated as batch dimensions) `float` `Tensor` indicating the standard
        deviation of the observation noise.
        Default value: `0.`.
      initial_step: scalar `int` `Tensor` specifying the starting timestep.
        Default value: `0`.
      validate_args: Python `bool`. Whether to validate input with asserts. If
        `validate_args` is `False`, and the inputs are invalid, correct behavior
        is not guaranteed.
        Default value: `False`.
      allow_nan_stats: Python `bool`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
        Default value: `True`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: 'DynamicLinearRegressionStateSpaceModel'.

    """

    with tf.name_scope(
        name or 'DynamicLinearRegressionStateSpaceModel') as name:
      dtype = dtype_util.common_dtype(
          [design_matrix, drift_scale, initial_state_prior])

      design_matrix = tf.convert_to_tensor(
          value=design_matrix, name='design_matrix', dtype=dtype)
      design_matrix_with_time_in_first_dim = distribution_util.move_dimension(
          design_matrix, -2, 0)

      drift_scale = tf.convert_to_tensor(
          value=drift_scale, name='drift_scale', dtype=dtype)

      observation_noise_scale = tf.convert_to_tensor(
          value=observation_noise_scale,
          name='observation_noise_scale',
          dtype=dtype)

      num_features = prefer_static.shape(design_matrix)[-1]

      def observation_matrix_fn(t):
        observation_matrix = tf.linalg.LinearOperatorFullMatrix(
            tf.gather(design_matrix_with_time_in_first_dim,
                      t)[..., tf.newaxis, :], name='observation_matrix')
        return observation_matrix

      self._drift_scale = drift_scale
      self._observation_noise_scale = observation_noise_scale

      super(DynamicLinearRegressionStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=tf.linalg.LinearOperatorIdentity(
              num_rows=num_features,
              dtype=dtype,
              name='transition_matrix'),
          transition_noise=tfd.MultivariateNormalDiag(
              scale_diag=(drift_scale[..., tf.newaxis] *
                          tf.ones([num_features], dtype=dtype)),
              name='transition_noise'),
          observation_matrix=observation_matrix_fn,
          observation_noise=tfd.MultivariateNormalDiag(
              scale_diag=observation_noise_scale[..., tf.newaxis],
              name='observation_noise'),
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          allow_nan_stats=allow_nan_stats,
          validate_args=validate_args,
          name=name)
 def _batch_shape_tensor(self):
   return tf.broadcast_dynamic_shape(
       [] if self.amplitude is None else ps.shape(self.amplitude),
       [] if self.length_scale is None else ps.shape(self.length_scale))
    def _distributional_transform(self, x, event_shape):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the 'flattened' event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Args:
      x: Sample of mixture distribution
      event_shape: The event shape of this distribution

    Returns:
      Result of the distributional transform
    """

        if tensorshape_util.rank(x.shape) is None:
            # tf.math.softmax raises an error when applied to inputs of undefined
            # rank.
            raise ValueError(
                'Distributional transform does not support inputs of '
                'undefined rank.')

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message='`univariate_components` must have scalar event')
        ]):
            event_ndims = ps.rank_from_shape(event_shape)
            x_padded = self._pad_sample_dims(
                x, event_ndims=event_ndims)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, event_size]),
                    exclusive=True,
                    axis=-1),
                ps.shape(log_prob_x))  # [S, B, k, E]

            event_ndims = ps.rank_from_shape(event_shape)
            logits_mix_prob = self.mixture_distribution.logits_parameter()
            logits_mix_prob = tf.reshape(
                logits_mix_prob,  # [k] or [B, k]
                ps.concat([
                    ps.shape(logits_mix_prob),
                    ps.ones([event_ndims], dtype=tf.int32),
                ],
                          axis=0))  # [k, [1]*e] or [B, k, [1]*e]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = tensorshape_util.rank(x.shape) - event_ndims
            posterior_weights_x = tf.math.softmax(log_posterior_weights_x,
                                                  axis=component_axis)
            return tf.reduce_sum(posterior_weights_x * cdf_x,
                                 axis=component_axis)
Example #9
0
 def _batch_shape_tensor(self):
   return ps.shape(self.concentration)
    def _reparameterize_sample(self, x, event_shape):
        """Adds reparameterization (pathwise) gradients to samples of the mixture.

    Implicit reparameterization gradients are
       dx/dphi = -(d transform(x, phi) / dx)^-1 * d transform(x, phi) / dphi,
    where transform(x, phi) is distributional transform that removes all
    parameters from samples x.

    We implement them by replacing x with
      -stop_gradient(d transform(x, phi) / dx)^-1 * transform(x, phi)]
    for the backward pass (gradient computation).
    The derivative of this quantity w.r.t. phi is then the implicit
    reparameterization gradient.
    Note that this replaces the gradients w.r.t. both the mixture
    distribution parameters and components distributions parameters.

    Limitations:
      1. Fundamental: components must be fully reparameterized.
      2. Distributional transform is currently only implemented for
        factorized components.
      3. Distributional transform currently only works for known rank of the
        batch tensor.

    Args:
      x: Sample of mixture distribution
      event_shape: The event shape of this distribution

    Returns:
      Tensor with same value as x, but with reparameterization gradients
    """
        # Remove the existing gradients of x wrt parameters of the components.
        x = tf.stop_gradient(x)

        event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32)
        x_2d_shape = [-1, event_size]  # [S*prod(B), prod(E)]

        # Perform distributional transform of x in [S, B, E] shape,
        # but have Jacobian of size [S*prod(B), prod(E), prod(E)].
        def reshaped_distributional_transform(x_2d):
            return tf.reshape(
                self._distributional_transform(tf.reshape(x_2d, ps.shape(x)),
                                               event_shape), x_2d_shape)

        # transform_2d: [S*prod(B), prod(E)]
        # jacobian: [S*prod(B), prod(E), prod(E)]
        x_2d = tf.reshape(x, x_2d_shape)
        transform_2d, jacobian = value_and_batch_jacobian(
            reshaped_distributional_transform, x_2d)

        # We only provide the first derivative; the second derivative computed by
        # autodiff would be incorrect, so we raise an error if it is requested.
        transform_2d = _prevent_2nd_derivative(transform_2d)

        # Compute [- stop_gradient(jacobian)^-1 * transform] by solving a linear
        # system. The Jacobian is lower triangular because the distributional
        # transform for i-th event dimension does not depend on the next
        # dimensions.
        surrogate_x_2d = -tf.linalg.triangular_solve(
            tf.stop_gradient(jacobian),
            transform_2d[..., tf.newaxis],
            lower=True)  # [S*prod(B), prod(E), 1]
        surrogate_x = tf.reshape(surrogate_x_2d, ps.shape(x))

        # Replace gradients of x with gradients of surrogate_x, but keep the value.
        return x + (surrogate_x - tf.stop_gradient(surrogate_x))
 def reshaped_distributional_transform(x_2d):
     return tf.reshape(
         self._distributional_transform(tf.reshape(x_2d, ps.shape(x)),
                                        event_shape), x_2d_shape)
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        try:
            seed_stream = SeedStream(seed, salt='MixtureSameFamily')
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=components_seed)
            if seed_stream is not None:
                seed_stream()  # Advance even if unused.
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `components_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. {}')
            warnings.warn(
                msg.format(self.components_distribution.name,
                           type(self.components_distribution), str(e)))
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=seed_stream())

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = ps.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `mixture_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(self.mixture_distribution.name,
                           type(self.mixture_distribution), str(e)))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = ps.rank(x) - event_ndims - 1
        mask_batch_ndims = ps.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = ps.shape(mask)
        target_shape = ps.concat([
            mask_shape[:-1],
            ps.ones([pad_ndims], dtype=tf.int32),
            mask_shape[-1:],
            ps.ones([event_ndims], dtype=tf.int32),
        ],
                                 axis=0)
        mask = tf.reshape(mask, shape=target_shape)

        if dtype_util.is_floating(x.dtype) or dtype_util.is_complex(x.dtype):
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
Example #13
0
    def _sample_n(self, n, seed=None):
        power = tf.convert_to_tensor(self.power)
        shape = ps.concat([[n], ps.shape(power)], axis=0)
        numpy_dtype = dtype_util.as_numpy_dtype(power.dtype)

        seed = samplers.sanitize_seed(seed, salt='zipf')

        # Because `_hat_integral` is montonically decreasing, the bounds for u will
        # switch.
        # Compute the hat_integral explicitly here since we can calculate the log of
        # the inputs statically in float64 with numpy.
        maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) -
                               tf.math.log(power - 1.)) + 1.
        minval_u = tf.math.exp(
            -(power - 1.) *
            numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) -
            tf.math.log(power - 1.))

        def loop_body(should_continue, k, seed):
            """Resample the non-accepted points."""
            u_seed, next_seed = samplers.split_seed(seed)
            # Uniform variates must be sampled from the open-interval `(0, 1)` rather
            # than `[0, 1)`. To do so, we use
            # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny`
            # because it is the smallest, positive, 'normal' number. A 'normal' number
            # is such that the mantissa has an implicit leading 1. Normal, positive
            # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
            # this case, a subnormal number (i.e., np.nextafter) can cause us to
            # sample 0.
            u = samplers.uniform(
                shape,
                minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny,
                maxval=numpy_dtype(1.),
                dtype=power.dtype,
                seed=u_seed)
            # We use (1 - u) * maxval_u + u * minval_u rather than the other way
            # around, since we want to draw samples in (minval_u, maxval_u].
            u = maxval_u + (minval_u - maxval_u) * u
            # set_shape needed here because of b/139013403
            tensorshape_util.set_shape(u, should_continue.shape)

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u, power=power)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5, power=power) +
                      tf.exp(self._log_prob(k + 1, power=power)))

            return [should_continue & (~accept), k, next_seed]

        should_continue, samples, _ = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=power.dtype),  # k
                seed,  # seed
            ],
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples
Example #14
0
 def _transpose_around_bijector_fn(self,
                                   bijector_fn,
                                   arg,
                                   src_event_ndims,
                                   dest_event_ndims=None,
                                   fn_reduces_event=False,
                                   **kwargs):
     # This function moves the axes corresponding to `self.sample_shape` to the
     # left of the batch shape, then applies `bijector_fn`, then moves the axes
     # corresponding to `self.sample_shape` back to the event part of the shape.
     #
     # `src_event_ndims` and `dest_event_ndims` indicate the expected event rank
     # (omitting `self.sample_shape`) before and after applying `bijector_fn`.
     #
     # This function arose because forward and inverse ended up being quite
     # similar. It was then only a small generalization to also support {F/I}LDJ.
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     arg_ndims = ps.rank(arg)
     # (1) Expand arg's dims.
     d = arg_ndims - batch_ndims - extra_sample_ndims - src_event_ndims
     arg = tf.reshape(arg,
                      shape=ps.pad(ps.shape(arg),
                                   paddings=[[ps.maximum(0, -d), 0]],
                                   constant_values=1))
     arg_ndims = ps.rank(arg)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose arg's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           arg_ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     arg = tf.transpose(arg, perm=perm)
     # (3) Apply underlying bijector.
     result = bijector_fn(arg, **kwargs)
     # (4) Transpose sample_shape from the sample to the event shape.
     result_ndims = ps.rank(result)
     if fn_reduces_event:
         dest_event_ndims = 0
     d = result_ndims - batch_ndims - extra_sample_ndims - dest_event_ndims
     if fn_reduces_event:
         # In some cases, fn may reduce event too far, i.e. ildj may return a
         # scalar `0.`, which won't work with the transpose we do below.
         result = tf.reshape(result,
                             shape=ps.pad(ps.shape(result),
                                          paddings=[[ps.maximum(0, -d), 0]],
                                          constant_values=1))
         result_ndims = ps.rank(result)
     sample_ndims = ps.maximum(0, d)
     sample_dims = ps.range(0, sample_ndims)
     extra_sample_dims = ps.range(sample_ndims,
                                  sample_ndims + extra_sample_ndims)
     batch_dims = ps.range(sample_ndims + extra_sample_ndims,
                           sample_ndims + extra_sample_ndims + batch_ndims)
     event_dims = ps.range(sample_ndims + extra_sample_ndims + batch_ndims,
                           result_ndims)
     perm = ps.concat(
         [sample_dims, batch_dims, extra_sample_dims, event_dims], axis=0)
     return tf.transpose(result, perm=perm)
def _kahan_reduce_bwd(aux, grads):
    operands, inits, axis, unsqueezed_shape = aux
    del inits, axis  # unused
    return (tf.broadcast_to(tf.reshape(grads[0], unsqueezed_shape),
                            ps.shape(operands[0])), None)
def _get_reduction_axes(x, nd):
    """Enumerates the final `nd` axis indices of `x`."""
    x_rank = prefer_static.rank_from_shape(prefer_static.shape(x))
    return prefer_static.range(x_rank - 1, x_rank - nd - 1, -1)
Example #17
0
 def _batch_shape_tensor(self):
     return prefer_static.shape(self.concentration)
def convolution_batch(x,
                      kernel,
                      rank,
                      strides,
                      padding,
                      data_format=None,
                      dilations=None,
                      name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    if rank != 2:
        raise NotImplementedError(
            'Argument `rank` currently only supports `2`; '
            'saw "{}".'.format(rank))
    if data_format is not None and data_format.upper() != 'NHWBC':
        raise ValueError(
            'Argument `data_format` currently only supports "NHWBC"; '
            'saw "{}".'.format(data_format))
    with tf.name_scope(name or 'conv2d_nhwbc'):
        # Prepare arguments.
        [
            rank,
            _,  # strides
            padding,
            dilations,
            data_format,
        ] = prepare_conv_args(rank, strides, padding, dilations)
        strides = prepare_tuple_argument(strides, rank + 2, arg_name='strides')

        dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

        # Step 1: Transpose and double flatten kernel.
        # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
        kernel_shape = prefer_static.shape(kernel)
        kernel_batch_shape, kernel_event_shape = prefer_static.split(
            kernel_shape, num_or_size_splits=[-1, rank + 2])
        kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape)
        kernel_ndims = prefer_static.rank(kernel)
        kernel_batch_ndims = kernel_ndims - rank - 2
        perm = prefer_static.concat([
            prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank),
            prefer_static.range(0, kernel_batch_ndims),
            prefer_static.range(kernel_batch_ndims + rank, kernel_ndims),
        ],
                                    axis=0)  # Eg, [1, 2, 0, 3, 4]
        kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
        kernel = tf.reshape(kernel,
                            shape=prefer_static.concat([
                                kernel_event_shape[:rank],
                                [
                                    kernel_batch_size * kernel_event_shape[-2],
                                    kernel_event_shape[-1]
                                ],
                            ],
                                                       axis=0))  # F + [bc, c']

        # Step 2: Double flatten x.
        # x.shape = N + D + B + [c]
        x_shape = prefer_static.shape(x)
        [
            x_sample_shape,
            x_rank_shape,
            x_batch_shape,
            x_channel_shape,
        ] = prefer_static.split(
            x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
        x = tf.reshape(
            x,  # N + D + B + [c]
            shape=prefer_static.concat([
                [prefer_static.reduce_prod(x_sample_shape)],
                x_rank_shape,
                [
                    prefer_static.reduce_prod(x_batch_shape) *
                    prefer_static.reduce_prod(x_channel_shape)
                ],
            ],
                                       axis=0))  # [n] + D + [bc]

        # Step 3: Apply convolution.
        y = tf.nn.depthwise_conv2d(x,
                                   kernel,
                                   strides=strides,
                                   padding=padding,
                                   data_format='NHWC',
                                   dilations=dilations)
        #  SAME: y.shape = [n, h,      w,      bcc']
        # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

        # Step 4: Reshape/reduce for output.
        y_shape = prefer_static.shape(y)
        y = tf.reshape(y,
                       shape=prefer_static.concat(
                           [
                               x_sample_shape,
                               y_shape[1:-1],
                               kernel_batch_shape,
                               kernel_event_shape[-2:],
                           ],
                           axis=0))  # N + D' + B + [c, c']
        y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

        return y
  def __init__(self,
               design_matrix,
               drift_scale_prior=None,
               initial_weights_prior=None,
               observed_time_series=None,
               name=None):
    """Specify a dynamic linear regression.

    Args:
      design_matrix: float `Tensor` of shape `concat([batch_shape,
        [num_timesteps, num_features]])`.
      drift_scale_prior: instance of `tfd.Distribution` specifying a prior on
        the `drift_scale` parameter. If `None`, a heuristic default prior is
        constructed based on the provided `observed_time_series`.
        Default value: `None`.
      initial_weights_prior: instance of `tfd.MultivariateNormal` representing
        the prior distribution on the latent states (the regression weights).
        Must have event shape `[num_features]`. If `None`, a weakly-informative
        Normal(0., 10.) prior is used.
        Default value: `None`.
      observed_time_series: `float` `Tensor` of shape `batch_shape + [T, 1]`
        (omitting the trailing unit dimension is also supported when `T > 1`),
        specifying an observed time series. Any priors not explicitly set will
        be given default values according to the scale of the observed time
        series (or batch of time series). May optionally be an instance of
        `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify
        timesteps with missing observations.
        Default value: `None`.
      name: Python `str` for the name of this component.
        Default value: 'DynamicLinearRegression'.

    """

    with tf.name_scope(name or 'DynamicLinearRegression') as name:

      dtype = dtype_util.common_dtype(
          [design_matrix, drift_scale_prior, initial_weights_prior])

      num_features = prefer_static.shape(design_matrix)[-1]

      # Default to a weakly-informative Normal(0., 10.) for the initital state
      if initial_weights_prior is None:
        initial_weights_prior = tfd.MultivariateNormalDiag(
            scale_diag=10. * tf.ones([num_features], dtype=dtype))

      # Heuristic default priors. Overriding these may dramatically
      # change inference performance and results.
      if drift_scale_prior is None:
        if observed_time_series is None:
          observed_stddev = tf.constant(1.0, dtype=dtype)
        else:
          _, observed_stddev, _ = sts_util.empirical_statistics(
              observed_time_series)

        drift_scale_prior = tfd.LogNormal(
            loc=tf.math.log(.05 * observed_stddev),
            scale=3.,
            name='drift_scale_prior')

      self._initial_state_prior = initial_weights_prior
      self._design_matrix = design_matrix

      super(DynamicLinearRegression, self).__init__(
          parameters=[
              Parameter('drift_scale', drift_scale_prior,
                        tfb.Chain([tfb.AffineScalar(scale=observed_stddev),
                                   tfb.Softplus()]))
          ],
          latent_size=num_features,
          name=name)
 def _batch_shape(self):
   tensors = [self.loc, self.scale, self.tailweight, self.skewness]
   return functools.reduce(
       ps.broadcast_shape,
       [ps.shape(t) for t in tensors])
Example #21
0
 def _batch_shape_tensor(self, concentration=None, rate=None):
     return ps.broadcast_shape(
         ps.shape(self.concentration
                  if concentration is None else concentration),
         ps.shape(self.rate if rate is None else rate))
Example #22
0
def find_root_chandrupatla(objective_fn,
                           low,
                           high,
                           position_tolerance=1e-8,
                           value_tolerance=0.,
                           max_iterations=50,
                           stopping_policy_fn=tf.reduce_all,
                           validate_args=False,
                           name='find_root_chandrupatla'):
    r"""Finds root(s) of a scalar function using Chandrupatla's method.

  Chandrupatla's method [1, 2] is a root-finding algorithm that is guaranteed
  to converge if a root lies within the given bounds. It generalizes the
  [bisection method](https://en.wikipedia.org/wiki/Bisection_method); at each
  step it chooses to perform either bisection or inverse quadratic
  interpolation. This makes it similar in spirit to [Brent's method](
  https://en.wikipedia.org/wiki/Brent%27s_method), which also considers steps
  that use the secant method, but Chandrupatla's method is simpler and often
  converges at least as quickly [3].

  Args:
    objective_fn: Python callable for which roots are searched. It must be a
      callable of a single variable. `objective_fn` must return a `Tensor` with
      shape `batch_shape` and dtype matching `lower_bound` and `upper_bound`.
    low: Float `Tensor` of shape `batch_shape` representing a lower
      bound(s) on the value of a root(s).
    high: Float `Tensor` of shape `batch_shape` representing an upper
      bound(s) on the value of a root(s).
    position_tolerance: Optional `Tensor` representing the maximum absolute
      error in the positions of the estimated roots. Shape must broadcast with
      `batch_shape`.
      Default value: `1e-8`.
    value_tolerance: Optional `Tensor` representing the absolute error allowed
      in the value of the objective function. If the absolute value of
      `objective_fn` is smaller than
      `value_tolerance` at a given position, then that position is considered a
      root for the function. Shape must broadcast with `batch_shape`.
      Default value: `1e-8`.
    max_iterations: Optional `Tensor` or Python integer specifying the maximum
      number of steps to perform. Shape must broadcast with `batch_shape`.
      Default value: `50`.
    stopping_policy_fn: Python `callable` controlling the algorithm termination.
      It must be a callable accepting a `Tensor` of booleans with the same shape
      as `lower_bound` and `upper_bound` (denoting whether each search is
      finished), and returning a scalar boolean `Tensor` indicating
      whether the overall search should stop. Typical values are
      `tf.reduce_all` (which returns only when the search is finished for all
      points), and `tf.reduce_any` (which returns as soon as the search is
      finished for any point).
      Default value: `tf.reduce_all` (returns only when the search is finished
        for all points).
    validate_args: Python `bool` indicating whether to validate arguments.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: 'find_root_chandrupatla'.

  Returns:
    root_search_results: A Python `namedtuple` containing the following items:
      estimated_root: `Tensor` containing the last position explored. If the
        search was successful within the specified tolerance, this position is
        a root of the objective function.
      objective_at_estimated_root: `Tensor` containing the value of the
        objective function at `position`. If the search was successful within
        the specified tolerance, then this is close to 0.
      num_iterations: The number of iterations performed.

  #### References

  [1] Tirupathi R. Chandrupatla. A new hybrid quadratic/bisection algorithm for
      finding the zero of a nonlinear function without using derivatives.
      _Advances in Engineering Software_, 28.3:145-149, 1997.
  [2] Philipp OJ Scherer. Computational Physics. _Springer Berlin_,
      Heidelberg, 2010.
      Section 6.1.7.3 https://books.google.com/books?id=cC-8BAAAQBAJ&pg=PA95
  [3] Jason Sachs. Ten Little Algorithms, Part 5: Quadratic Extremum
      Interpolation and Chandrupatla's Method (2015).
      https://www.embeddedrelated.com/showarticle/855.php
  """

    ################################################
    # Loop variables used by Chandrupatla's method:
    #
    #  a: endpoint of an interval `[min(a, b), max(a, b)]` containing the
    #     root. There is no guarantee as to which of `a` and `b` is larger.
    #  b: endpoint of an interval `[min(a, b), max(a, b)]` containing the
    #       root. There is no guarantee as to which of `a` and `b` is larger.
    #  f_a: value of the objective at `a`.
    #  f_b: value of the objective at `b`.
    #  t: the next position to be evaluated as the coefficient of a convex
    #    combination of `a` and `b` (i.e., a value in the unit interval).
    #  num_iterations: integer number of steps taken so far.
    #  converged: boolean indicating whether each batch element has converged.
    #
    # All variables have the same shape `batch_shape`.

    def _should_continue(a, b, f_a, f_b, t, num_iterations, converged):
        del a, b, f_a, f_b, t  # Unused.
        all_converged = stopping_policy_fn(
            tf.logical_or(converged, num_iterations >= max_iterations))
        return ~all_converged

    def _body(a, b, f_a, f_b, t, num_iterations, converged):
        """One step of Chandrupatla's method for root finding."""
        previous_loop_vars = (a, b, f_a, f_b, t, num_iterations, converged)
        finalized_elements = tf.logical_or(converged,
                                           num_iterations >= max_iterations)

        # Evaluate the new point.
        x_new = (1 - t) * a + t * b
        f_new = objective_fn(x_new)
        # If we've bisected (t==0.5) and the new float value for `a` is identical to
        # that from the previous iteration, then we'll keep bisecting (the
        # logic below will set t==0.5 for the next step), and nothing further will
        # change.
        at_fixed_point = tf.equal(x_new, a) & tf.equal(t, 0.5)
        # Otherwise, tighten the bounds.
        a, b, c, f_a, f_b, f_c = _structure_broadcasting_where(
            tf.equal(tf.math.sign(f_new), tf.math.sign(f_a)),
            (x_new, b, a, f_new, f_b, f_a), (x_new, a, b, f_new, f_a, f_b))

        # Check for convergence.
        f_best = tf.where(tf.abs(f_a) < tf.abs(f_b), f_a, f_b)
        interval_tolerance = position_tolerance / (tf.abs(b - c))
        converged = tf.logical_or(
            interval_tolerance > 0.5,
            tf.logical_or(
                tf.math.abs(f_best) <= value_tolerance, at_fixed_point))

        # Propose next point to evaluate.
        xi = (a - b) / (c - b)
        phi = (f_a - f_b) / (f_c - f_b)
        t = tf.where(
            # Condition for inverse quadratic interpolation.
            tf.logical_and(1 - tf.math.sqrt(1 - xi) < phi,
                           tf.math.sqrt(xi) > phi),
            # Propose a point by inverse quadratic interpolation.
            (f_a / (f_b - f_a) * f_c / (f_b - f_c) + (c - a) / (b - a) * f_a /
             (f_c - f_a) * f_b / (f_c - f_b)),
            # Otherwise, just cut the interval in half (bisection).
            0.5)
        # Constrain the proposal to the current interval (0 < t < 1).
        t = tf.minimum(tf.maximum(t, interval_tolerance),
                       1 - interval_tolerance)

        # Update elements that haven't converged.
        return _structure_broadcasting_where(
            finalized_elements, previous_loop_vars,
            (a, b, f_a, f_b, t, num_iterations + 1, converged))

    with tf.name_scope(name):
        max_iterations = tf.convert_to_tensor(max_iterations,
                                              name='max_iterations',
                                              dtype_hint=tf.int32)
        a = tf.convert_to_tensor(low, name='lower_bound')
        b = tf.convert_to_tensor(high, name='upper_bound')
        f_a, f_b = objective_fn(a), objective_fn(b)
        batch_shape = ps.broadcast_shape(ps.shape(f_a), ps.shape(f_b))

        assertions = []
        if validate_args:
            assertions += [
                assert_util.assert_none_equal(
                    tf.math.sign(f_a),
                    tf.math.sign(f_b),
                    message='Bounds must be on different sides of a root.')
            ]

        with tf.control_dependencies(assertions):
            initial_loop_vars = [
                a, b, f_a, f_b,
                tf.cast(0.5, dtype=f_a.dtype),
                tf.cast(0, dtype=max_iterations.dtype), False
            ]
            a, b, f_a, f_b, _, num_iterations, _ = tf.while_loop(
                _should_continue,
                _body,
                loop_vars=tf.nest.map_structure(
                    lambda x: tf.broadcast_to(x, batch_shape),
                    initial_loop_vars))

        x_best, f_best = _structure_broadcasting_where(
            tf.abs(f_a) < tf.abs(f_b), (a, f_a), (b, f_b))
    return RootSearchResults(estimated_root=x_best,
                             objective_at_estimated_root=f_best,
                             num_iterations=num_iterations)
Example #23
0
def covariance(x,
               y=None,
               sample_axis=0,
               event_axis=-1,
               keepdims=False,
               name=None):
    """Sample covariance between observations indexed by `event_axis`.

  Given `N` samples of scalar random variables `X` and `Y`, covariance may be
  estimated as

  ```none
  Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)}
  Xbar := N^{-1} sum_{n=1}^N X_n
  Ybar := N^{-1} sum_{n=1}^N Y_n
  ```

  For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`,
  one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`.

  ```python
  x = tf.random.normal(shape=(100, 2, 3))
  y = tf.random.normal(shape=(100, 2, 3))

  # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j].
  cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None)

  # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n]
  cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)
  ```

  Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is
  slightly biased.

  Args:
    x:  A numeric `Tensor` holding samples.
    y:  Optional `Tensor` with same `dtype` and `shape` as `x`.
      Default value: `None` (`y` is effectively set to `x`).
    sample_axis: Scalar or vector `Tensor` designating axis holding samples, or
      `None` (meaning all axis hold samples).
      Default value: `0` (leftmost dimension).
    event_axis:  Scalar or vector `Tensor`, or `None` (scalar events).
      Axis indexing random events, whose covariance we are interested in.
      If a vector, entries must form a contiguous block of dims. `sample_axis`
      and `event_axis` should not intersect.
      Default value: `-1` (rightmost axis holds events).
    keepdims:  Boolean.  Whether to keep the sample axis as singletons.
    name: Python `str` name prefixed to Ops created by this function.
          Default value: `None` (i.e., `'covariance'`).

  Returns:
    cov: A `Tensor` of same `dtype` as the `x`, and rank equal to
      `rank(x) - len(sample_axis) + 2 * len(event_axis)`.

  Raises:
    AssertionError:  If `x` and `y` are found to have different shape.
    ValueError:  If `sample_axis` and `event_axis` are found to overlap.
    ValueError:  If `event_axis` is found to not be contiguous.
  """

    with tf.name_scope(name or 'covariance'):
        x = tf.convert_to_tensor(x, name='x')
        # Covariance *only* uses the centered versions of x (and y).
        x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True)

        if y is None:
            y = x
        else:
            y = tf.convert_to_tensor(y, name='y', dtype=x.dtype)
            # If x and y have different shape, sample_axis and event_axis will likely
            # be wrong for one of them!
            tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
            y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True)

        if event_axis is None:
            return tf.reduce_mean(x * tf.math.conj(y),
                                  axis=sample_axis,
                                  keepdims=keepdims)

        if sample_axis is None:
            raise ValueError(
                'sample_axis was None, which means all axis hold events, and this '
                'overlaps with event_axis ({})'.format(event_axis))

        event_axis = _make_positive_axis(event_axis, ps.rank(x))
        sample_axis = _make_positive_axis(sample_axis, ps.rank(x))

        # If we get lucky and axis is statically defined, we can do some checks.
        if _is_list_like(event_axis) and _is_list_like(sample_axis):
            event_axis = tuple(map(int, event_axis))
            sample_axis = tuple(map(int, sample_axis))
            if set(event_axis).intersection(sample_axis):
                raise ValueError(
                    'sample_axis ({}) and event_axis ({}) overlapped'.format(
                        sample_axis, event_axis))
            if (np.diff(np.array(sorted(event_axis))) > 1).any():
                raise ValueError(
                    'event_axis must be contiguous. Found: {}'.format(
                        event_axis))
            batch_axis = list(
                sorted(
                    set(range(tensorshape_util.rank(
                        x.shape))).difference(sample_axis + event_axis)))
        else:
            batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)),
                                      ps.concat((sample_axis, event_axis), 0))

        event_axis = ps.cast(event_axis, dtype=tf.int32)
        sample_axis = ps.cast(sample_axis, dtype=tf.int32)
        batch_axis = ps.cast(batch_axis, dtype=tf.int32)

        # Permute x/y until shape = B + E + S
        perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0)
        x_permed = tf.transpose(a=x, perm=perm_for_xy)
        y_permed = tf.transpose(a=y, perm=perm_for_xy)

        batch_ndims = ps.size(batch_axis)
        batch_shape = ps.shape(x_permed)[:batch_ndims]
        event_ndims = ps.size(event_axis)
        event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims]
        sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:]
        sample_ndims = ps.size(sample_shape)
        n_samples = ps.reduce_prod(sample_shape)
        n_events = ps.reduce_prod(event_shape)

        # Flatten sample_axis into one long dim.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        # Do the same for event_axis.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))

        # After matmul, cov.shape = batch_shape + [n_events, n_events]
        cov = tf.matmul(x_permed_flat, y_permed_flat,
                        adjoint_b=True) / ps.cast(n_samples, x.dtype)

        # Insert some singletons to make
        # cov.shape = batch_shape + event_shape**2 + [1,...,1]
        # This is just like x_permed.shape, except the sample_axis is all 1's, and
        # the [n_events] became event_shape**2.
        cov = tf.reshape(
            cov,
            ps.concat(
                (
                    batch_shape,
                    # event_shape**2 used here because it is the same length as
                    # event_shape, and has the same number of elements as one
                    # batch of covariance.
                    event_shape**2,
                    ps.ones([sample_ndims], tf.int32)),
                0))
        # Permuting by the argsort inverts the permutation, making
        # cov.shape have ones in the position where there were samples, and
        # [n_events * n_events] in the event position.
        cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy))

        # Now expand event_shape**2 into event_shape + event_shape.
        # We here use (for the first time) the fact that we require event_axis to be
        # contiguous.
        e_start = event_axis[0]
        e_len = 1 + event_axis[-1] - event_axis[0]
        cov = tf.reshape(
            cov,
            ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape,
                       ps.shape(cov)[e_start + e_len:]), 0))

        # tf.squeeze requires python ints for axis, not Tensor.  This is enough to
        # require our axis args to be constants.
        if not keepdims:
            squeeze_axis = ps.where(sample_axis < e_start, sample_axis,
                                    sample_axis + e_len)
            cov = _squeeze(cov, axis=squeeze_axis)

        return cov
 def _batch_shape_tensor(self):
     x = self._probs if self._logits is None else self._logits
     return ps.shape(x)
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'one_step')):
            variance_parts = previous_kernel_results.running_variance
            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]
            # Set the momentum.
            batch_ndims = ps.rank(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            state_parts = tf.nest.flatten(current_state)
            new_momentum_distribution = _make_momentum_distribution(
                diags, state_parts, batch_ndims)
            inner_results = self.momentum_distribution_setter_fn(
                previous_kernel_results.inner_results,
                new_momentum_distribution)

            # Step the inner kernel.
            inner_kwargs = {} if seed is None else dict(seed=seed)
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results, **inner_kwargs)
            new_state_parts = tf.nest.flatten(new_state)
            new_variance_parts = []
            for variance_part, diag, state_part in zip(variance_parts, diags,
                                                       new_state_parts):
                # Compute new variance for each variance part, accounting for partial
                # batching of the variance calculation across chains (ie, some, all, or
                # none of the chains may share the estimated mass matrix).
                #
                # For example, say
                #
                # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                # variance_part has shape          [4] + [5, 6]
                # log_prob has shape         [2, 3, 4]
                #
                # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                # matrices, each being shared across a [2, 3]-batch of chains. Note this
                # division is inferred from the shapes of the state part, the log_prob,
                # and the user-provided initial running variances.
                #
                # Until RunningVariance supports rank > 1 chunking, we need to flatten
                # the states that go into updating the variance estimates. In the above
                # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                # fed to `RunningVariance.update(state_part, axis=0)`, recording
                # 6 new observations in the running variance calculation.
                # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                # the resulting momentum distribution will have batch shape of
                # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                state_rank = ps.rank(state_part)
                variance_rank = ps.rank(diag)
                num_reduce_dims = state_rank - variance_rank

                state_part_shape = ps.shape(state_part)
                # This reshape adds a 1 when reduce_dims==0, and collapses all the lead
                # dimensions to a single one otherwise.
                reshaped_state = ps.reshape(
                    state_part,
                    ps.concat(
                        [[ps.reduce_prod(state_part_shape[:num_reduce_dims])],
                         state_part_shape[num_reduce_dims:]],
                        axis=0))

                # The `axis=0` here removes the leading dimension we got from the
                # reshape above, so the new_variance_parts have the correct shape again.
                new_variance_parts.append(
                    variance_part.update(reshaped_state, axis=0))

            new_kernel_results = previous_kernel_results._replace(
                inner_results=new_inner_results,
                running_variance=new_variance_parts)

            return new_state, new_kernel_results
 def _sample_n(self, n, seed=None):
     probs = self._probs_parameter_no_checks()
     new_shape = ps.concat([[n], ps.shape(probs)], 0)
     uniform = samplers.uniform(new_shape, seed=seed, dtype=probs.dtype)
     sample = tf.less(uniform, probs)
     return tf.cast(sample, self.dtype)
Example #27
0
def one_step_predictive(model,
                        posterior_samples,
                        num_forecast_steps=0,
                        original_mean=0.,
                        original_scale=1.,
                        thin_every=10):
  """Constructs a one-step-ahead predictive distribution at every timestep.

  Unlike the generic `tfp.sts.one_step_predictive`, this method uses the
  latent levels from Gibbs sampling to efficiently construct a predictive
  distribution that mixes over posterior samples. The predictive distribution
  may also include additional forecast steps.

  This method returns the predictive distributions for each timestep given
  previous timesteps and sampled model parameters, `p(observed_time_series[t] |
  observed_time_series[:t], weights, observation_noise_scale)`. Note that the
  posterior values of the weights and noise scale will in general be informed
  by observations from all timesteps *including the step being predicted*, so
  this is not a strictly kosher probabilistic quantity, but in general we assume
  that it's close, i.e., that the step being predicted had very small individual
  impact on the overall parameter posterior.

  Args:
    model: A `tfd.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    posterior_samples: A `GibbsSamplerState` instance in which each element is a
      `Tensor` with initial dimension of size `num_samples`.
    num_forecast_steps: Python `int` number of additional forecast steps to
      append.
      Default value: `0`.
    original_mean: Optional scalar float `Tensor`, added to the predictive
      distribution to undo the effect of input normalization.
      Default value: `0.`
    original_scale: Optional scalar float `Tensor`, used to rescale the
      predictive distribution to undo the effect of input normalization.
      Default value: `1.`
    thin_every: Optional Python `int` factor by which to thin the posterior
      samples, to reduce complexity of the predictive distribution. For example,
      if `thin_every=10`, every `10`th sample will be used.
      Default value: `10`.
  Returns:
    predictive_dist: A `tfd.MixtureSameFamily` instance of event shape
      `[num_timesteps + num_forecast_steps]` representing the predictive
      distribution of each timestep given previous timesteps.
  """
  dtype = dtype_util.common_dtype([
      posterior_samples.level_scale,
      posterior_samples.observation_noise_scale,
      posterior_samples.level,
      original_mean,
      original_scale], dtype_hint=tf.float32)
  num_observed_steps = prefer_static.shape(posterior_samples.level)[-1]

  original_mean = tf.convert_to_tensor(original_mean, dtype=dtype)
  original_scale = tf.convert_to_tensor(original_scale, dtype=dtype)
  thinned_samples = tf.nest.map_structure(lambda x: x[::thin_every],
                                          posterior_samples)

  if prefer_static.rank_from_shape(  # If no slope was inferred, treat as zero.
      prefer_static.shape(thinned_samples.slope)) <= 1:
    thinned_samples = thinned_samples._replace(
        slope=tf.zeros_like(thinned_samples.level),
        slope_scale=tf.zeros_like(thinned_samples.level_scale))

  num_steps_from_last_observation = tf.concat([
      tf.ones([num_observed_steps], dtype=dtype),
      tf.range(1, num_forecast_steps + 1, dtype=dtype)], axis=0)

  # The local linear trend model expects that the level at step t + 1 is equal
  # to the level at step t, plus the slope at time t - 1,
  # plus transition noise of scale 'level_scale' (which we account for below).
  if num_forecast_steps > 0:
    num_batch_dims = prefer_static.rank_from_shape(
        prefer_static.shape(thinned_samples.level)) - 2
    # All else equal, the current level will remain stationary.
    forecast_level = tf.tile(thinned_samples.level[..., -1:],
                             tf.concat([tf.ones([num_batch_dims + 1],
                                                dtype=tf.int32),
                                        [num_forecast_steps]], axis=0))
    # If the model includes slope, the level will steadily increase.
    forecast_level += (thinned_samples.slope[..., -1:] *
                       tf.range(1., num_forecast_steps + 1.,
                                dtype=forecast_level.dtype))

  level_pred = tf.concat([thinned_samples.level[..., :1],  # t == 0
                          (thinned_samples.level[..., :-1] +
                           thinned_samples.slope[..., :-1])  # 1 <= t < T
                         ] + (
                             [forecast_level]
                             if num_forecast_steps > 0 else []),
                         axis=-1)

  design_matrix = _get_design_matrix(
      model).to_dense()[:num_observed_steps + num_forecast_steps]
  regression_effect = tf.linalg.matvec(design_matrix, thinned_samples.weights)

  y_mean = ((level_pred + regression_effect) *
            original_scale[..., tf.newaxis] + original_mean[..., tf.newaxis])

  # To derive a forecast variance, including slope uncertainty, let
  #  `r[:k]` be iid Gaussian RVs with variance `level_scale**2` and `s[:k]` be
  # iid Gaussian RVs with variance `slope_scale**2`. Then the forecast level at
  # step `T + k` can be written as
  #   (level[T] +           # Last known level.
  #    r[0] + ... + r[k] +  # Sum of random walk terms on level.
  #    slope[T] * k         # Contribution from last known slope.
  #    (k - 1) * s[0] +     # Contributions from random walk terms on slope.
  #    (k - 2) * s[1] +
  #    ... +
  #    1 * s[k - 1])
  # which has variance of
  #  (level_scale**2 * k +
  #   slope_scale**2 * ( (k - 1)**2 +
  #                      (k - 2)**2 +
  #                      ... + 1 ))
  # Here the `slope_scale` coefficient is the `k - 1`th square pyramidal
  # number [1], which is given by
  #  (k - 1) * k * (2 * k - 1) / 6.
  #
  # [1] https://en.wikipedia.org/wiki/Square_pyramidal_number
  variance_from_level = (thinned_samples.level_scale[..., tf.newaxis]**2 *
                         num_steps_from_last_observation)
  variance_from_slope = thinned_samples.slope_scale[..., tf.newaxis]**2 * (
      (num_steps_from_last_observation - 1) *
      num_steps_from_last_observation *
      (2 * num_steps_from_last_observation - 1)) / 6.
  y_scale = (original_scale * tf.sqrt(
      thinned_samples.observation_noise_scale[..., tf.newaxis]**2 +
      variance_from_level + variance_from_slope))

  num_posterior_draws = prefer_static.shape(y_mean)[0]
  return tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(
          logits=tf.zeros([num_posterior_draws], dtype=y_mean.dtype)),
      components_distribution=tfd.Normal(
          loc=dist_util.move_dimension(y_mean, 0, -1),
          scale=dist_util.move_dimension(y_scale, 0, -1)))
Example #28
0
 def _batch_shape_tensor(self, loc=None, concentration=None):
     return prefer_static.broadcast_shape(
         prefer_static.shape(self.loc if loc is None else loc),
         prefer_static.shape(self.concentration
                             if concentration is None else concentration))
Example #29
0
  def vectorized_fn(*args):
    """Vectorized version of `fn` that accepts arguments of any rank."""
    with tf.name_scope(name or 'make_rank_polymorphic'):
      assertions = []

      # If we got a single value for core_ndims, tile it across all args.
      core_ndims_structure = (
          core_ndims
          if tf.nest.is_nested(core_ndims)
          else tf.nest.map_structure(lambda _: core_ndims, args))

      # Build flat lists of all argument parts and their corresponding core
      # ndims.
      flat_core_ndims = tf.nest.flatten(core_ndims_structure)
      flat_args = nest.flatten_up_to(
          core_ndims_structure, args, check_types=False)

      # Filter to only the `Tensor`-valued args (taken to be those with `None`
      # values for `core_ndims`). Other args will be passed through to `fn`
      # unmodified.
      (vectorized_arg_core_ndims,
       vectorized_args,
       fn_of_vectorized_args) = _lock_in_non_vectorized_args(
           fn,
           arg_structure=core_ndims_structure,
           flat_core_ndims=flat_core_ndims,
           flat_args=flat_args)

      # `vectorized_map` requires all inputs to have a single, common batch
      # dimension `[n]`. So we broadcast all input parts to a common
      # batch shape, then flatten it down to a single dimension.

      # First, compute how many 'extra' (batch) ndims each part has. This must
      # be nonnegative.
      vectorized_arg_shapes = [ps.shape(arg) for arg in vectorized_args]
      batch_ndims = [
          ps.rank_from_shape(arg_shape) - nd
          for (arg_shape, nd) in zip(
              vectorized_arg_shapes, vectorized_arg_core_ndims)]
      static_ndims = [tf.get_static_value(nd) for nd in batch_ndims]
      if any([nd and nd < 0 for nd in static_ndims]):
        raise ValueError('Cannot broadcast a Tensor having lower rank than the '
                         'specified `core_ndims`! (saw input ranks {}, '
                         '`core_ndims` {}).'.format(
                             tf.nest.map_structure(
                                 ps.rank_from_shape,
                                 vectorized_arg_shapes),
                             vectorized_arg_core_ndims))
      if validate_args:
        for nd, part, core_nd in zip(
            batch_ndims, vectorized_args, vectorized_arg_core_ndims):
          assertions.append(tf.debugging.assert_non_negative(
              nd, message='Cannot broadcast a Tensor having lower rank than '
              'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format(
                  part, core_nd)))

      # Next, split each part's shape into batch and core shapes, and
      # broadcast the batch shapes.
      with tf.control_dependencies(assertions):
        empty_shape = np.zeros([0], dtype=np.int32)
        batch_shapes, core_shapes = empty_shape, empty_shape
        if vectorized_arg_shapes:
          batch_shapes, core_shapes = zip(*[
              (arg_shape[:nd], arg_shape[nd:])
              for (arg_shape, nd) in zip(vectorized_arg_shapes, batch_ndims)])
        broadcast_batch_shape = (
            functools.reduce(ps.broadcast_shape, batch_shapes, []))

      # Flatten all of the batch dimensions into one.
      n = tf.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32)
      static_n = tf.get_static_value(n)
      if static_n == 1:
        # We can bypass `vectorized_map` if the batch shape is `[]`, `[1]`,
        # `[1, 1]`, etc., just by flattening to batch shape `[]`.
        result_batch_dims = 0
        batched_result = fn_of_vectorized_args(
            tf.nest.map_structure(
                lambda x, nd: tf.reshape(x, ps.shape(x)[ps.rank(x) - nd:]),
                vectorized_args,
                vectorized_arg_core_ndims))
      else:
        # Pad all input parts to the common shape, then flatten
        # into the single leading dimension `[n]`.
        # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
        # when batch rank is static so that we can exploit broadcasting.
        broadcast_vectorized_args = [
            tf.broadcast_to(part, ps.concat(
                [broadcast_batch_shape, core_shape], axis=0))
            for (part, core_shape) in zip(vectorized_args, core_shapes)]
        vectorized_args_with_flattened_batch_dim = [
            tf.reshape(part, ps.concat([[n], core_shape], axis=0))
            for (part, core_shape) in zip(
                broadcast_vectorized_args, core_shapes)]
        result_batch_dims = 1
        batched_result = tf.vectorized_map(
            fn_of_vectorized_args, vectorized_args_with_flattened_batch_dim)

      # Unflatten any `Tensor`s in the result.
      unflatten = lambda x: tf.reshape(x, ps.concat([  # pylint: disable=g-long-lambda
          broadcast_batch_shape, ps.shape(x)[result_batch_dims:]], axis=0))
      result = tf.nest.map_structure(
          lambda x: unflatten(x) if tf.is_tensor(x) else x, batched_result,
          expand_composites=True)
    return result
Example #30
0
def _resample_weights(design_matrix,
                      target_residuals,
                      observation_noise_scale,
                      weights_prior_scale,
                      is_missing=None,
                      seed=None):
    """Samples regression weights from their conditional posterior.

  This assumes a conjugate normal regression model,

  ```
  weights ~ Normal(loc=0., covariance_matrix=weights_prior_scale**2 * I)
  target_residuals ~ Normal(loc=matvec(design_matrix, weights),
                            covariance_matrix=observation_noise_scale**2 * I)
  ```

  and returns a sample from `p(weights | target_residuals,
    observation_noise_scale, design_matrix)`.

  Args:
    design_matrix: Float `Tensor` design matrix of shape
      `[..., num_timesteps, num_features]`.
    target_residuals:  Float `Tensor` of shape `[..., num_observations]`
    observation_noise_scale: Scalar float `Tensor` (with optional batch shape)
      standard deviation of the iid observation noise.
    weights_prior_scale: Scalar float `Tensor` (with optional batch shape)
      specifying the standard deviation of the Normal prior on regression
      weights.
    is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A
      `True` value indicates that the observation for that timestep is missing.
    seed: Optional `Python` `int` seed controlling the sampled values.
  Returns:
    weights: Float `Tensor` of shape `[..., num_features]`, sampled from
      the conditional posterior `p(weights | target_residuals,
      observation_noise_scale, weights_prior_scale)`.
  """
    if is_missing is not None:
        # Replace design matrix with zeros at unobserved timesteps. This ensures
        # they will not affect the posterior on weights.
        design_matrix = tf.where(is_missing[..., tf.newaxis],
                                 tf.zeros_like(design_matrix), design_matrix)
    design_shape = prefer_static.shape(design_matrix)
    num_outputs = design_shape[-2]
    num_features = design_shape[-1]

    iid_prior_scale = tf.linalg.LinearOperatorScaledIdentity(
        num_rows=num_features, multiplier=weights_prior_scale)
    iid_likelihood_scale = tf.linalg.LinearOperatorScaledIdentity(
        num_rows=num_outputs, multiplier=observation_noise_scale)
    weights_mean, weights_prec = (
        normal_conjugate_posteriors.mvn_conjugate_linear_update(
            linear_transformation=design_matrix,
            observation=target_residuals,
            prior_scale=iid_prior_scale,
            likelihood_scale=iid_likelihood_scale))
    sampled_weights = weights_prec.cholesky().solvevec(samplers.normal(
        shape=prefer_static.shape(weights_mean),
        dtype=design_matrix.dtype,
        seed=seed),
                                                       adjoint=True)
    return weights_mean + sampled_weights