def chees_criterion(previous_state,
                    proposed_state,
                    accept_prob,
                    validate_args=False):
  """The ChEES criterion from [1].

  ChEES stands for Change in the Estimator of the Expected Square.

  ```None
  ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2],
  ```

  where `x` is the previous chain state, `x'` is the next chain state, and
  `||.||` is the L2 norm. Both expectations are with respect to the chain's
  stationary distribution. In practice, the inner expectation is replaced by the
  empirical mean across chains, so computing this criterion requires that at
  least 2 chains are present. The outer expectation is computed by the caller
  (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).

  This can be thought of as the standard expected squared jump distance (ESJD)
  criterion, except that the jump distance is computed in the space of centered
  squared L2 norms.

  Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals,
  which can give excellent mean estimates but terrible variance estimates;
  maximizing ChEES should give good estimates across a wider range of types of
  posterior expectations.

  Args:
    previous_state: (Possibly nested) floating point `Tensor`. The previous
      state of the HMC chain.
    proposed_state: (Possibly nested) floating point `Tensor`. The proposed
      state of the HMC chain.
    accept_prob: Floating `Tensor`. Probability of acceping the proposed state.
    validate_args: Whether to perform non-static argument validation.

  Returns:
    chees: The value of the ChEES criterion.

  Raises:
    ValueError: If `accept_prob` indicates that there are fewer than 2 chains.

  #### References

  [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme
       for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In
       preparation.

  """
  batch_ndims = ps.rank(accept_prob)
  batch_axes = ps.range(batch_ndims, dtype=tf.int32)
  num_chains = ps.size(accept_prob)
  num_chains_ = tf.get_static_value(num_chains)
  if num_chains_ is not None:
    if num_chains_ < 2:
      raise ValueError(
          'chees_criterion requires at least 2 chains. Got: {}'.format(
              num_chains_))
  elif validate_args:
    with tf.control_dependencies([
        assert_util.assert_greater_equal(
            num_chains, 2, 'chees_criterion requires at least 2 chains.')
    ]):
      previous_state = tf.nest.map_structure(tf.identity, previous_state)

  def _center_previous_state(x):
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term.
    return x - tf.stop_gradient(tf.reduce_mean(x, axis=batch_axes))

  def _center_proposed_state(x):
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term. The goal here is to get a reliable
    # diagnostic of the unrelying dynamics, rather than incorporating the effect
    # of the MetropolisHastings correction.
    # TODO(mhoffman): Needs more experimentation.
    expanded_accept_prob = mcmc_util.left_justified_expand_dims_like(
        accept_prob, x)

    # accept_prob is zero when x is NaN, but we still want to sanitize such
    # values.
    x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
    # If all accept_prob's are zero, the x_center will have a nonsense value,
    # but we'll discard the resultant gradients later on, so it's fine.
    x_center = (
        tf.reduce_sum(expanded_accept_prob * x_safe, axis=batch_axes) /
        (tf.reduce_sum(expanded_accept_prob, axis=batch_axes) + 1e-20))

    return x - tf.stop_gradient(x_center)

  def _sum_event_part(x):
    event_axes = ps.range(batch_ndims, ps.rank(x))
    return tf.reduce_sum(x, axis=event_axes)

  def _sum_event(x):
    return sum(tf.nest.flatten(tf.nest.map_structure(
        _sum_event_part,
        x,
    )))

  def _square(x):
    return tf.nest.map_structure(tf.square, x)

  def _sub(x, y):
    return tf.nest.map_structure(lambda x, y: x - y, x, y)

  previous_state = tf.nest.map_structure(_center_previous_state, previous_state)
  proposed_state = tf.nest.map_structure(_center_proposed_state, proposed_state)
  chees = 0.25 * tf.square(
      _sum_event(_sub(_square(proposed_state), _square(previous_state))))
  return chees
    def posterior_marginals(self, observations, mask=None, name=None):
        """Compute marginal posterior distribution for each state.

    This function computes, for each time step, the marginal
    conditional probability that the hidden Markov model was in
    each possible state given the observations that were made
    at each time step.
    So if the hidden states are `z[0],...,z[num_steps - 1]` and
    the observations are `x[0], ..., x[num_steps - 1]`, then
    this function computes `P(z[i] | x[0], ..., x[num_steps - 1])`
    for all `i` from `0` to `num_steps - 1`.

    This operation is sometimes called smoothing. It uses a form
    of the forward-backward algorithm.

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Args:
      observations: A tensor representing a batch of observations
        made on the hidden Markov model.  The rightmost dimension of this tensor
        gives the steps in a sequence of observations from a single sample from
        the hidden Markov model. The size of this dimension should match the
        `num_steps` parameter of the hidden Markov model object. The other
        dimensions are the dimensions of the batch and these are broadcast with
        the hidden Markov model's parameters.
      mask: optional bool-type `tensor` with rightmost dimension matching
        `num_steps` indicating which observations the result of this
        function should be conditioned on. When the mask has value
        `True` the corresponding observations aren't used.
        if `mask` is `None` then all of the observations are used.
        the `mask` dimensions left of the last are broadcast with the
        hmm batch as well as with the observations.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_marginal: A `Categorical` distribution object representing the
        marginal probability of the hidden Markov model being in each state at
        each step. The rightmost dimension of the `Categorical` distributions
        batch will equal the `num_steps` parameter providing one marginal
        distribution for each step. The other dimensions are the dimensions
        corresponding to the batch of observations.

    Raises:
      ValueError: if rightmost dimension of `observations` does not
      have size `num_steps`.
    """

        with tf.name_scope(name or "posterior_marginals"):
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(observations)
                mask_tensor_shape = tf.shape(
                    mask) if mask is not None else None

                with self._observation_mask_shape_preconditions(
                        observation_tensor_shape, mask_tensor_shape):
                    observation_log_probs = self._observation_log_probs(
                        observations, mask)
                    log_prob = self._log_init + observation_log_probs[0]
                    log_transition = self._log_trans
                    log_adjoint_prob = tf.zeros_like(log_prob)

                    def _scan_multiple_steps_forwards():
                        def forward_step(log_previous_step,
                                         log_prob_observation):
                            return _log_vector_matrix(
                                log_previous_step,
                                log_transition) + log_prob_observation

                        forward_log_probs = tf.scan(forward_step,
                                                    observation_log_probs[1:],
                                                    initializer=log_prob,
                                                    name="forward_log_probs")
                        return tf.concat([[log_prob], forward_log_probs],
                                         axis=0)

                    forward_log_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_forwards,
                        lambda: tf.convert_to_tensor([log_prob]))

                    total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1],
                                                         axis=-1)

                    def _scan_multiple_steps_backwards():
                        """Perform `scan` operation when `num_steps` > 1."""
                        def backward_step(log_previous_step,
                                          log_prob_observation):
                            return _log_matrix_vector(
                                log_transition,
                                log_prob_observation + log_previous_step)

                        backward_log_adjoint_probs = tf.scan(
                            backward_step,
                            observation_log_probs[1:],
                            initializer=log_adjoint_prob,
                            reverse=True,
                            name="backward_log_adjoint_probs")

                        return tf.concat(
                            [backward_log_adjoint_probs, [log_adjoint_prob]],
                            axis=0)

                    backward_log_adjoint_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_backwards,
                        lambda: tf.convert_to_tensor([log_adjoint_prob]))

                    log_likelihoods = forward_log_probs + backward_log_adjoint_probs

                    marginal_log_probs = distribution_util.move_dimension(
                        log_likelihoods - total_log_prob[..., tf.newaxis], 0,
                        -2)

                    return categorical.Categorical(logits=marginal_log_probs)
Esempio n. 3
0
    def op(x, kernel):
        input_dtype = dtype_util.common_dtype([x, kernel],
                                              dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

        batch_shape, event_shape = ps.split(ps.shape(x),
                                            num_or_size_splits=[-1, 3])
        xh, xw, c_in = ps.unstack(event_shape, num=3)
        fh, fw = filter_shape

        assertions = _maybe_validate_input_shapes(ps.shape(kernel),
                                                  channels_in=c_in,
                                                  filter_height=fh,
                                                  filter_width=fw,
                                                  validate_args=validate_args)

        with tf.control_dependencies(assertions):
            if tf.get_static_value(ps.rank(kernel)) == 2:
                flat_x = tf.reshape(x,
                                    shape=ps.concat([[-1], event_shape],
                                                    axis=0))
                flat_y = tf.nn.conv2d(x,
                                      filters=tf.reshape(
                                          kernel, shape=[fh, fw, c_in, -1]),
                                      strides=strides,
                                      padding=padding,
                                      data_format='NHWC',
                                      dilations=dilations)
                output_shape = ps.shape(flat_y)[-3:]
                return tf.reshape(flat_y,
                                  shape=ps.concat([batch_shape, output_shape],
                                                  axis=0))

            pad_values = [
                _get_conv_padding(xdim,
                                  filter_dim=k,
                                  stride=s,
                                  dilation=d,
                                  padding=padding)
                for (xdim, k, s,
                     d) in zip((xh, xw), filter_shape, strides, dilations)
            ]

            idx, shape = im2row_index(
                (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in),
                block_shape=filter_shape,
                slice_step=strides,
                dilations=dilations,
                dtype=dtype)

            if padding == 'SAME':
                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x = tf.pad(x, paddings=paddings, constant_values=0)

            flat_shape = ps.pad(batch_shape,
                                paddings=[[0, 1]],
                                constant_values=-1)
            flat_x = tf.gather(tf.reshape(x, shape=flat_shape),
                               indices=idx,
                               axis=-1)
            im_x = tf.reshape(flat_x,
                              shape=ps.concat([batch_shape, shape], axis=0))
            return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Esempio n. 4
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = _lu_reconstruct_assertions(lower_upper, perm,
                                                validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if lower_upper.shape.ndims is None or lower_upper.shape.ndims != 2:
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        x.set_shape(lower_upper.shape)
        return x
 def _event_shape_tensor(self):
     with tf.control_dependencies(self._runtime_assertions):
         return tf.concat(
             [[self._num_steps],
              self.observation_distribution.event_shape_tensor()],
             axis=0)
Esempio n. 6
0
def fit_with_hmc(model,
                 observed_time_series,
                 num_results=100,
                 num_warmup_steps=50,
                 num_leapfrog_steps=15,
                 initial_state=None,
                 initial_step_size=None,
                 chain_batch_shape=(),
                 num_variational_steps=150,
                 variational_optimizer=None,
                 variational_sample_size=5,
                 seed=None,
                 name=None):
    """Draw posterior samples using Hamiltonian Monte Carlo (HMC).

  Markov chain Monte Carlo (MCMC) methods are considered the gold standard of
  Bayesian inference; under suitable conditions and in the limit of infinitely
  many draws they generate samples from the true posterior distribution. HMC [1]
  uses gradients of the model's log-density function to propose samples,
  allowing it to exploit posterior geometry. However, it is computationally more
  expensive than variational inference and relatively sensitive to tuning.

  This method attempts to provide a sensible default approach for fitting
  StructuralTimeSeries models using HMC. It first runs variational inference as
  a fast posterior approximation, and initializes the HMC sampler from the
  variational posterior, using the posterior standard deviations to set
  per-variable step sizes (equivalently, a diagonal mass matrix). During the
  warmup phase, it adapts the step size to target an acceptance rate of 0.75,
  which is thought to be in the desirable range for optimal mixing [2].


  Args:
    model: An instance of `StructuralTimeSeries` representing a
      time-series model. This represents a joint distribution over
      time-series and their parameters with batch shape `[b1, ..., bN]`.
    observed_time_series: `float` `Tensor` of shape
      `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where
      `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]`
      dimension may (optionally) be omitted if `num_timesteps > 1`. Any `NaN`s
        are interpreted as missing observations; missingness may be also be
        explicitly specified by passing a `tfp.sts.MaskedTimeSeries` instance.
    num_results: Integer number of Markov chain draws.
      Default value: `100`.
    num_warmup_steps: Integer number of steps to take before starting to
      collect results. The warmup steps are also used to adapt the step size
      towards a target acceptance rate of 0.75.
      Default value: `50`.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to
      `step_size * num_leapfrog_steps`.
      Default value: `15`.
    initial_state: Optional Python `list` of `Tensor`s, one for each model
      parameter, representing the initial state(s) of the Markov chain(s). These
      should have shape `concat([chain_batch_shape, param.prior.batch_shape,
      param.prior.event_shape])`. If `None`, the initial state is set
      automatically using a sample from a variational posterior.
      Default value: `None`.
    initial_step_size: Python `list` of `Tensor`s, one for each model parameter,
      representing the step size for the leapfrog integrator. Must
      broadcast with the shape of `initial_state`. Larger step sizes lead to
      faster progress, but too-large step sizes make rejection exponentially
      more likely. If `None`, the step size is set automatically using the
      standard deviation of a variational posterior.
      Default value: `None`.
    chain_batch_shape: Batch shape (Python `tuple`, `list`, or `int`) of chains
      to run in parallel.
      Default value: `[]` (i.e., a single chain).
    num_variational_steps: Python `int` number of steps to run the variational
      optimization to determine the initial state and step sizes.
      Default value: `150`.
    variational_optimizer: Optional `tf.train.Optimizer` instance to use in
      the variational optimization. If `None`, defaults to
      `tf.train.AdamOptimizer(0.1)`.
      Default value: `None`.
    variational_sample_size: Python `int` number of Monte Carlo samples to use
      in estimating the variational divergence. Larger values may stabilize
      the optimization, but at higher cost per step in time and memory.
      Default value: `1`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'fit_with_hmc').

  Returns:
    samples: Python `list` of `Tensors` representing posterior samples of model
      parameters, with shapes `[concat([[num_results], chain_batch_shape,
      param.prior.batch_shape, param.prior.event_shape]) for param in
      model.parameters]`.
    kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
      `Tensor`s representing internal calculations made within the HMC sampler.

  #### Examples

  Assume we've built a structural time-series model:

  ```python
    day_of_week = tfp.sts.Seasonal(
        num_seasons=7,
        observed_time_series=observed_time_series,
        name='day_of_week')
    local_linear_trend = tfp.sts.LocalLinearTrend(
        observed_time_series=observed_time_series,
        name='local_linear_trend')
    model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                        observed_time_series=observed_time_series)
  ```

  To draw posterior samples using HMC under default settings:

  ```python
  samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)
  print("acceptance rate: {}".format(
    np.mean(kernel_results.inner_results.inner_results.is_accepted, axis=0)))
  print("posterior means: {}".format(
    {param.name: np.mean(param_draws, axis=0)
     for (param, param_draws) in zip(model.parameters, samples)}))
  ```

  We can also run multiple chains. This may help diagnose convergence issues
  and allows us to exploit vectorization to draw samples more quickly, although
  warmup still requires the same number of sequential steps.

  ```python
  from matplotlib import pylab as plt

  samples, kernel_results = tfp.sts.fit_with_hmc(
    model, observed_time_series, chain_batch_shape=[10])
  print("acceptance rate: {}".format(
    np.mean(kernel_results.inner_results.inner_results.is_accepted, axis=0)))

  # Plot the sampled traces for each parameter. If the chains have mixed, their
  # traces should all cover the same region of state space, frequently crossing
  # over each other.
  for (param, param_draws) in zip(model.parameters, samples):
    if param.prior.event_shape.ndims > 0:
      print("Only plotting traces for scalar parameters, skipping {}".format(
        param.name))
      continue
    plt.figure(figsize=[10, 4])
    plt.title(param.name)
    plt.plot(param_draws.numpy())
    plt.ylabel(param.name)
    plt.xlabel("HMC step")

  # Combining the samples from multiple chains into a single dimension allows
  # us to easily pass sampled parameters to downstream forecasting methods.
  combined_samples = [np.reshape(param_draws,
                                 [-1] + list(param_draws.shape[2:]))
                      for param_draws in samples]
  ```

  For greater flexibility, you may prefer to implement your own sampler using
  the TensorFlow Probability primitives in `tfp.mcmc`. The following recipe
  constructs a basic HMC sampler, using a `TransformedTransitionKernel` to
  incorporate constraints on the parameter space.

  ```python
  transformed_hmc_kernel = tfp.mcmc.TransformedTransitionKernel(
      inner_kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
          inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
              target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob,
              step_size=step_size,
              num_leapfrog_steps=num_leapfrog_steps,
              state_gradients_are_stopped=True,
              seed=seed),
          num_adaptation_steps = int(0.8 * num_warmup_steps)),
      bijector=[param.bijector for param in model.parameters])

  # Initialize from a Uniform[-2, 2] distribution in unconstrained space.
  initial_state = [tfp.sts.sample_uniform_initial_state(
    param, return_constrained=True) for param in model.parameters]

  samples, kernel_results = tfp.mcmc.sample_chain(
    kernel=transformed_hmc_kernel,
    num_results=num_results,
    current_state=initial_state,
    num_burnin_steps=num_warmup_steps)
  ```

  #### References

  [1]: Radford Neal. MCMC Using Hamiltonian Dynamics. _Handbook of Markov Chain
       Monte Carlo_, 2011. https://arxiv.org/abs/1206.1901
  [2]  M.J. Betancourt, Simon Byrne, and Mark Girolami. Optimizing The
       Integrator Step Size for Hamiltonian Monte Carlo.
       https://arxiv.org/abs/1411.6669

  """
    with tf.name_scope(name or 'fit_with_hmc') as name:
        seed = tfp_util.SeedStream(seed,
                                   salt='StructuralTimeSeries_fit_with_hmc')

        observed_time_series = sts_util.pad_batch_dimension_for_multiple_chains(
            observed_time_series, model, chain_batch_shape=chain_batch_shape)
        target_log_prob_fn = model.joint_distribution(
            observed_time_series).log_prob

        # Initialize state and step sizes from a variational posterior if not
        # specified.
        if initial_step_size is None or initial_state is None:
            variational_posterior = build_factored_surrogate_posterior(
                model, batch_shape=chain_batch_shape, seed=seed())

            if variational_optimizer is None:
                variational_optimizer = tf1.train.AdamOptimizer(
                    learning_rate=0.1
                )  # TODO(b/137299119) Replace with TF2 optimizer.
            loss_curve = vi.fit_surrogate_posterior(
                target_log_prob_fn,
                variational_posterior,
                sample_size=variational_sample_size,
                num_steps=num_variational_steps,
                optimizer=variational_optimizer,
                seed=seed())

            with tf.control_dependencies([loss_curve]):
                if initial_state is None:
                    posterior_sample = variational_posterior.sample()
                    initial_state = [
                        posterior_sample[p.name] for p in model.parameters
                    ]

                # Set step sizes using the unconstrained variational distribution.
                if initial_step_size is None:
                    q_dists_by_name, _ = (variational_posterior.distribution.
                                          sample_distributions())
                    initial_step_size = [
                        q_dists_by_name[p.name].stddev()
                        for p in model.parameters
                    ]

        # Run HMC to sample from the posterior on parameters.
        @tf.function(autograph=False)
        def run_hmc():
            return mcmc.sample_chain(
                num_results=num_results,
                current_state=initial_state,
                num_burnin_steps=num_warmup_steps,
                kernel=mcmc.DualAveragingStepSizeAdaptation(
                    inner_kernel=mcmc.TransformedTransitionKernel(
                        inner_kernel=mcmc.HamiltonianMonteCarlo(
                            target_log_prob_fn=target_log_prob_fn,
                            step_size=initial_step_size,
                            num_leapfrog_steps=num_leapfrog_steps,
                            state_gradients_are_stopped=True),
                        bijector=[
                            param.bijector for param in model.parameters
                        ]),
                    num_adaptation_steps=int(num_warmup_steps * 0.8)),
                seed=seed())

        samples, kernel_results = run_hmc()

        return samples, kernel_results
Esempio n. 7
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if rhs.shape.ndims == 2 and perm.shape.ndims == 1:
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
Esempio n. 8
0
    def __init__(self,
                 mean_direction,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VonMisesFisher'):
        """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([mean_direction, concentration],
                                            tf.float32)
            mean_direction = tf.convert_to_tensor(mean_direction,
                                                  name='mean_direction',
                                                  dtype=dtype)
            concentration = tf.convert_to_tensor(concentration,
                                                 name='concentration',
                                                 dtype=dtype)
            assertions = [
                assert_util.assert_non_negative(
                    concentration,
                    message='`concentration` must be non-negative'),
                assert_util.assert_greater(
                    tf.shape(mean_direction)[-1],
                    1,
                    message='`mean_direction` may not have scalar event shape'
                ),
                assert_util.assert_near(
                    1.,
                    tf.linalg.norm(mean_direction, axis=-1),
                    message='`mean_direction` must be unit-length')
            ] if validate_args else []
            static_event_dim = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mean_direction.shape,
                                                    1)[-1])
            if static_event_dim is not None and static_event_dim > 5:
                raise ValueError('vMF ndims > 5 is not currently supported')
            elif validate_args:
                assertions += [
                    assert_util.assert_less_equal(
                        tf.shape(mean_direction)[-1],
                        5,
                        message='vMF ndims > 5 is not currently supported')
                ]
            with tf.control_dependencies(assertions):
                self._mean_direction = tf.identity(mean_direction)
                self._concentration = tf.identity(concentration)
            dtype_util.assert_same_float_dtype(
                [self._mean_direction, self._concentration])
            # mean_direction is always reparameterized.
            # concentration is only for event_dim==3, via an inversion sampler.
            reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED
                                       if static_event_dim == 3 else
                                       reparameterization.NOT_REPARAMETERIZED)
            super(VonMisesFisher, self).__init__(
                dtype=self._concentration.dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                graph_parents=[self._mean_direction, self._concentration],
                name=name)
Esempio n. 9
0
    def _sample_n(self, n, seed=None):
        seed = seed_stream.SeedStream(seed, salt='vom_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        event_dim = (tf.compat.dimension_value(self.event_shape[0])
                     or self._event_shape_tensor()[0])

        sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()],
                                       axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n, seed=seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * self.concentration +
                       tf.sqrt(4 * self.concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = self.concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue):
                del w
                return tf.reduce_any(should_continue)

            def body_fn(w, should_continue):
                z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
                w = tf1.where(should_continue,
                              (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
                w = tf.debugging.check_numerics(w, 'w')
                should_continue = tf.logical_and(
                    should_continue,
                    self.concentration * w + dim * tf.math.log1p(-x * w) - c <
                    tf.math.log(
                        tf.random.uniform(sample_batch_shape,
                                          seed=seed(),
                                          dtype=self.dtype)))
                return w, should_continue

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0 = tf.while_loop(cond=cond_fn,
                                         body=body_fn,
                                         loop_vars=(w, should_continue))[0]
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01),
                        data=[tf.nn.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01),
                        data=[
                            -tf.nn.top_k(tf.reshape(-samples_dim0, [-1]))[0]
                        ])
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = tf.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.nn.l2_normalize(tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
                                            axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.nn.l2_normalize(samples, axis=-1)
        if not self._allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self._allow_nan_stats:
            worst, idx = tf.nn.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(
                        dtype_util.as_numpy_dtype(self.dtype)(0),
                        worst,
                        data=[
                            worst, idx,
                            tf.gather(tf.reshape(samples, [-1, event_dim]),
                                      idx)
                        ],
                        atol=1e-4,
                        summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self._allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(self._rotate(basis) -
                                       self.mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples)
        return self._rotate(samples)
Esempio n. 10
0
    def push(self, value, mask, name=None):
        """Pushes `value` onto the stack, advances frame of batch members in `mask`.

    In this impl, we update each thread's top-of-stack (regardless of `mask`) to
    the corresponding `value`, then advance the stack pointers of only those
    threads indicated by `mask`.

    Args:
      value: `Tensor` having the shape of a single batch of the variable.
      mask: Boolean `Tensor` of shape `[batch_size]`. Threads at `True` indices
          of `mask` have their stack frames advanced; the others remain.
      name: Optional name for this op.

    Returns:
      stack: Updated stack. Does not mutate `self`.
      asserted_value: A assertion-bound snapshot of the input `value`,
          assertions used to catch stack overflows.
    """
        with tf.name_scope(name or 'Stack.push'):
            value = tf.convert_to_tensor(value=value, name='value')
            mask = tf.convert_to_tensor(value=mask, name='mask')
            # self.stack:       [max_stack_depth * batch_size, ...]
            # self.stack_index:                   [batch_size]
            # value:                              [batch_size, ...]
            batch_size = (tf.compat.dimension_value(self.stack_index.shape[0])
                          or tf.shape(input=self.stack_index)[0])
            max_stack_depth = (tf.compat.dimension_value(self.stack.shape[0])
                               or tf.shape(input=self.stack)[0]) // batch_size
            max_stack_depth_tensor = tf.convert_to_tensor(
                value=max_stack_depth)
            tiled_value = tf.tile(
                input=value[tf.newaxis, ...],
                multiples=tf.concat(
                    [[max_stack_depth_tensor],
                     tf.ones(tf.rank(value),
                             dtype=max_stack_depth_tensor.dtype)],
                    axis=0))
            update_stack_mask = tf.one_hot(
                self.stack_index,
                depth=max_stack_depth,
                axis=
                0,  # Stack depth x batch are both in outermost dim, stack major.
                on_value=True,
                off_value=False,
                dtype=tf.bool)
            new_stack = tf1.where(
                tf.reshape(update_stack_mask, [-1]),
                tf.reshape(tiled_value, tf.shape(input=self.stack)),
                self.stack)
            new_stack.set_shape(self.stack.shape)
            new_stack_index = self.stack_index + tf.cast(
                mask, self.stack_index.dtype)
            new_stack_index.set_shape(self.stack_index.shape)
            if self._safety_checks():
                with tf.control_dependencies([
                        tf1.assert_less(
                            new_stack_index,
                            tf.cast(max_stack_depth_tensor,
                                    new_stack_index.dtype))
                ]):
                    value = tf.identity(value)
                    new_stack_index = tf.identity(new_stack_index)
            return type(self)(new_stack, new_stack_index), value
Esempio n. 11
0
    def _solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
    ):
        # This function is comprised of the following sequential stages:
        # (1) Make static assertions.
        # (2) Initialize variables.
        # (3) Make non-static assertions.
        # (4) Solve up to final time.
        # (5) Return `Results` object.
        #
        # The stages can be found in the code by searching for (n) where n=1..5.
        #
        # By static vs. non-static assertions (see stages 1 and 3), we mean
        # assertions that can be made before the graph is run vs. those that can
        # only be made at run time. The latter are constructed as a list of
        # tf.Assert operations by the function `assert_ops` (see below).
        #
        # If `solution_times` is specified as a `Tensor`, stage 4 consists of three
        # nested loops, which can be conceptually understood as follows:
        # ```
        # current_time, current_state = initial_time, initial_state
        # order, step_size = 1, first_step_size
        # for solution_time in solution_times:
        #   while current_time < solution_time:
        #     while True:
        #       next_time = current_time + step_size
        #       next_state, error = (
        #           solve_nonlinear_equation_to_get_approximate_state_at_next_time(
        #           current_time, current_state, next_time, order))
        #       if error < tolerance:
        #         current_time, current_state = next_time, next_state
        #         order, step_size = (
        #           maybe_update_order_and_step_size(order, step_size))
        #         break
        #       else:
        #         step_size = decrease_step_size(step_size)
        # ```
        # The outermost loop advances the solver to the next `solution_time` (see
        # `advance_to_solution_time`). The middle loop advances the solver by a
        # small timestep (see `step`). The innermost loop determines the size of
        # that timestep (see `maybe_step`).
        #
        # If `solution_times` is specified as
        # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped
        # and `solution_time` in the middle loop is replaced by `final_time`.

        def assert_ops():
            """Creates a list of assert operations."""
            if not self._validate_args:
                return []
            assert_ops = []
            if previous_solver_internal_state is not None:
                assert_initial_state_matches_previous_solver_internal_state = (
                    tf1.assert_near(
                        tf.norm(
                            initial_state_vec - previous_solver_internal_state.
                            backward_differences[0], np.inf),
                        0.,
                        message=
                        '`previous_solver_internal_state` does not match '
                        '`initial_state`.'))
                assert_ops.append(
                    assert_initial_state_matches_previous_solver_internal_state
                )
            if solution_times_chosen_by_solver:
                assert_ops.append(
                    util.assert_positive(final_time - initial_time,
                                         'final_time - initial_time'))
            else:
                assert_ops += [
                    util.assert_increasing(solution_times, 'solution_times'),
                    util.assert_nonnegative(
                        solution_times[0] - initial_time,
                        'solution_times[0] - initial_time'),
                ]
            if max_num_steps is not None:
                assert_ops.append(
                    util.assert_positive(max_num_steps, 'max_num_steps'))
            if max_num_newton_iters is not None:
                assert_ops.append(
                    util.assert_positive(max_num_newton_iters,
                                         'max_num_newton_iters'))
            assert_ops += [
                util.assert_positive(rtol, 'rtol'),
                util.assert_positive(atol, 'atol'),
                util.assert_positive(first_step_size, 'first_step_size'),
                util.assert_positive(safety_factor, 'safety_factor'),
                util.assert_positive(min_step_size_factor,
                                     'min_step_size_factor'),
                util.assert_positive(max_step_size_factor,
                                     'max_step_size_factor'),
                tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER),
                          [
                              '`max_order` must be between 1 and {}.'.format(
                                  bdf_util.MAX_ORDER)
                          ]),
                util.assert_positive(newton_tol_factor, 'newton_tol_factor'),
                util.assert_positive(newton_step_size_factor,
                                     'newton_step_size_factor'),
            ]
            return assert_ops

        def advance_to_solution_time(n, diagnostics, iterand,
                                     solver_internal_state, state_vec_array,
                                     time_array):
            """Takes multiple steps to advance time to `solution_times[n]`."""
            def step_cond(next_time, diagnostics, iterand, *_):
                return (iterand.time < next_time) & (tf.equal(
                    diagnostics.status, 0))

            nth_solution_time = solution_time_array.read(n)
            [
                _, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ] = tf.while_loop(step_cond, step, [
                nth_solution_time, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ])
            state_vec_array = state_vec_array.write(
                n, solver_internal_state.backward_differences[0])
            time_array = time_array.write(n, nth_solution_time)
            return (n + 1, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def step(next_time, diagnostics, iterand, solver_internal_state,
                 state_vec_array, time_array):
            """Takes a single step."""
            distance_to_next_time = next_time - iterand.time
            overstepped = iterand.new_step_size > distance_to_next_time
            iterand = iterand._replace(new_step_size=tf1.where(
                overstepped, distance_to_next_time, iterand.new_step_size),
                                       should_update_step_size=overstepped
                                       | iterand.should_update_step_size)

            if not self._evaluate_jacobian_lazily:
                diagnostics = diagnostics._replace(
                    num_jacobian_evaluations=diagnostics.
                    num_jacobian_evaluations + 1)
                iterand = iterand._replace(jacobian_mat=jacobian_fn_mat(
                    iterand.time,
                    solver_internal_state.backward_differences[0]),
                                           jacobian_is_up_to_date=True)

            def maybe_step_cond(accepted, diagnostics, *_):
                return tf.logical_not(accepted) & tf.equal(
                    diagnostics.status, 0)

            _, diagnostics, iterand, solver_internal_state = tf.while_loop(
                maybe_step_cond, maybe_step,
                [False, diagnostics, iterand, solver_internal_state])

            if solution_times_chosen_by_solver:
                state_vec_array = state_vec_array.write(
                    state_vec_array.size(),
                    solver_internal_state.backward_differences[0])
                time_array = time_array.write(time_array.size(), iterand.time)

            return (next_time, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
            """Takes a single step only if the outcome has a low enough error."""
            [
                num_jacobian_evaluations, num_matrix_factorizations,
                num_ode_fn_evaluations, status
            ] = diagnostics
            [
                jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps,
                num_steps_same_size, should_update_jacobian,
                should_update_step_size, time, unitary, upper
            ] = iterand
            [backward_differences, order, step_size] = solver_internal_state

            if max_num_steps is not None:
                status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

            backward_differences = tf1.where(
                should_update_step_size,
                bdf_util.interpolate_backward_differences(
                    backward_differences, order, new_step_size / step_size),
                backward_differences)
            step_size = tf1.where(should_update_step_size, new_step_size,
                                  step_size)
            should_update_factorization = should_update_step_size
            num_steps_same_size = tf1.where(should_update_step_size, 0,
                                            num_steps_same_size)

            def update_factorization():
                return bdf_util.newton_qr(
                    jacobian_mat, newton_coefficients_array.read(order),
                    step_size)

            if self._evaluate_jacobian_lazily:

                def update_jacobian_and_factorization():
                    new_jacobian_mat = jacobian_fn_mat(time,
                                                       backward_differences[0])
                    new_unitary, new_upper = update_factorization()
                    return [
                        new_jacobian_mat, True, num_jacobian_evaluations + 1,
                        new_unitary, new_upper
                    ]

                def maybe_update_factorization():
                    new_unitary, new_upper = tf.cond(
                        should_update_factorization, update_factorization,
                        lambda: [unitary, upper])
                    return [
                        jacobian_mat, jacobian_is_up_to_date,
                        num_jacobian_evaluations, new_unitary, new_upper
                    ]

                [
                    jacobian_mat, jacobian_is_up_to_date,
                    num_jacobian_evaluations, unitary, upper
                ] = tf.cond(should_update_jacobian,
                            update_jacobian_and_factorization,
                            maybe_update_factorization)
            else:
                unitary, upper = update_factorization()
                num_matrix_factorizations += 1

            tol = atol + rtol * tf.abs(backward_differences[0])
            newton_tol = newton_tol_factor * tf.norm(tol)

            [
                newton_converged, next_backward_difference, next_state_vec,
                newton_num_iters
            ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                                newton_coefficients_array.read(order),
                                ode_fn_vec, order, step_size, time, newton_tol,
                                unitary, upper)
            num_steps += 1
            num_ode_fn_evaluations += newton_num_iters

            # If Newton's method failed and the Jacobian was up to date, decrease the
            # step size.
            newton_failed = tf.logical_not(newton_converged)
            should_update_step_size = newton_failed & jacobian_is_up_to_date
            new_step_size = step_size * tf1.where(should_update_step_size,
                                                  newton_step_size_factor, 1.)

            # If Newton's method failed and the Jacobian was NOT up to date, update
            # the Jacobian.
            should_update_jacobian = newton_failed & tf.logical_not(
                jacobian_is_up_to_date)

            error_ratio = tf1.where(
                newton_converged,
                bdf_util.error_ratio(next_backward_difference,
                                     error_coefficients_array.read(order),
                                     tol), np.nan)
            accepted = error_ratio < 1.
            converged_and_rejected = newton_converged & tf.logical_not(
                accepted)

            # If Newton's method converged but the solution was NOT accepted, decrease
            # the step size.
            new_step_size = tf1.where(
                converged_and_rejected,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = should_update_step_size | converged_and_rejected

            # If Newton's method converged and the solution was accepted, update the
            # matrix of backward differences.
            time = tf1.where(accepted, time + step_size, time)
            backward_differences = tf1.where(
                accepted,
                bdf_util.update_backward_differences(backward_differences,
                                                     next_backward_difference,
                                                     next_state_vec, order),
                backward_differences)
            jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(
                accepted)
            num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                            num_steps_same_size)

            # Order and step size are only updated if we have taken strictly more than
            # order + 1 steps of the same size. This is to prevent the order from
            # being throttled.
            should_update_order_and_step_size = accepted & (num_steps_same_size
                                                            > order + 1)

            backward_differences_array = tf.TensorArray(
                backward_differences.dtype,
                size=bdf_util.MAX_ORDER + 3,
                clear_after_read=False,
                element_shape=next_backward_difference.get_shape()).unstack(
                    backward_differences)
            new_order = order
            new_error_ratio = error_ratio
            for offset in [-1, +1]:
                proposed_order = tf.clip_by_value(order + offset, 1, max_order)
                proposed_error_ratio = bdf_util.error_ratio(
                    backward_differences_array.read(proposed_order + 1),
                    error_coefficients_array.read(proposed_order), tol)
                proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
                new_order = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_order, new_order)
                new_error_ratio = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_error_ratio,
                    new_error_ratio)
            order = new_order
            error_ratio = new_error_ratio

            new_step_size = tf1.where(
                should_update_order_and_step_size,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = (should_update_step_size
                                       | should_update_order_and_step_size)

            diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                          num_matrix_factorizations,
                                          num_ode_fn_evaluations, status)
            iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date,
                                  new_step_size, num_steps,
                                  num_steps_same_size, should_update_jacobian,
                                  should_update_step_size, time, unitary,
                                  upper)
            solver_internal_state = _BDFSolverInternalState(
                backward_differences, order, step_size)
            return accepted, diagnostics, iterand, solver_internal_state

        # (1) Make static assertions.
        # TODO(b/138304296): Support specifying Jacobian sparsity patterns.
        if jacobian_sparsity is not None:
            raise NotImplementedError(
                'The BDF solver does not support specifying '
                'Jacobian sparsity patterns.')
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'The BDF solver does not support batching.')
        solution_times_chosen_by_solver = (isinstance(solution_times,
                                                      base.ChosenBySolver))

        with tf.name_scope(self._name):

            # (2) Convert to tensors.
            error_if_wrong_dtype = functools.partial(
                util.error_if_not_real_or_complex, identifier='initial_state')

            initial_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                  initial_state)
            tf.nest.map_structure(error_if_wrong_dtype, initial_state)

            state_shape = tf.nest.map_structure(tf.shape, initial_state)
            common_state_dtype = dtype_util.common_dtype(initial_state)
            real_dtype = dtype_util.real_dtype(common_state_dtype)

            if jacobian_fn is None and common_state_dtype.is_complex:
                raise NotImplementedError(
                    'The BDF solver does not support automatic '
                    'Jacobian computations for complex dtypes.')

            # Convert everything to operate on a single, concatenated vector form.
            initial_state_vec = util.get_state_vec(initial_state)
            ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape)
            jacobian_fn_mat = util.get_jacobian_fn_mat(
                jacobian_fn,
                ode_fn_vec,
                state_shape,
                use_pfor=self._use_pfor_to_compute_jacobian,
                dtype=common_state_dtype,
            )

            num_odes = tf.size(initial_state_vec)
            # Use tf.cast instead of tf.convert_to_tensor for differentiable
            # parameters because the tf.custom_gradient decorator converts raw floats
            # into tf.float32, which cannot be converted to tf.float64.
            initial_time = tf.cast(initial_time, real_dtype)
            num_solution_times = 0
            if solution_times_chosen_by_solver:
                final_time = tf.cast(solution_times.final_time, real_dtype)
            else:
                solution_times = tf.cast(solution_times, real_dtype)
                num_solution_times = tf.size(solution_times)
                solution_time_array = tf.TensorArray(
                    solution_times.dtype,
                    size=num_solution_times,
                    element_shape=[]).unstack(solution_times)
                util.error_if_not_vector(solution_times, 'solution_times')
            rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype)
            atol = tf.convert_to_tensor(self._atol, dtype=real_dtype)
            safety_factor = tf.convert_to_tensor(self._safety_factor,
                                                 dtype=real_dtype)
            min_step_size_factor = tf.convert_to_tensor(
                self._min_step_size_factor, dtype=real_dtype)
            max_step_size_factor = tf.convert_to_tensor(
                self._max_step_size_factor, dtype=real_dtype)
            max_num_steps = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
            max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32)
            max_num_newton_iters = self._max_num_newton_iters
            if max_num_newton_iters is not None:
                max_num_newton_iters = tf.convert_to_tensor(
                    max_num_newton_iters, dtype=tf.int32)
            newton_tol_factor = tf.convert_to_tensor(self._newton_tol_factor,
                                                     dtype=real_dtype)
            newton_step_size_factor = tf.convert_to_tensor(
                self._newton_step_size_factor, dtype=real_dtype)
            bdf_coefficients = tf.cast(
                tf.concat([[0.],
                           tf.convert_to_tensor(self._bdf_coefficients,
                                                dtype=real_dtype)], 0),
                common_state_dtype)
            util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients')
            if self._validate_args:
                initial_time = tf.ensure_shape(initial_time, [])
                if solution_times_chosen_by_solver:
                    final_time = tf.ensure_shape(final_time, [])
                safety_factor = tf.ensure_shape(safety_factor, [])
                min_step_size_factor = tf.ensure_shape(min_step_size_factor,
                                                       [])
                max_step_size_factor = tf.ensure_shape(max_step_size_factor,
                                                       [])
                if max_num_steps is not None:
                    max_num_steps = tf.ensure_shape(max_num_steps, [])
                max_order = tf.ensure_shape(max_order, [])
                if max_num_newton_iters is not None:
                    max_num_newton_iters = tf.ensure_shape(
                        max_num_newton_iters, [])
                newton_tol_factor = tf.ensure_shape(newton_tol_factor, [])
                newton_step_size_factor = tf.ensure_shape(
                    newton_step_size_factor, [])
                bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6])
            newton_coefficients = 1. / (
                (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS)
            newton_coefficients_array = tf.TensorArray(
                newton_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(newton_coefficients)
            error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / (
                bdf_util.ORDERS + 1)
            error_coefficients_array = tf.TensorArray(
                error_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(error_coefficients)
            first_step_size = self._first_step_size
            if first_step_size is None:
                first_step_size = bdf_util.first_step_size(
                    atol, error_coefficients_array.read(1), initial_state_vec,
                    initial_time, ode_fn_vec, rtol, safety_factor)
            elif previous_solver_internal_state is not None:
                tf.logging.warn(
                    '`first_step_size` is ignored since'
                    '`previous_solver_internal_state` was specified.')
            first_step_size = tf.convert_to_tensor(first_step_size,
                                                   dtype=real_dtype)
            if self._validate_args:
                first_step_size = tf.ensure_shape(first_step_size, [])
            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                first_order_backward_difference = ode_fn_vec(
                    initial_time, initial_state_vec) * tf.cast(
                        first_step_size, common_state_dtype)
                backward_differences = tf.concat([
                    initial_state_vec[tf.newaxis, :],
                    first_order_backward_difference[tf.newaxis, :],
                    tf.zeros(tf.stack([bdf_util.MAX_ORDER + 1, num_odes]),
                             dtype=common_state_dtype),
                ], 0)
                solver_internal_state = _BDFSolverInternalState(
                    backward_differences=backward_differences,
                    order=1,
                    step_size=first_step_size)
            state_vec_array = tf.TensorArray(
                common_state_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=initial_state_vec.get_shape())
            time_array = tf.TensorArray(
                real_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=tf.TensorShape([]))
            diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0,
                                          num_matrix_factorizations=0,
                                          num_ode_fn_evaluations=0,
                                          status=0)
            iterand = _BDFIterand(
                jacobian_mat=tf.zeros([num_odes, num_odes],
                                      dtype=common_state_dtype),
                jacobian_is_up_to_date=False,
                new_step_size=solver_internal_state.step_size,
                num_steps=0,
                num_steps_same_size=0,
                should_update_jacobian=True,
                should_update_step_size=False,
                time=initial_time,
                unitary=tf.zeros([num_odes, num_odes],
                                 dtype=common_state_dtype),
                upper=tf.zeros([num_odes, num_odes], dtype=common_state_dtype))

            # (3) Make non-static assertions.
            with tf.control_dependencies(assert_ops()):

                # (4) Solve up to final time.
                if solution_times_chosen_by_solver:

                    def step_cond(next_time, diagnostics, iterand, *_):
                        return (iterand.time < next_time) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(step_cond, step, [
                        final_time, diagnostics, iterand,
                        solver_internal_state, state_vec_array, time_array
                    ])

                else:

                    def advance_to_solution_time_cond(n, diagnostics, *_):
                        return (n < num_solution_times) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(
                        advance_to_solution_time_cond,
                        advance_to_solution_time, [
                            0, diagnostics, iterand, solver_internal_state,
                            state_vec_array, time_array
                        ])

                # (6) Return `Results` object.
                states = util.get_state_from_vec(state_vec_array.stack(),
                                                 state_shape)
                times = time_array.stack()
                if not solution_times_chosen_by_solver:
                    times.set_shape(solution_times.get_shape())
                    tf.nest.map_structure(
                        lambda s, ini_s: s.set_shape(
                            solution_times.get_shape(  # pylint: disable=g-long-lambda
                            ).concatenate(ini_s.shape)),
                        states,
                        initial_state)
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)
Esempio n. 12
0
def lossfun(x, alpha, scale, approximate=False, epsilon=1e-6):
  r"""Implements the general form of the loss.

  This implements the rho(x, \alpha, c) function described in "A General and
  Adaptive Robust Loss Function", Jonathan T. Barron,
  https://arxiv.org/abs/1701.03077.

  Args:
    x: The residual for which the loss is being computed. x can have any shape,
      and alpha and scale will be broadcasted to match x's shape if necessary.
      Must be a tensorflow tensor or numpy array of floats.
    alpha: The shape parameter of the loss (\alpha in the paper), where more
      negative values produce a loss with more robust behavior (outliers "cost"
      less), and more positive values produce a loss with less robust behavior
      (outliers are penalized more heavily). Alpha can be any value in
      [-infinity, infinity], but the gradient of the loss with respect to alpha
      is 0 at -infinity, infinity, 0, and 2. Must be a tensorflow tensor or
      numpy array of floats with the same precision as `x`. Varying alpha allows
      for smooth interpolation between a number of discrete robust losses:
      alpha=-Infinity: Welsch/Leclerc Loss.
      alpha=-2: Geman-McClure loss.
      alpha=0: Cauchy/Lortentzian loss.
      alpha=1: Charbonnier/pseudo-Huber loss.
      alpha=2: L2 loss.
    scale: The scale parameter of the loss. When |x| < scale, the loss is an
      L2-like quadratic bowl, and when |x| > scale the loss function takes on a
      different shape according to alpha. Must be a tensorflow tensor or numpy
      array of single-precision floats.
    approximate: a bool, where if True, this function returns an approximate and
      faster form of the loss, as described in the appendix of the paper. This
      approximation holds well everywhere except as x and alpha approach zero.
    epsilon: A float that determines how inaccurate the "approximate" version of
      the loss will be. Larger values are less accurate but more numerically
      stable. Must be great than single-precision machine epsilon.

  Returns:
    The losses for each element of x, in the same shape as x. This is returned
    as a TensorFlow graph node of single precision floats.
  """
  # `scale` and `alpha` must have the same type as `x`.
  float_dtype = x.dtype
  tf.debugging.assert_type(scale, float_dtype)
  tf.debugging.assert_type(alpha, float_dtype)
  # `scale` must be > 0.
  assert_ops = [tf.Assert(tf.reduce_all(tf.greater(scale, 0.)), [scale])]
  with tf.control_dependencies(assert_ops):
    # Broadcast `alpha` and `scale` to have the same shape as `x`.
    alpha = tf.broadcast_to(alpha, tf.shape(x))
    scale = tf.broadcast_to(scale, tf.shape(x))

    if approximate:
      # `epsilon` must be greater than single-precision machine epsilon.
      assert epsilon > np.finfo(np.float32).eps
      # Compute an approximate form of the loss which is faster, but innacurate
      # when x and alpha are near zero.
      b = tf.abs(alpha - tf.cast(2., float_dtype)) + epsilon
      d = tf.where(
          tf.greater_equal(alpha, 0.), alpha + epsilon, alpha - epsilon)
      loss = (b / d) * (tf.pow(tf.square(x / scale) / b + 1., 0.5 * d) - 1.)
    else:
      # Compute the exact loss.

      # This will be used repeatedly.
      squared_scaled_x = tf.square(x / scale)

      # The loss when alpha == 2.
      loss_two = 0.5 * squared_scaled_x
      # The loss when alpha == 0.
      loss_zero = util.log1p_safe(0.5 * squared_scaled_x)
      # The loss when alpha == -infinity.
      loss_neginf = -tf.math.expm1(-0.5 * squared_scaled_x)
      # The loss when alpha == +infinity.
      loss_posinf = util.expm1_safe(0.5 * squared_scaled_x)

      # The loss when not in one of the above special cases.
      machine_epsilon = tf.cast(np.finfo(np.float32).eps, float_dtype)
      # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
      beta_safe = tf.maximum(machine_epsilon, tf.abs(alpha - 2.))
      # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
      alpha_safe = tf.where(
          tf.greater_equal(alpha, 0.), tf.ones_like(alpha),
          -tf.ones_like(alpha)) * tf.maximum(machine_epsilon, tf.abs(alpha))
      loss_otherwise = (beta_safe / alpha_safe) * (
          tf.pow(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.)

      # Select which of the cases of the loss to return.
      loss = tf.where(
          tf.equal(alpha, -tf.cast(float('inf'), float_dtype)), loss_neginf,
          tf.where(
              tf.equal(alpha, 0.), loss_zero,
              tf.where(
                  tf.equal(alpha, 2.), loss_two,
                  tf.where(
                      tf.equal(alpha, tf.cast(float('inf'), float_dtype)),
                      loss_posinf, loss_otherwise))))

    return loss
Esempio n. 13
0
def _kl_independent(a, b, name='kl_independent'):
    """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default 'kl_independent'.

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
    p = a.distribution
    q = b.distribution

    # The KL between any two (non)-batched distributions is a scalar.
    # Given that the KL between two factored distributions is the sum, i.e.
    # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    if (tensorshape_util.is_fully_defined(a.event_shape)
            and tensorshape_util.is_fully_defined(b.event_shape)):
        if a.event_shape == b.event_shape:
            if p.event_shape == q.event_shape:
                num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                                   tensorshape_util.rank(p.event_shape))
                reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

                return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                    q,
                                                                    name=name),
                                     axis=reduce_dims)
            else:
                raise NotImplementedError(
                    'KL between Independents with different '
                    'event shapes not supported.')
        else:
            raise ValueError('Event shapes do not match.')
    else:
        p_event_shape_tensor = p.event_shape_tensor()
        q_event_shape_tensor = q.event_shape_tensor()
        # NOTE: We could optimize by passing the event_shape_tensor of p and q
        # to a.event_shape_tensor() and b.event_shape_tensor().
        a_event_shape_tensor = a.event_shape_tensor()
        b_event_shape_tensor = b.event_shape_tensor()
        with tf.control_dependencies([
                assert_util.assert_equal(a_event_shape_tensor,
                                         b_event_shape_tensor,
                                         message='Event shapes do not match.'),
                assert_util.assert_equal(p_event_shape_tensor,
                                         q_event_shape_tensor,
                                         message='Event shapes do not match.'),
        ]):
            num_reduce_dims = (
                ps.rank_from_shape(a_event_shape_tensor, a.event_shape) -
                ps.rank_from_shape(p_event_shape_tensor, p.event_shape))
            reduce_dims = ps.range(-num_reduce_dims, 0, 1)
            return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                q,
                                                                name=name),
                                 axis=reduce_dims)
Esempio n. 14
0
    def _parse_train_data(self, data):
        """Parse data for ShapeMask training."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        masks = data['groundtruth_instance_masks']
        is_crowds = data['groundtruth_is_crowd']
        # Skips annotations with `is_crowd` = True.
        if self._skip_crowd_during_training and self._is_training:
            num_groundtrtuhs = tf.shape(classes)[0]
            with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
                indices = tf.cond(
                    tf.greater(tf.size(is_crowds), 0),
                    lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
                    lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
            classes = tf.gather(classes, indices)
            boxes = tf.gather(boxes, indices)
            masks = tf.gather(masks, indices)

        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # If not using category, makes all categories with id = 0.
        if not self._use_category:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Flips image randomly during training.
        if self._aug_rand_hflip:
            image, boxes, masks = input_utils.random_horizontal_flip(
                image, boxes, masks)

        # Converts boxes from normalized coordinates to pixel coordinates.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            self._output_size,
            aug_scale_min=self._aug_scale_min,
            aug_scale_max=self._aug_scale_max)
        image_scale = image_info[2, :]
        offset = image_info[3, :]

        # Resizes and crops boxes and masks.
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  self._output_size, offset)

        # Filters out ground truth boxes that are all zeros.
        indices = input_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)
        masks = tf.gather(masks, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size, self._output_size)
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)
        (cls_targets, box_targets,
         num_positives) = anchor_labeler.label_anchors(
             boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))

        # Sample groundtruth masks/boxes/classes for mask branch.
        num_masks = tf.shape(masks)[0]
        mask_shape = tf.shape(masks)[1:3]

        # Pad sampled boxes/masks/classes to a constant batch size.
        padded_boxes = input_utils.pad_to_fixed_size(boxes,
                                                     self._num_sampled_masks)
        padded_classes = input_utils.pad_to_fixed_size(classes,
                                                       self._num_sampled_masks)
        padded_masks = input_utils.pad_to_fixed_size(masks,
                                                     self._num_sampled_masks)

        # Randomly sample groundtruth masks for mask branch training. For the image
        # without groundtruth masks, it will sample the dummy padded tensors.
        rand_indices = tf.random.shuffle(
            tf.range(tf.maximum(num_masks, self._num_sampled_masks)))
        rand_indices = tf.math.mod(rand_indices, tf.maximum(num_masks, 1))
        rand_indices = rand_indices[0:self._num_sampled_masks]
        rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks])

        sampled_boxes = tf.gather(padded_boxes, rand_indices)
        sampled_classes = tf.gather(padded_classes, rand_indices)
        sampled_masks = tf.gather(padded_masks, rand_indices)
        # Jitter the sampled boxes to mimic the noisy detections.
        sampled_boxes = box_utils.jitter_boxes(
            sampled_boxes, noise_scale=self._box_jitter_scale)
        sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size)
        # Compute mask targets in feature crop. A feature crop fully contains a
        # sampled box.
        mask_outer_boxes = box_utils.compute_outer_boxes(
            sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale)
        mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes,
                                                self._output_size)
        # Compensate the offset of mask_outer_boxes to map it back to original image
        # scale.
        mask_outer_boxes_ori = mask_outer_boxes
        mask_outer_boxes_ori += tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
        mask_outer_boxes_ori /= tf.tile(tf.expand_dims(image_scale, axis=0),
                                        [1, 2])
        norm_mask_outer_boxes_ori = box_utils.normalize_boxes(
            mask_outer_boxes_ori, mask_shape)

        # Set sampled_masks shape to [batch_size, height, width, 1].
        sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1),
                                tf.float32)
        mask_targets = tf.image.crop_and_resize(
            sampled_masks,
            norm_mask_outer_boxes_ori,
            box_indices=tf.range(self._num_sampled_masks),
            crop_size=[self._mask_crop_size, self._mask_crop_size],
            method='bilinear',
            extrapolation_value=0,
            name='train_mask_targets')
        mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5),
                                tf.ones_like(mask_targets),
                                tf.zeros_like(mask_targets))
        mask_targets = tf.squeeze(mask_targets, axis=-1)
        if self._up_sample_factor > 1:
            fine_mask_targets = tf.image.crop_and_resize(
                sampled_masks,
                norm_mask_outer_boxes_ori,
                box_indices=tf.range(self._num_sampled_masks),
                crop_size=[
                    self._mask_crop_size * self._up_sample_factor,
                    self._mask_crop_size * self._up_sample_factor
                ],
                method='bilinear',
                extrapolation_value=0,
                name='train_mask_targets')
            fine_mask_targets = tf.where(
                tf.greater_equal(fine_mask_targets, 0.5),
                tf.ones_like(fine_mask_targets),
                tf.zeros_like(fine_mask_targets))
            fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1)
        else:
            fine_mask_targets = mask_targets

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32)
        if self._mask_train_class == 'all':
            mask_is_valid = valid_image * tf.ones_like(sampled_classes,
                                                       tf.int32)
        else:
            # Get the intersection of sampled classes with training splits.
            mask_valid_classes = tf.cast(
                tf.expand_dims(
                    class_utils.coco_split_class_ids(self._mask_train_class),
                    1), sampled_classes.dtype)
            match = tf.reduce_any(
                tf.equal(tf.expand_dims(sampled_classes, 0),
                         mask_valid_classes), 0)
            mask_is_valid = valid_image * tf.cast(match, tf.int32)

        # Packs labels for model_fn outputs.
        labels = {
            'cls_targets': cls_targets,
            'box_targets': box_targets,
            'anchor_boxes': input_anchor.multilevel_boxes,
            'num_positives': num_positives,
            'image_info': image_info,
            # For ShapeMask.
            'mask_boxes': sampled_boxes,
            'mask_outer_boxes': mask_outer_boxes,
            'mask_targets': mask_targets,
            'fine_mask_targets': fine_mask_targets,
            'mask_classes': sampled_classes,
            'mask_is_valid': mask_is_valid,
        }
        return image, labels
Esempio n. 15
0
def interpolate1d(x, values, tangents):
    r"""Perform cubic hermite spline interpolation on a 1D spline.

  The x coordinates of the spline knots are at [0 : 1 : len(values)-1].
  Queries outside of the range of the spline are computed using linear
  extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline
  for details, where "x" corresponds to `x`, "p" corresponds to `values`, and
  "m" corresponds to `tangents`.

  Args:
    x: A tensor of any size of single or double precision floats containing the
      set of values to be used for interpolation into the spline.
    values: A vector of single or double precision floats containing the value
      of each knot of the spline being interpolated into. Must be the same
      length as `tangents` and the same type as `x`.
    tangents: A vector of single or double precision floats containing the
      tangent (derivative) of each knot of the spline being interpolated into.
      Must be the same length as `values` and the same type as `x`.

  Returns:
    The result of interpolating along the spline defined by `values`, and
    `tangents`, using `x` as the query values. Will be the same length and type
    as `x`.
  """
    # `values` and `tangents` must have the same type as `x`.
    tf.debugging.assert_type(values, x.dtype)
    tf.debugging.assert_type(tangents, x.dtype)
    float_dtype = x.dtype
    assert_ops = [
        # `values` must be a vector.
        tf.Assert(tf.equal(tf.rank(values), 1), [tf.shape(values)]),
        # `tangents` must be a vector.
        tf.Assert(tf.equal(tf.rank(tangents), 1), [tf.shape(values)]),
        # `values` and `tangents` must have the same length.
        tf.Assert(
            tf.equal(tf.shape(values)[0],
                     tf.shape(tangents)[0]),
            [tf.shape(values)[0], tf.shape(tangents)[0]]),
    ]
    with tf.control_dependencies(assert_ops):
        # Find the indices of the knots below and above each x.
        x_lo = tf.cast(
            tf.floor(
                tf.clip_by_value(x, 0.,
                                 tf.cast(tf.shape(values)[0] - 2,
                                         float_dtype))), tf.int32)
        x_hi = x_lo + 1

        # Compute the relative distance between each `x` and the knot below it.
        t = x - tf.cast(x_lo, float_dtype)

        # Compute the cubic hermite expansion of `t`.
        t_sq = tf.square(t)
        t_cu = t * t_sq
        h01 = -2. * t_cu + 3. * t_sq
        h00 = 1. - h01
        h11 = t_cu - t_sq
        h10 = h11 - t_sq + t

        # Linearly extrapolate above and below the extents of the spline for all
        # values.
        value_before = tangents[0] * t + values[0]
        value_after = tangents[-1] * (t - 1.) + values[-1]

        # Cubically interpolate between the knots below and above each query point.
        neighbor_values_lo = tf.gather(values, x_lo)
        neighbor_values_hi = tf.gather(values, x_hi)
        neighbor_tangents_lo = tf.gather(tangents, x_lo)
        neighbor_tangents_hi = tf.gather(tangents, x_hi)
        value_mid = (neighbor_values_lo * h00 + neighbor_values_hi * h01 +
                     neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11)

        # Return the interpolated or extrapolated values for each query point,
        # depending on whether or not the query lies within the span of the spline.
        return tf.where(t < 0., value_before,
                        tf.where(t > 1., value_after, value_mid))
Esempio n. 16
0
def interpolate(x,
                x_data,
                y_data,
                left_slope=None,
                right_slope=None,
                validate_args=False,
                optimize_for_tpu=False,
                dtype=None,
                name=None):
    """Performs linear interpolation for supplied points.

  Given a set of knots whose x- and y- coordinates are in `x_data` and `y_data`,
  this function returns y-values for x-coordinates in `x` via piecewise
  linear interpolation.

  `x_data` must be non decreasing, but `y_data` don't need to be because we do
  not require the function approximated by these knots to be monotonic.

  #### Examples

  ```python
  x = [-10, -1, 1, 3, 6, 7, 8, 15, 18, 25, 30, 35]
  x_data = [-1, 2, 6, 8, 18, 30.0]
  y_data = [10, -1, -5, 7, 9, 20]

  result = linear_interpolation(x, x_data, y_data)
  # [ 10, 10, 2.66666667, -2, -5, 1, 7, 8.4, 9, 15.41666667, 20, 20]
  ```

  Args:
    x: x-coordinates for which we need to get interpolation. A N-D `Tensor` of
      real dtype. First N-1 dimensions represent batching dimensions.
    x_data: x coordinates. A N-D `Tensor` of real dtype. Should be sorted
      in non decreasing order. First N-1 dimensions represent batching
      dimensions.
    y_data: y coordinates. A N-D `Tensor` of real dtype. Should have the
      compatible shape as `x_data`. First N-1 dimensions represent batching
      dimensions.
    left_slope: The slope to use for extrapolation with x-coordinate smaller
      than the min `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None`, which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the leftmost `y_data`.
    right_slope: The slope to use for extrapolation with x-coordinate greater
      than the max `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None` which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the rightmost `y_data`.
    validate_args: Python `bool` that indicates whether the function performs
      the check if the shapes of `x_data` and `y_data` are equal and that the
      elements in `x_data` are non decreasing. If this value is set to `False`
      and the elements in `x_data` are not increasing, the result of linear
      interpolation may be wrong.
      Default value: `False`.
    optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot
      encoding to lookup indices of `x_values` in `x_data`. This significantly
      improves performance of the algorithm on a TPU device but may slow down
      performance on the CPU.
      Default value: `False`.
    dtype: Optional tf.dtype for `x`, x_data`, `y_data`, `left_slope` and
      `right_slope`.
      Default value: `None` which means that the `dtype` inferred by TensorFlow
      is used.
    name: Python str. The name prefixed to the ops created by this function.
      Default value: `None` which maps to 'linear_interpolation'.

  Returns:
    A N-D `Tensor` of real dtype corresponding to the x-values in `x`.
  """
    name = name or 'linear_interpolation'
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        dtype = dtype or x.dtype
        x_data = tf.convert_to_tensor(x_data, dtype=dtype, name='x_data')
        y_data = tf.convert_to_tensor(y_data, dtype=dtype, name='y_data')
        batch_shape = x.shape.as_list()[:-1]
        if not batch_shape:
            x = tf.expand_dims(x, 0)
            x_data = tf.expand_dims(x_data, 0)
            y_data = tf.expand_dims(y_data, 0)

        if left_slope is None:
            left_slope = tf.constant(0.0, dtype=x.dtype, name='left_slope')
        else:
            left_slope = tf.convert_to_tensor(left_slope,
                                              dtype=dtype,
                                              name='left_slope')
        if right_slope is None:
            right_slope = tf.constant(0.0, dtype=x.dtype, name='right_slope')
        else:
            right_slope = tf.convert_to_tensor(right_slope,
                                               dtype=dtype,
                                               name='right_slope')
        control_deps = []
        if validate_args:
            # Check that `x_data` elements is non-decreasing
            diffs = x_data[..., 1:] - x_data[..., :-1]
            assertion = tf.compat.v1.debugging.assert_greater_equal(
                diffs,
                tf.zeros_like(diffs),
                message='x_data is not sorted in non-decreasing order.')
            control_deps.append(assertion)
            # Check that the shapes of `x_data` and `y_data` are equal
            control_deps.append(
                tf.compat.v1.assert_equal(tf.shape(x_data), tf.shape(y_data)))

        with tf.control_dependencies(control_deps):
            # Get upper bound indices for `x`.
            upper_indices = tf.searchsorted(x_data,
                                            x,
                                            side='left',
                                            out_type=tf.int32)
            x_data_size = x_data.shape.as_list()[-1]
            at_min = tf.equal(upper_indices, 0)
            at_max = tf.equal(upper_indices, x_data_size)
            # Create tensors in order to be used by `tf.where`.
            # `values_min` are extrapolated values for x-coordinates less than or
            # equal to `x_data[..., 0]`.
            # `values_max` are extrapolated values for x-coordinates greater than
            # `x_data[..., -1]`.

            values_min = tf.expand_dims(
                y_data[..., 0], -1) + left_slope * (x - tf.broadcast_to(
                    tf.expand_dims(x_data[..., 0], -1), shape=tf.shape(x)))
            values_max = tf.expand_dims(
                y_data[..., -1], -1) + right_slope * (x - tf.broadcast_to(
                    tf.expand_dims(x_data[..., -1], -1), shape=tf.shape(x)))

            # `tf.where` evaluates all branches, need to cap indices to ensure it
            # won't go out of bounds.
            capped_lower_indices = tf.math.maximum(upper_indices - 1, 0)
            capped_upper_indices = tf.math.minimum(upper_indices,
                                                   x_data_size - 1)
            # Prepare indices for `tf.gather_nd` or `tf.one_hot`
            # TODO(b/156720909): Extract get_slice logic into a common utilities
            # module for cubic and linear interpolation
            if optimize_for_tpu:
                lower_encoding = tf.one_hot(capped_lower_indices,
                                            x_data_size,
                                            dtype=dtype)
                upper_encoding = tf.one_hot(capped_upper_indices,
                                            x_data_size,
                                            dtype=dtype)
            else:
                index_matrix = _prepare_indices(upper_indices)
                lower_encoding = tf.concat(
                    [index_matrix,
                     tf.expand_dims(capped_lower_indices, -1)], -1)

                upper_encoding = tf.concat(
                    [index_matrix,
                     tf.expand_dims(capped_upper_indices, -1)], -1)

            def get_slice(x, encoding):
                if optimize_for_tpu:
                    return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) *
                                              encoding,
                                              axis=-1)
                else:
                    return tf.gather_nd(x, encoding)

            x_data_lower = get_slice(x_data, lower_encoding)
            x_data_upper = get_slice(x_data, upper_encoding)
            y_data_lower = get_slice(y_data, lower_encoding)
            y_data_upper = get_slice(y_data, upper_encoding)

            # Nan in unselected branches could propagate through gradient calculation,
            # hence we need to clip the values to ensure no nan would occur. In this
            # case we need to ensure there is no division by zero.
            x_data_diff = x_data_upper - x_data_lower
            floor_x_diff = tf.where(at_min | at_max, x_data_diff + 1,
                                    x_data_diff)
            interpolated = y_data_lower + (x - x_data_lower) * (
                y_data_upper - y_data_lower) / floor_x_diff

            interpolated = tf.where(at_min, values_min, interpolated)
            interpolated = tf.where(at_max, values_max, interpolated)
            if batch_shape:
                return interpolated
            else:
                return tf.squeeze(interpolated, 0)
Esempio n. 17
0
    def update_state(self, values, sample_weight=None):
        """Accumulates statistics for computing the metric.

        Args:
          values: Per-example value.
          sample_weight: Optional weighting of each example. Defaults to 1.

        Returns:
          Update op.
        """
        [
            values
        ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values(  # noqa: E501
            [values], sample_weight)
        try:
            values = tf.cast(values, self._dtype)
        except (ValueError, TypeError):
            msg = (
                "The output of a metric function can only be a single Tensor. "
                f"Received: {values}. ")
            if isinstance(values, dict):
                msg += (
                    "To return a dict of values, implement a custom Metric "
                    "subclass.")
            raise RuntimeError(msg)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self._dtype)
            # Update dimensions of weights to match with values if possible.
            (
                values,
                _,
                sample_weight,
            ) = losses_utils.squeeze_or_expand_dimensions(
                values, sample_weight=sample_weight)
            try:
                # Broadcast weights if possible.
                sample_weight = tf.__internal__.ops.broadcast_weights(
                    sample_weight, values)
            except ValueError:
                # Reduce values to same ndim as weight array
                ndim = backend.ndim(values)
                weight_ndim = backend.ndim(sample_weight)
                if self.reduction == metrics_utils.Reduction.SUM:
                    values = tf.reduce_sum(values,
                                           axis=list(range(weight_ndim, ndim)))
                else:
                    values = tf.reduce_mean(values,
                                            axis=list(range(weight_ndim,
                                                            ndim)))
            values = tf.multiply(values, sample_weight)

        value_sum = tf.reduce_sum(values)
        with tf.control_dependencies([value_sum]):
            update_total_op = self.total.assign_add(value_sum)

        # Exit early if the reduction doesn't have a denominator.
        if self.reduction == metrics_utils.Reduction.SUM:
            return update_total_op

        # Update `count` for reductions that require a denominator.
        if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
            num_values = tf.cast(tf.size(values), self._dtype)
        elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
            if sample_weight is None:
                num_values = tf.cast(tf.size(values), self._dtype)
            else:
                num_values = tf.reduce_sum(sample_weight)
        else:
            raise NotImplementedError(
                f'Reduction "{self.reduction}" not implemented. Expected '
                '"sum", "weighted_mean", or "sum_over_batch_size".')

        with tf.control_dependencies([update_total_op]):
            return self.count.assign_add(num_values)
Esempio n. 18
0
 def _inverse(self, y):
     with tf.control_dependencies(self._assertions(y)):
         return -y, y
Esempio n. 19
0
def pinv(a, rcond=None, validate_args=False, name=None):
    """Compute the Moore-Penrose pseudo-inverse of a matrix.

  Calculate the [generalized inverse of a matrix](
  https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its
  singular-value decomposition (SVD) and including all large singular values.

  The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves'
  [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then
  `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
  `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
  `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]

  This function is analogous to [`numpy.linalg.pinv`](
  https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html).
  It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
  default `rcond` is `1e-15`. Here the default is
  `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.

  Args:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    rcond: `Tensor` of small singular value cutoffs.  Singular values smaller
      (in modulus) than `rcond` * largest_singular_value (again, in modulus) are
      set to zero. Must broadcast against `tf.shape(a)[:-2]`.
      Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'pinv'.

  Returns:
    a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except
      rightmost two dimensions are transposed.

  Raises:
    TypeError: if input `a` does not have `float`-like `dtype`.
    ValueError: if input `a` has fewer than 2 dimensions.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp

  a = tf.constant([[1.,  0.4,  0.5],
                   [0.4, 0.2,  0.25],
                   [0.5, 0.25, 0.35]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)

  a = tf.constant([[1.,  0.4,  0.5,  1.],
                   [0.4, 0.2,  0.25, 2.],
                   [0.5, 0.25, 0.35, 3.]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[ 0.76,  0.37,  0.21, -0.02],
               [ 0.37,  0.43, -0.33,  0.02],
               [ 0.21, -0.33,  0.81,  0.01],
               [-0.02,  0.02,  0.01,  1.  ]], dtype=float32)
  ```

  #### References

  [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press,
       Inc., 1980, pp. 139-142.
  """
    with tf.name_scope(name or 'pinv'):
        a = tf.convert_to_tensor(a, name='a')

        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)

        dtype = a.dtype.as_numpy_dtype

        if rcond is None:

            def get_dim_size(dim):
                if tf.compat.dimension_value(a.shape[dim]) is not None:
                    return tf.compat.dimension_value(a.shape[dim])
                return tf.shape(a)[dim]

            num_rows = get_dim_size(-2)
            num_cols = get_dim_size(-1)
            if isinstance(num_rows, int) and isinstance(num_cols, int):
                max_rows_cols = float(max(num_rows, num_cols))
            else:
                max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype)
            rcond = 10. * max_rows_cols * np.finfo(dtype).eps

        rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond')

        # Calculate pseudo inverse via SVD.
        # Note: if a is symmetric then u == v. (We might observe additional
        # performance by explicitly setting `v = u` in such cases.)
        [
            singular_values,  # Sigma
            left_singular_vectors,  # U
            right_singular_vectors,  # V
        ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True)

        # Saturate small singular values to inf. This has the effect of make
        # `1. / s = 0.` while not resulting in `NaN` gradients.
        cutoff = rcond * tf.reduce_max(singular_values, axis=-1)
        singular_values = tf.where(singular_values > cutoff[..., tf.newaxis],
                                   singular_values, np.array(np.inf, dtype))

        # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap
        # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e.,
        # a matrix inverse has 'transposed' semantics.
        a_pinv = tf.matmul(right_singular_vectors /
                           singular_values[..., tf.newaxis, :],
                           left_singular_vectors,
                           adjoint_b=True)

        if a.shape.ndims is not None:
            a_pinv.set_shape(a.shape[:-2].concatenate(
                [a.shape[-1], a.shape[-2]]))

        return a_pinv
Esempio n. 20
0
 def _inverse_log_det_jacobian(self, y):
   with tf.control_dependencies(self._maybe_assert_valid_y(y)):
     return (self.power - 1.) * tf.math.log(y)
Esempio n. 21
0
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
    """Computes a matrix inverse given the matrix's LU decomposition.

  This op is conceptually identical to,

  ```python
  inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
  tf.assert_near(tf.matrix_inverse(X), inv_X)
  # ==> True
  ```

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_matrix_inverse').

  Returns:
    inv_x: The matrix_inv, i.e.,
      `tf.matrix_inverse(tfp.math.lu_reconstruct(lu, perm))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  inv_x = tfp.math.lu_matrix_inverse(*tf.linalg.lu(x))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_matrix_inverse'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        assertions = _lu_reconstruct_assertions(lower_upper, perm,
                                                validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
        shape = tf.shape(lower_upper)
        return lu_solve(lower_upper,
                        perm,
                        rhs=tf.eye(shape[-1],
                                   batch_shape=shape[:-2],
                                   dtype=lower_upper.dtype),
                        validate_args=False)
Esempio n. 22
0
 def _forward_log_det_jacobian(self, x):
   with tf.control_dependencies(self._maybe_assert_valid_x(x)):
     if self.power == 0.:
       return x
     return (1. / self.power - 1.) * tf.math.log1p(x * self.power)
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (observation_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (transition_distribution.batch_shape
                     and transition_distribution.batch_shape[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (observation_distribution.batch_shape
                     and observation_distribution.batch_shape[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(self._initial_distribution._graph_parents +
                               self._transition_distribution._graph_parents +
                               self._observation_distribution._graph_parents),
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Esempio n. 24
0
def _validate_arg_if_not_none(arg, assertion, validate_args):
    if arg is None:
        return arg
    with tf.control_dependencies([assertion(arg)] if validate_args else []):
        result = tf.identity(arg)
    return result
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            seed = seed_stream.SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=seed())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=seed())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                hidden_states = tf.scan(generate_step,
                                        dummy_index,
                                        initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n])

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
Esempio n. 26
0
def _effective_sample_size_single_state(states, filter_beyond_lag,
                                        filter_threshold,
                                        filter_beyond_positive_pairs,
                                        cross_chain_dims,
                                        validate_args):
  """ESS computation for one single Tensor argument."""

  with tf.name_scope('effective_sample_size_single_state'):

    states = tf.convert_to_tensor(states, name='states')
    dt = states.dtype

    # filter_beyond_lag == None ==> auto_corr is the full sequence.
    auto_cov = stats.auto_correlation(
        states, axis=0, max_lags=filter_beyond_lag, normalize=False)
    n = _axis_size(states, axis=0)

    if cross_chain_dims is not None:
      num_chains = _axis_size(states, cross_chain_dims)
      num_chains_ = tf.get_static_value(num_chains)

      assertions = []
      msg = ('When `cross_chain_dims` is not `None`, there must be > 1 chain '
             'in `states`.')
      if num_chains_ is not None:
        if num_chains_ < 2:
          raise ValueError(msg)
      elif validate_args:
        assertions.append(
            assert_util.assert_greater(num_chains, 1., message=msg))

      with tf.control_dependencies(assertions):
        # We're computing the R[k] from equation 10 of Vehtari et al.
        # (2019):
        #
        # R[k] := 1 - (W - 1/C * Sum_{c=1}^C s_c**2 R[k, c]) / (var^+),
        #
        # where:
        #   C := number of chains
        #   N := length of chains
        #   x_hat[c] := 1 / N Sum_{n=1}^N x[n, c], chain mean.
        #   x_hat := 1 / C Sum_{c=1}^C x_hat[c], overall mean.
        #   W := 1/C Sum_{c=1}^C s_c**2, within-chain variance.
        #   B := N / (C - 1) Sum_{c=1}^C (x_hat[c] - x_hat)**2, between chain
        #     variance.
        #   s_c**2 := 1 / (N - 1) Sum_{n=1}^N (x[n, c] - x_hat[c])**2, chain
        #       variance
        #   R[k, m] := auto_corr[k, m, ...], auto-correlation indexed by chain.
        #   var^+ := (N - 1) / N * W + B / N

        cross_chain_dims = ps.non_negative_axis(
            cross_chain_dims, ps.rank(states))
        # B / N
        between_chain_variance_div_n = _reduce_variance(
            tf.reduce_mean(states, axis=0),
            biased=False,  # This makes the denominator be C - 1.
            axis=cross_chain_dims - 1)
        # W * (N - 1) / N
        biased_within_chain_variance = tf.reduce_mean(auto_cov[0],
                                                      cross_chain_dims - 1)
        # var^+
        approx_variance = (
            biased_within_chain_variance + between_chain_variance_div_n)
        # 1/C * Sum_{c=1}^C s_c**2 R[k, c]
        mean_auto_cov = tf.reduce_mean(auto_cov, cross_chain_dims)
        auto_corr = 1. - (biased_within_chain_variance -
                          mean_auto_cov) / approx_variance
    else:
      auto_corr = auto_cov / auto_cov[:1]
      num_chains = 1

    # With R[k] := auto_corr[k, ...],
    # ESS = N / {1 + 2 * Sum_{k=1}^N R[k] * (N - k) / N}
    #     = N / {-1 + 2 * Sum_{k=0}^N R[k] * (N - k) / N} (since R[0] = 1)
    #     approx N / {-1 + 2 * Sum_{k=0}^M R[k] * (N - k) / N}
    # where M is the filter_beyond_lag truncation point chosen above.

    # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total
    # ndims the same as auto_corr
    k = tf.range(0., _axis_size(auto_corr, axis=0))
    nk_factor = (n - k) / n
    if tensorshape_util.rank(auto_corr.shape) is not None:
      new_shape = [-1] + [1] * (tensorshape_util.rank(auto_corr.shape) - 1)
    else:
      new_shape = tf.concat(
          ([-1],
           tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)),
          axis=0)
    nk_factor = tf.reshape(nk_factor, new_shape)
    weighted_auto_corr = nk_factor * auto_corr

    if filter_beyond_positive_pairs:
      def _sum_pairs(x):
        x_len = ps.shape(x)[0]
        # For odd sequences, we drop the final value.
        x = x[:x_len - x_len % 2]
        new_shape = ps.concat([[x_len // 2, 2], ps.shape(x)[1:]], axis=0)
        return tf.reduce_sum(tf.reshape(x, new_shape), 1)

      # Pairwise sums are all positive for auto-correlation spectra derived from
      # reversible MCMC chains.
      # E.g. imagine the pairwise sums are [0.2, 0.1, -0.1, -0.2]
      # Step 1: mask = [False, False, True, True]
      mask = _sum_pairs(auto_corr) < 0.
      # Step 2: mask = [0, 0, 1, 1]
      mask = tf.cast(mask, dt)
      # Step 3: mask = [0, 0, 1, 2]
      mask = tf.cumsum(mask, axis=0)
      # Step 4: mask = [1, 1, 0, 0]
      mask = tf.maximum(1. - mask, 0.)

      # N.B. this reduces the length of weighted_auto_corr by a factor of 2.
      # It still works fine in the formula below.
      weighted_auto_corr = _sum_pairs(weighted_auto_corr) * mask
    elif filter_threshold is not None:
      filter_threshold = tf.convert_to_tensor(
          filter_threshold, dtype=dt, name='filter_threshold')
      # Get a binary mask to zero out values of auto_corr below the threshold.
      #   mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i,
      #   mask[i, ...] = 0, otherwise.
      # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...]
      # Building step by step,
      #   Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2.
      # Step 1:  mask = [False, False, True, False]
      mask = auto_corr < filter_threshold
      # Step 2:  mask = [0, 0, 1, 0]
      mask = tf.cast(mask, dtype=dt)
      # Step 3:  mask = [0, 0, 1, 1]
      mask = tf.cumsum(mask, axis=0)
      # Step 4:  mask = [1, 1, 0, 0]
      mask = tf.maximum(1. - mask, 0.)
      weighted_auto_corr *= mask

    return num_chains * n / (-1 + 2 * tf.reduce_sum(weighted_auto_corr, axis=0))
    def posterior_mode(self, observations, mask=None, name=None):
        """Compute maximum likelihood sequence of hidden states.

    When this function is provided with a sequence of observations
    `x[0], ..., x[num_steps - 1]`, it returns the sequence of hidden
    states `z[0], ..., z[num_steps - 1]`, drawn from the underlying
    Markov chain, that is most likely to yield those observations.

    It uses the [Viterbi algorithm](
    https://en.wikipedia.org/wiki/Viterbi_algorithm).

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Note: if there isn't a unique most likely sequence then one
    of the equally most likely sequences is chosen.

    Args:
      observations: A tensor representing a batch of observations made on the
        hidden Markov model.  The rightmost dimensions of this tensor correspond
        to the dimensions of the observation distributions of the underlying
        Markov chain.  The next dimension from the right indexes the steps in a
        sequence of observations from a single sample from the hidden Markov
        model.  The size of this dimension should match the `num_steps`
        parameter of the hidden Markov model object.  The other dimensions are
        the dimensions of the batch and these are broadcast with the hidden
        Markov model's parameters.
      mask: optional bool-type `tensor` with rightmost dimension matching
        `num_steps` indicating which observations the result of this
        function should be conditioned on. When the mask has value
        `True` the corresponding observations aren't used.
        if `mask` is `None` then all of the observations are used.
        the `mask` dimensions left of the last are broadcast with the
        hmm batch as well as with the observations.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_mode: A `Tensor` representing the most likely sequence of hidden
        states. The rightmost dimension of this tensor will equal the
        `num_steps` parameter providing one hidden state for each step. The
        other dimensions are those of the batch.

    Raises:
      ValueError: if the `observations` tensor does not consist of
      sequences of `num_steps` observations.

    #### Examples

    ```python
    tfd = tfp.distributions

    # A simple weather model.

    # Represent a cold day with 0 and a hot day with 1.
    # Suppose the first day of a sequence has a 0.8 chance of being cold.

    initial_distribution = tfd.Categorical(probs=[0.8, 0.2])

    # Suppose a cold day has a 30% chance of being followed by a hot day
    # and a hot day has a 20% chance of being followed by a cold day.

    transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                     [0.2, 0.8]])

    # Suppose additionally that on each day the temperature is
    # normally distributed with mean and standard deviation 0 and 5 on
    # a cold day and mean and standard deviation 15 and 10 on a hot day.

    observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

    # This gives the hidden Markov model:

    model = tfd.HiddenMarkovModel(
        initial_distribution=initial_distribution,
        transition_distribution=transition_distribution,
        observation_distribution=observation_distribution,
        num_steps=7)

    # Suppose we observe gradually rising temperatures over a week:
    temps = [-2., 0., 2., 4., 6., 8., 10.]

    # We can now compute the most probable sequence of hidden states:

    model.posterior_mode(temps)

    # The result is [0 0 0 0 0 1 1] telling us that the transition
    # from "cold" to "hot" most likely happened between the
    # 5th and 6th days.
    ```
    """

        with tf.name_scope(name or "posterior_mode"):
            observations = tf.convert_to_tensor(observations,
                                                name="observations")
            if mask is not None:
                mask = tf.convert_to_tensor(mask,
                                            name="mask",
                                            dtype_hint=tf.bool)
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(observations)
                mask_tensor_shape = tf.shape(
                    mask) if mask is not None else None

                with self._observation_mask_shape_preconditions(
                        observation_tensor_shape, mask_tensor_shape):
                    observation_log_probs = self._observation_log_probs(
                        observations, mask)
                    log_prob = self._log_init + observation_log_probs[0]

                    def _reduce_multiple_steps():
                        """Perform `reduce_max` operation when `num_steps` > 1."""
                        def forward_step(previous_step_pair,
                                         log_prob_observation):
                            log_prob_previous = previous_step_pair[0]
                            log_prob = (
                                log_prob_previous[..., tf.newaxis] +
                                self._log_trans +
                                log_prob_observation[..., tf.newaxis, :])
                            most_likely_given_successor = tf.argmax(log_prob,
                                                                    axis=-2)
                            max_log_p_given_successor = tf.reduce_max(
                                input_tensor=log_prob, axis=-2)
                            return (max_log_p_given_successor,
                                    most_likely_given_successor)

                        forward_log_probs, all_most_likely_given_successor = tf.scan(
                            forward_step,
                            observation_log_probs[1:],
                            initializer=(log_prob,
                                         tf.zeros(tf.shape(log_prob),
                                                  dtype=tf.int64)),
                            name="forward_log_probs")

                        most_likely_end = tf.argmax(forward_log_probs[-1],
                                                    axis=-1)

                        # We require the operation that gives C from A and B where
                        # C[i...j] = A[i...j, B[i...j]]
                        # and A = most_likely_given_successor
                        #     B = most_likely_successor.
                        # tf.gather requires indices of known shape so instead we use
                        # reduction with tf.one_hot(B) to pick out elements from B
                        def backward_step(most_likely_successor,
                                          most_likely_given_successor):
                            return tf.reduce_sum(
                                input_tensor=(most_likely_given_successor *
                                              tf.one_hot(most_likely_successor,
                                                         self._num_states,
                                                         dtype=tf.int64)),
                                axis=-1)

                        backward_scan = tf.scan(
                            backward_step,
                            all_most_likely_given_successor,
                            most_likely_end,
                            reverse=True)
                        most_likely_sequences = tf.concat(
                            [backward_scan, [most_likely_end]], axis=0)
                        return distribution_util.move_dimension(
                            most_likely_sequences, 0, -1)

                    return prefer_static.cond(
                        self.num_steps > 1, _reduce_multiple_steps,
                        lambda: tf.argmax(log_prob, axis=-1)[..., tf.newaxis])
Esempio n. 28
0
def _potential_scale_reduction_single_state(state, independent_chain_ndims,
                                            split_chains, validate_args):
  """potential_scale_reduction for one single state `Tensor`."""
  # casting integers to floats for floating-point division
  # check to see if the `state` is a numpy object for the numpy test suite
  if dtype_util.as_numpy_dtype(state.dtype) is np.int64:
    state = tf.cast(state, tf.float64)
  elif dtype_util.is_integer(state.dtype):
    state = tf.cast(state, tf.float32)
  with tf.name_scope('potential_scale_reduction_single_state'):
    # We assume exactly one leading dimension indexes e.g. correlated samples
    # from each Markov chain.
    state = tf.convert_to_tensor(state, name='state')

    n_samples_ = tf.compat.dimension_value(state.shape[0])
    if n_samples_ is not None:  # If available statically.
      if split_chains and n_samples_ < 4:
        raise ValueError(
            'Must provide at least 4 samples when splitting chains. '
            'Found {}'.format(n_samples_))
      if not split_chains and n_samples_ < 2:
        raise ValueError(
            'Must provide at least 2 samples.  Found {}'.format(n_samples_))
    elif validate_args:
      if split_chains:
        assertions = [assert_util.assert_greater(
            ps.shape(state)[0], 4,
            message='Must provide at least 4 samples when splitting chains.')]
        with tf.control_dependencies(assertions):
          state = tf.identity(state)
      else:
        assertions = [assert_util.assert_greater(
            ps.shape(state)[0], 2,
            message='Must provide at least 2 samples.')]
        with tf.control_dependencies(assertions):
          state = tf.identity(state)

    # Define so it's not a magic number.
    # Warning!  `if split_chains` logic assumes this is 1!
    sample_ndims = 1

    if split_chains:
      # Split the sample dimension in half, doubling the number of
      # independent chains.

      # For odd number of samples, keep all but the last sample.
      state_shape = ps.shape(state)
      n_samples = state_shape[0]
      state = state[:n_samples - n_samples % 2]

      # Suppose state = [0, 1, 2, 3, 4, 5]
      # Step 1: reshape into [[0, 1, 2], [3, 4, 5]]
      # E.g. reshape states of shape [a, b] into [2, a//2, b].
      state = tf.reshape(
          state,
          ps.concat([[2, n_samples // 2], state_shape[1:]], axis=0)
      )
      # Step 2: Put the size `2` dimension in the right place to be treated as a
      # chain, changing [[0, 1, 2], [3, 4, 5]] into [[0, 3], [1, 4], [2, 5]],
      # reshaping [2, a//2, b] into [a//2, 2, b].
      state = tf.transpose(
          a=state,
          perm=ps.concat(
              [[1, 0], ps.range(2, ps.rank(state))], axis=0))

      # We're treating the new dim as indexing 2 chains, so increment.
      independent_chain_ndims += 1

    sample_axis = ps.range(0, sample_ndims)
    chain_axis = ps.range(sample_ndims,
                          sample_ndims + independent_chain_ndims)
    sample_and_chain_axis = ps.range(
        0, sample_ndims + independent_chain_ndims)

    n = _axis_size(state, sample_axis)
    m = _axis_size(state, chain_axis)

    # In the language of Brooks and Gelman (1998),
    # B / n is the between chain variance, the variance of the chain means.
    # W is the within sequence variance, the mean of the chain variances.
    b_div_n = _reduce_variance(
        tf.reduce_mean(state, axis=sample_axis, keepdims=True),
        sample_and_chain_axis,
        biased=False)
    w = tf.reduce_mean(
        _reduce_variance(state, sample_axis, keepdims=True, biased=False),
        axis=sample_and_chain_axis)

    # sigma^2_+ is an estimate of the true variance, which would be unbiased if
    # each chain was drawn from the target.  c.f. "law of total variance."
    sigma_2_plus = ((n - 1) / n) * w + b_div_n
    return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
Esempio n. 29
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)
            kernel_shape = ps.shape(kernel)
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  kernel_shape[-1],
                                                  batch_shape, event_shape)

                idx, shape = im2row_index((xh * sh + sum(pad_values[0]),
                                           xw * sw + sum(pad_values[1]), c_in),
                                          block_shape=filter_shape,
                                          slice_step=(1, 1),
                                          dilations=dilations,
                                          dtype=dtype,
                                          transpose=True)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                # Interleave the rows and columns of the input with rows and columns of
                # zeros equal to the number of strides.
                x_half_dilated = tf.concat([
                    tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)],
                                       axis=0),
                             dtype=input_dtype),
                    tf.reshape(x,
                               shape=ps.concat(
                                   [batch_shape, (xh * xw, 1, c_in)], axis=0))
                ],
                                           axis=-2)
                y = tf.reshape(x_half_dilated,
                               shape=ps.concat(
                                   [batch_shape, (xh, 1, xw * sw, c_in)],
                                   axis=0))

                x = tf.reshape(tf.concat([
                    tf.zeros(ps.concat(
                        [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0),
                             dtype=input_dtype), y
                ],
                                         axis=-3),
                               shape=ps.concat(
                                   [batch_shape, (xh * sh, xw * sw, c_in)],
                                   axis=0))

                truncations = -ps.minimum(ps.cast(paddings, dtype=tf.int32), 0)
                truncate_start, truncate_end = ps.unstack(truncations, axis=1)
                x_truncate = tf.slice(x,
                                      begin=truncate_start,
                                      size=ps.shape(x) -
                                      (truncate_start + truncate_end))

                x_pad = tf.pad(x_truncate,
                               paddings=ps.maximum(paddings, 0),
                               constant_values=0)

                flat_shape = ps.pad(batch_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape),
                                   indices=idx,
                                   axis=-1)
                im_x = tf.reshape(flat_x,
                                  shape=ps.concat([batch_shape, shape],
                                                  axis=0))
                return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Esempio n. 30
0
    def _sample_control_dependencies(self, x):
        """Helper which validates sample arg, e.g., input to `log_prob`."""
        x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else
                   tensorshape_util.rank(x.shape))
        event_ndims = (tf.size(self.event_shape_tensor())
                       if tensorshape_util.rank(self.event_shape) is None else
                       tensorshape_util.rank(self.event_shape))
        batch_ndims = (tf.size(self.batch_shape_tensor())
                       if tensorshape_util.rank(self.batch_shape) is None else
                       tensorshape_util.rank(self.batch_shape))
        expected_batch_event_ndims = batch_ndims + event_ndims

        if (isinstance(x_ndims, int)
                and isinstance(expected_batch_event_ndims, int)):
            if x_ndims < expected_batch_event_ndims:
                raise NotImplementedError(
                    'Broadcasting is not supported; too few batch and event dims '
                    '(expected at least {}, saw {}).'.format(
                        expected_batch_event_ndims, x_ndims))
            ndims_assertion = []
        elif self.validate_args:
            ndims_assertion = [
                assert_util.assert_greater_equal(
                    x_ndims,
                    expected_batch_event_ndims,
                    message=('Broadcasting is not supported; too few '
                             'batch and event dims.'),
                    name='assert_batch_and_event_ndims_large_enough'),
            ]

        if (tensorshape_util.is_fully_defined(self.batch_shape)
                and tensorshape_util.is_fully_defined(self.event_shape)):
            expected_batch_event_shape = np.int32(
                tensorshape_util.concatenate(self.batch_shape,
                                             self.event_shape))
        else:
            expected_batch_event_shape = tf.concat([
                self.batch_shape_tensor(),
                self.event_shape_tensor(),
            ],
                                                   axis=0)

        sample_ndims = x_ndims - expected_batch_event_ndims
        if isinstance(sample_ndims, int):
            sample_ndims = max(sample_ndims, 0)
        if (isinstance(sample_ndims, int)
                and tensorshape_util.is_fully_defined(x.shape[sample_ndims:])):
            actual_batch_event_shape = np.int32(x.shape[sample_ndims:])
        else:
            sample_ndims = tf.maximum(sample_ndims, 0)
            actual_batch_event_shape = tf.shape(x)[sample_ndims:]

        assertions = []
        if (isinstance(expected_batch_event_shape, np.ndarray)
                and isinstance(actual_batch_event_shape, np.ndarray)):
            if any(expected_batch_event_shape != actual_batch_event_shape):
                raise NotImplementedError('Broadcasting is not supported; '
                                          'unexpected batch and event shape '
                                          '(expected {}, saw {}).'.format(
                                              expected_batch_event_shape,
                                              actual_batch_event_shape))
            assertions.extend(ndims_assertion)
        elif self.validate_args:
            with tf.control_dependencies(ndims_assertion):
                shape_assertion = assert_util.assert_equal(
                    expected_batch_event_shape,
                    actual_batch_event_shape,
                    message=('Broadcasting is not supported; '
                             'unexpected batch and event shape.'),
                    name='assert_batch_and_event_shape_same')
            assertions.append(shape_assertion)

        return assertions