Esempio n. 1
0
 def test_tf_should_error_with_more_than_one_named_axis(self):
     if JAX_MODE:
         self.skipTest('Test only applies to TF backend.')
     with self.assertRaisesRegex(
             ValueError,
             'TensorFlow backend does not support multiple shard axes'):
         distribute_lib.canonicalize_named_axis(['a', 'b'])
Esempio n. 2
0
def pbroadcast_value(value, value_axis_names, output_axis_names):
    value_axis_names = distribute_lib.canonicalize_named_axis(value_axis_names)
    pbroadcast_axes = [
        axis_name for axis_name in output_axis_names
        if axis_name not in value_axis_names
    ]
    return distribute_lib.pbroadcast(value, named_axis=pbroadcast_axes)
Esempio n. 3
0
def reduce_logmeanexp(input_tensor,
                      axis=None,
                      keepdims=False,
                      experimental_named_axis=None,
                      experimental_allow_all_gather=False,
                      name=None):
    """Computes `log(mean(exp(input_tensor)))`.

  Reduces `input_tensor` along the dimensions given in `axis`.  Unless
  `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
  `axis`. If `keepdims` is true, the reduced dimensions are retained with length
  1.

  If `axis` has no entries, all dimensions are reduced, and a tensor with a
  single element is returned.

  This function is more numerically stable than `log(reduce_mean(exp(input)))`.
  It avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  Args:
    input_tensor: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keepdims:  Boolean.  Whether to keep the axis as singleton dimensions.
      Default value: `False` (i.e., squeeze the reduced dimensions).
    experimental_named_axis: A `str or list of `str` axis names to additionally
      reduce over. Providing `None` will not reduce over any axes.
    experimental_allow_all_gather: Allow using an `all_gather`-based fallback
      under TensorFlow when computing the distributed maximum. This fallback is
      only efficient when `axis` reduces away most of the dimensions of
      `input_tensor`.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., `'reduce_logmeanexp'`).

  Returns:
    log_mean_exp: The reduced tensor.
  """
    with tf.name_scope(name or 'reduce_logmeanexp'):
        named_axes = distribute_lib.canonicalize_named_axis(
            experimental_named_axis)
        lse = distribute_lib.reduce_logsumexp(
            input_tensor,
            axis=axis,
            keepdims=keepdims,
            named_axis=named_axes,
            allow_all_gather=experimental_allow_all_gather)
        n = ps.size(input_tensor) // ps.size(lse)
        for named_axis in named_axes:
            n = n * distribute_lib.get_axis_size(named_axis)
        log_n = tf.math.log(tf.cast(n, lse.dtype))
        return lse - log_n
Esempio n. 4
0
  def __init__(self,
               inner_kernel,
               chain_axis_names,
               validate_args=False,
               name=None):
    """Constructs a `Sharded` transition kernel.

    Args:
      inner_kernel: A `TransitionKernel` to be sharded.
      chain_axis_names: A `str` or list of `str`s that determine the named axes
        that independent Markov chains will be sharded across.
      validate_args: Python `bool`. When `True` kernel parameters are checked
        for validity. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.
    """
    chain_axis_names = distribute_lib.canonicalize_named_axis(chain_axis_names)
    self._parameters = dict(
        inner_kernel=inner_kernel,
        chain_axis_names=chain_axis_names,
        validate_args=validate_args,
        name=name)
Esempio n. 5
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'snaper_hamiltonian_monte_carlo',
                                    'bootstrap_results')):
            init_state = tf.nest.map_structure(
                lambda x: tf.convert_to_tensor(x, name='init_state'),
                init_state)

            # It is unfortunate that we need to make this extra call to the TLP here.
            # The issue is that we need this value to even construct the PHMC, and
            # the kernel will call this one itself.
            tlp = self.target_log_prob_fn(*tf.nest.flatten(init_state))
            batch_shape = ps.shape(tlp)
            batch_ndims = ps.rank(tlp)
            if tf.get_static_value(batch_ndims) is None:
                # The issue doesn't live in this file, rather it is the downstream
                # components that fail to work (notably, tfb.Reshape).
                raise ValueError(
                    'SNAPERHMC currently requires a statically known '
                    'rank of the target log probability.')

            # We need at least two chains to estimate the principal component.
            # Number of total chains is local batch size * distributed axis size
            reduce_chain_axis_names = distribute_lib.canonicalize_named_axis(
                self.experimental_reduce_chain_axis_names)
            local_axis_size = ps.maximum(ps.size(tlp), 1)
            distributed_axis_size = int(
                ps.reduce_prod([
                    distribute_lib.get_axis_size(a)
                    for a in reduce_chain_axis_names
                ]))
            num_chains = local_axis_size * distributed_axis_size
            num_chains_ = tf.get_static_value(num_chains)
            if num_chains_ is not None:
                if num_chains_ < 2:
                    raise ValueError(
                        'SNAPERHMC requires at least 2 chains. Got: {}'.format(
                            num_chains_))
            elif self.validate_args:
                with tf.control_dependencies([
                        assert_util.assert_greater_equal(
                            num_chains, 2,
                            'SNAPERHMC requires at least 2 chains.')
                ]):
                    init_state = tf.nest.map_structure(tf.identity, init_state)

            event_axes = tf.nest.map_structure(
                lambda x: ps.range(batch_ndims, ps.rank(x)) - ps.rank(x),
                init_state)
            if self.experimental_shard_axis_names is None:
                shard_axis_names = tf.nest.map_structure(
                    lambda _: None, init_state)
            else:
                shard_axis_names = self.experimental_shard_axis_names

            ema_variance = tf.nest.map_structure(
                lambda x: tf.ones(  # pylint: disable=g-long-lambda
                    ps.shape(x)[batch_ndims:],
                    dtype=x.dtype,
                    name='ema_variance'),
                init_state)
            ema_mean = tf.nest.map_structure(
                lambda x: tf.zeros_like(x, name='ema_mean'), ema_variance)
            ema_principal_component = _normalize(ema_variance, event_axes,
                                                 shard_axis_names)
            # These start out at 1 for a bit of smoothing.
            state_ema_points = tf.ones([], tf.int32)
            principal_component_ema_points = tf.ones([], tf.int32)

            kernel = self._make_kernel(
                batch_shape=batch_shape,
                step=tf.zeros([], tf.int32),
                state_ema_points=state_ema_points,
                state=init_state,
                mean=ema_mean,
                variance=ema_variance,
                principal_component=ema_principal_component,
            )

            inner_results = kernel.bootstrap_results(
                tf.nest.flatten(init_state))

            kernel_results = SNAPERHamiltonianMonteCarloResults(
                inner_results=inner_results,
                ema_mean=ema_mean,
                ema_variance=ema_variance,
                state_ema_points=state_ema_points,
                ema_principal_component=ema_principal_component,
                principal_component_ema_points=principal_component_ema_points,
                seed=samplers.zeros_seed(),
            )
            return kernel_results
Esempio n. 6
0
    def __init__(self,
                 distribution,
                 shard_axis_name=None,
                 validate_args=False,
                 name=None):
        """Constructs a `Sharded` distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      shard_axis_name: `str` or a list of strings for axis name(s). An empty
        list means that no sharding is actually done. This can be `None` under
        the TensorFlow backend (meaning a sharded axis is present, but
        anonymous). Only the JAX backend supports multiple axes names.
      validate_args: Python `bool`.  Whether to validate input with asserts. If
        `validate_args` is `False`, and the inputs are invalid, correct behavior
        is not guaranteed.
      name: The name for ops managed by the distribution.
        Default value: `None` (i.e., `'Sharded' + distribution.name`).
    """
        parameters = dict(locals())

        if shard_axis_name is None:
            if JAX_MODE:
                # In JAX, axes names matter and we don't know which axis name the user
                # might intend, so we bail.
                raise ValueError(
                    'Cannot provide a `None` axis name in JAX backend.')
            else:
                # In TF, there are no axes names, so we can pick a reasonable default.
                shard_axis_name = [True]

        # Use inner axes before outer axes
        full_shard_axis_name = (
            distribution.experimental_shard_axis_names +
            distribute_lib.canonicalize_named_axis(shard_axis_name))

        if not JAX_MODE:
            if len(full_shard_axis_name) > 1:
                raise ValueError(
                    'TensorFlow backend does not support multiple shard axes:\n'
                    'inner shard_axis_names: '
                    f'{list(distribution.experimental_shard_axis_names)}\n'
                    f'outer shard_axis_names: {list(shard_axis_name)}')

        if len(set(full_shard_axis_name)) != len(full_shard_axis_name):
            duplicates = set()
            seen = set()
            for axis_name in full_shard_axis_name:
                if axis_name in seen:
                    duplicates.add(axis_name)
                seen.add(axis_name)
            raise ValueError(
                'Found duplicate axis name(s).\n'
                'inner shard_axis_names: '
                f'{list(distribution.experimental_shard_axis_names)}\n'
                f'outer shard_axis_names: {shard_axis_name}\n'
                f'duplicates: {list(duplicates)}')

        with tf.name_scope(name or 'Sharded' + distribution.name) as name:
            self._distribution = distribution
            self._shard_axis_name = full_shard_axis_name
            super(Sharded, self).__init__(
                dtype=self._distribution.dtype,
                validate_args=validate_args,
                allow_nan_stats=self._distribution.allow_nan_stats,
                reparameterization_type=self._distribution.
                reparameterization_type,
                parameters=parameters,
                name=name)
Esempio n. 7
0
def _update_trajectory_grad(previous_kernel_results,
                            previous_state,
                            proposed_state,
                            proposed_velocity,
                            trajectory_jitter,
                            accept_prob,
                            step_size,
                            criterion_fn,
                            max_leapfrog_steps,
                            experimental_shard_axis_names=None,
                            experimental_chain_axis_names=None):
    """Updates the trajectory length."""

    # Compute criterion grads.
    def leapfrog_action(dt):
        # This represents the effect on the criterion value as the state follows the
        # proposed velocity. This implicitly assumes an identity mass matrix.
        def adjust_state(x, v, shard_axes=None):
            broadcasted_dt = distribute_lib.pbroadcast(
                bu.left_justified_expand_dims_like(dt, v), shard_axes)
            return x + broadcasted_dt * v

        adjusted_state = _map_structure_up_to_with_axes(
            proposed_state,
            adjust_state,
            proposed_state,
            proposed_velocity,
            experimental_shard_axis_names=experimental_shard_axis_names)
        return criterion_fn(previous_state, adjusted_state, accept_prob)

    criterion, trajectory_grad = gradient.value_and_gradient(
        leapfrog_action, tf.zeros_like(accept_prob))
    trajectory_grad *= trajectory_jitter

    # Weight by acceptance probability.
    experimental_chain_axis_names = distribute_lib.canonicalize_named_axis(
        experimental_chain_axis_names)
    trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
    trajectory_grad = tf.where(tf.math.is_finite(trajectory_grad),
                               trajectory_grad, 0.)
    trajectory_grad = (_reduce_sum_with_axes(
        trajectory_grad * accept_prob, None, experimental_chain_axis_names) /
                       _reduce_sum_with_axes(accept_prob + 1e-20, None,
                                             experimental_chain_axis_names))

    # Compute Adam/RMSProp step size.
    dtype = previous_kernel_results.adaptation_rate.dtype
    iteration_f = tf.cast(previous_kernel_results.step, dtype) + 1.
    msg_adaptation_rate = 0.05
    new_averaged_sq_grad = (
        (1 - msg_adaptation_rate) * previous_kernel_results.averaged_sq_grad +
        msg_adaptation_rate * trajectory_grad**2)
    adjusted_averaged_sq_grad = new_averaged_sq_grad / (
        1. - (1 - msg_adaptation_rate)**iteration_f)
    trajectory_step_size = (previous_kernel_results.adaptation_rate /
                            tf.sqrt(adjusted_averaged_sq_grad + 1e-20))

    # Apply the gradient. Clip absolute value to ~log(2)/2.
    log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad,
                                  -0.35, 0.35)
    new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp(
        log_update)

    # Iterate averaging.
    average_weight = iteration_f**(-0.5)
    new_averaged_max_trajectory_length = tf.exp(
        average_weight * tf.math.log(new_max_trajectory_length) +
        (1 - average_weight) *
        tf.math.log(1e-10 +
                    previous_kernel_results.averaged_max_trajectory_length))

    # Clip the maximum trajectory length.
    new_max_trajectory_length = _clip_max_trajectory_length(
        new_max_trajectory_length, step_size,
        previous_kernel_results.adaptation_rate, max_leapfrog_steps)

    return previous_kernel_results._replace(
        criterion=criterion,
        max_trajectory_length=new_max_trajectory_length,
        averaged_sq_grad=new_averaged_sq_grad,
        averaged_max_trajectory_length=new_averaged_max_trajectory_length)
Esempio n. 8
0
def chees_criterion(previous_state,
                    proposed_state,
                    accept_prob,
                    validate_args=False,
                    experimental_shard_axis_names=None,
                    experimental_chain_axis_names=None):
    """The ChEES criterion from [1].

  ChEES stands for Change in the Estimator of the Expected Square.

  ```None
  ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2],
  ```

  where `x` is the previous chain state, `x'` is the next chain state, and
  `||.||` is the L2 norm. Both expectations are with respect to the chain's
  stationary distribution. In practice, the inner expectation is replaced by the
  empirical mean across chains, so computing this criterion requires that at
  least 2 chains are present. The outer expectation is computed by the caller
  (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).

  This can be thought of as the standard expected squared jump distance (ESJD)
  criterion, except that the jump distance is computed in the space of centered
  squared L2 norms.

  Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals,
  which can give excellent mean estimates but terrible variance estimates;
  maximizing ChEES should give good estimates across a wider range of types of
  posterior expectations.

  Args:
    previous_state: (Possibly nested) floating point `Tensor`. The previous
      state of the HMC chain.
    proposed_state: (Possibly nested) floating point `Tensor`. The proposed
      state of the HMC chain.
    accept_prob: Floating `Tensor`. Probability of acceping the proposed state.
    validate_args: Whether to perform non-static argument validation.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    experimental_chain_axis_names: A string or list of string names indicating
      how batches of chains are sharded.

  Returns:
    chees: The value of the ChEES criterion.

  Raises:
    ValueError: If `accept_prob` indicates that there are fewer than 2 chains.

  #### References

  [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme
       for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In
       preparation.

  """
    batch_ndims = ps.rank(accept_prob)
    batch_axes = ps.range(batch_ndims, dtype=tf.int32)
    experimental_chain_axis_names = distribute_lib.canonicalize_named_axis(
        experimental_chain_axis_names)
    # Number of total chains is local batch size * distributed axis size
    local_axis_size = ps.maximum(ps.size(accept_prob), 1)
    distributed_axis_size = int(
        ps.reduce_prod([
            distribute_lib.get_axis_size(a)
            for a in experimental_chain_axis_names
        ]))
    num_chains = local_axis_size * distributed_axis_size
    num_chains_ = tf.get_static_value(num_chains)
    if num_chains_ is not None:
        if num_chains_ < 2:
            raise ValueError(
                'chees_criterion requires at least 2 chains. Got: {}'.format(
                    num_chains_))
    elif validate_args:
        with tf.control_dependencies([
                assert_util.assert_greater_equal(
                    num_chains, 2,
                    'chees_criterion requires at least 2 chains.')
        ]):
            previous_state = tf.nest.map_structure(tf.identity, previous_state)

    def _center_previous_state(x):
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term.
        x_mean = _reduce_mean_with_axes(x, batch_axes,
                                        experimental_chain_axis_names)
        return x - tf.stop_gradient(x_mean)

    def _center_proposed_state(x):
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term. The goal here is to get a reliable
        # diagnostic of the unrelying dynamics, rather than incorporating the effect
        # of the MetropolisHastings correction.
        # TODO(mhoffman): Needs more experimentation.
        expanded_accept_prob = bu.left_justified_expand_dims_like(
            accept_prob, x)

        # accept_prob is zero when x is NaN, but we still want to sanitize such
        # values.
        x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        # If all accept_prob's are zero, the x_center will have a nonsense value,
        # but we'll discard the resultant gradients later on, so it's fine.
        x_center = (
            _reduce_sum_with_axes(expanded_accept_prob * x_safe, batch_axes,
                                  experimental_chain_axis_names) /
            (_reduce_sum_with_axes(expanded_accept_prob, batch_axes,
                                   experimental_chain_axis_names) + 1e-20))

        return x - tf.stop_gradient(x_center)

    def _sum_event_part(x, shard_axes=None):
        event_axes = ps.range(batch_ndims, ps.rank(x))
        return distribute_lib.psum(tf.reduce_sum(x, axis=event_axes),
                                   shard_axes)

    def _sum_event(x):
        event_parts = _map_structure_up_to_with_axes(
            x,
            _sum_event_part,
            x,
            experimental_shard_axis_names=experimental_shard_axis_names)
        return sum(tf.nest.flatten(event_parts))

    def _square(x):
        return tf.nest.map_structure(tf.square, x)

    def _sub(x, y):
        return tf.nest.map_structure(lambda x, y: x - y, x, y)

    previous_state = tf.nest.map_structure(_center_previous_state,
                                           previous_state)
    proposed_state = tf.nest.map_structure(_center_proposed_state,
                                           proposed_state)
    chees = 0.25 * tf.square(
        _sum_event(_sub(_square(proposed_state), _square(previous_state))))
    return chees
Esempio n. 9
0
def snaper_criterion(previous_state,
                     proposed_state,
                     accept_prob,
                     trajectory_length,
                     direction,
                     state_mean=None,
                     state_mean_weight=0.,
                     validate_args=False,
                     experimental_shard_axis_names=None,
                     experimental_reduce_chain_axis_names=None):
    """The SNAPER criterion from [1].

  SNAPER stands for Squared Norm Along Principal component ESJD Rate:

  ```None
  SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 /
             trajectory_length],
  ```

  where `x` is the previous chain state, `x'` is the next chain state, and `p`
  is a unit vector (the `direction` argument). Both expectations are with
  respect to the chain's stationary distribution. In practice, the inner
  expectation is replaced by the empirical mean across chains, so computing this
  criterion requires that at least 2 chains are present unless `state_mean` and
  `state_mean_weight` are set. The outer expectation is computed by the caller
  (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).

  This can be thought of as the standard expected squared jump distance (ESJD)
  criterion, except that the jump distance is computed in the space of squared
  projections onto a vector.

  The `direction` vector is typically chosen to be an approximation to the first
  principal component of the state covariance matrix.

  `state_mean` and `state_mean_weight` can be used to supplement the empirical
  means as follows:

  ```None
  E[x] ≈ (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.
  ```

  Args:
    previous_state: (Possibly nested) floating point `Tensor`. The previous
      state of the HMC chain.
    proposed_state: (Possibly nested) floating point `Tensor`. The proposed
      state of the HMC chain.
    accept_prob: Floating `Tensor`. Probability of acceping the proposed state.
    trajectory_length: Floating `Tensor`. Mean trajectory length (not used in
      this criterion).
    direction: (Possibly nested) floating point `Tensor`. A unit vector onto
      which the centered state should be projected before computing ESJD.
      Typically this chosen to be an approximation to the first principal
      component of the state covariance matrix.
    state_mean: Optional (Possibly nested) floating point `Tensor`. The
      estimated state mean.
    state_mean_weight: Floating point `Tensor`. The weight of the `state_mean`.
    validate_args: Whether to perform non-static argument validation.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    experimental_reduce_chain_axis_names: A string or list of string names
      indicating which named chain axes to reduce over when computing the
      criterion.

  Returns:
    snaper: The value of the SNAPER criterion.

  #### References

  [1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for
       Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>

  """
    batch_ndims = ps.rank(accept_prob)
    batch_axes = ps.range(batch_ndims, dtype=tf.int32)
    reduce_chain_axis_names = distribute_lib.canonicalize_named_axis(
        experimental_reduce_chain_axis_names)

    if state_mean is None:
        state_mean = tf.nest.map_structure(lambda _: None, previous_state)

        accept_prob = _check_at_least_two_chains(
            accept_prob,
            reduce_chain_axis_names=reduce_chain_axis_names,
            validate_args=validate_args,
            message=
            'snaper_criterion requires at least 2 chains when `state_mean` is `None`'
        )

    def _mix_in_state_mean(empirical_mean, state_mean):
        if state_mean is None:
            return empirical_mean
        else:
            return ((1. - state_mean_weight) * empirical_mean +
                    state_mean_weight * state_mean)

    def _center_previous_state(x, x_mean):
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term.
        emp_x_mean = tf.stop_gradient(
            distribute_lib.reduce_mean(x, batch_axes, reduce_chain_axis_names))
        x_mean = _mix_in_state_mean(emp_x_mean, x_mean)
        return x - x_mean

    def _center_proposed_state(x, x_mean):
        # Note that we don't do a monte carlo average of the accepted chain
        # position, but rather try to get an estimate of the underlying dynamics.
        # This is done by only looking at proposed states where the integration
        # error is low.
        expanded_accept_prob = bu.left_justified_expand_dims_like(
            accept_prob, x)

        # accept_prob is zero when x is NaN, but we still want to sanitize such
        # values.
        x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term.
        # If all accept_prob's are zero, the x_center will have a nonsense value,
        # but we'll discard the resultant gradients later on, so it's fine.
        emp_x_mean = tf.stop_gradient(
            distribute_lib.reduce_sum(expanded_accept_prob * x_safe,
                                      batch_axes, reduce_chain_axis_names) /
            (distribute_lib.reduce_sum(expanded_accept_prob, batch_axes,
                                       reduce_chain_axis_names) + 1e-20))

        x_mean = _mix_in_state_mean(emp_x_mean, x_mean)
        return x - x_mean

    def _dot_product_part(x, p, shard_axes=None):
        event_axes = ps.range(batch_ndims, ps.rank(x))
        return distribute_lib.reduce_sum(x * p, event_axes, shard_axes)

    def _dot_product(x):
        dot_products = _map_structure_up_to_with_axes(
            x,
            _dot_product_part,
            x,
            direction,
            experimental_shard_axis_names=experimental_shard_axis_names)
        return sum(tf.nest.flatten(dot_products))

    previous_state = tf.nest.map_structure(_center_previous_state,
                                           previous_state, state_mean)
    proposed_state = tf.nest.map_structure(_center_proposed_state,
                                           proposed_state, state_mean)

    previous_proj = _dot_product(previous_state)
    proposed_proj = _dot_product(proposed_state)

    snaper = (tf.square(tf.square(proposed_proj) - tf.square(previous_proj)) /
              trajectory_length)
    return snaper
Esempio n. 10
0
def chees_criterion(previous_state,
                    proposed_state,
                    accept_prob,
                    trajectory_length,
                    validate_args=False,
                    experimental_shard_axis_names=None,
                    experimental_reduce_chain_axis_names=None):
    """The ChEES criterion from [1].

  ChEES stands for Change in the Estimator of the Expected Square.

  ```None
  ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2],
  ```

  where `x` is the previous chain state, `x'` is the next chain state, and
  `||.||` is the L2 norm. Both expectations are with respect to the chain's
  stationary distribution. In practice, the inner expectation is replaced by the
  empirical mean across chains, so computing this criterion requires that at
  least 2 chains are present. The outer expectation is computed by the caller
  (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).

  This can be thought of as the standard expected squared jump distance (ESJD)
  criterion, except that the jump distance is computed in the space of centered
  squared L2 norms.

  Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals,
  which can give excellent mean estimates but terrible variance estimates;
  maximizing ChEES should give good estimates across a wider range of types of
  posterior expectations.

  Args:
    previous_state: (Possibly nested) floating point `Tensor`. The previous
      state of the HMC chain.
    proposed_state: (Possibly nested) floating point `Tensor`. The proposed
      state of the HMC chain.
    accept_prob: Floating `Tensor`. Probability of acceping the proposed state.
    trajectory_length: Floating `Tensor`. Mean trajectory length (not used in
      this criterion).
    validate_args: Whether to perform non-static argument validation.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    experimental_reduce_chain_axis_names: A string or list of string names
      indicating which named chain axes to reduce over when computing the
      criterion.

  Returns:
    chees: The value of the ChEES criterion.

  Raises:
    ValueError: If `accept_prob` indicates that there are fewer than 2 chains.

  #### References

  [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme
       for Setting Trajectory Lengths in Hamiltonian Monte Carlo.
       <https://proceedings.mlr.press/v130/hoffman21a>

  """
    del trajectory_length
    batch_ndims = ps.rank(accept_prob)
    batch_axes = ps.range(batch_ndims, dtype=tf.int32)
    reduce_chain_axis_names = distribute_lib.canonicalize_named_axis(
        experimental_reduce_chain_axis_names)

    accept_prob = _check_at_least_two_chains(
        accept_prob,
        reduce_chain_axis_names=reduce_chain_axis_names,
        validate_args=validate_args,
        message='chees_criterion requires at least 2 chains.',
    )

    def _center_previous_state(x):
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term.
        x_mean = _reduce_mean_with_axes(x, batch_axes, reduce_chain_axis_names)
        return x - tf.stop_gradient(x_mean)

    def _center_proposed_state(x):
        # Note that we don't do a monte carlo average of the accepted chain
        # position, but rather try to get an estimate of the underlying dynamics.
        # This is done by only looking at proposed states where the integration
        # error is low.
        # TODO(mhoffman): Needs more experimentation.
        expanded_accept_prob = bu.left_justified_expand_dims_like(
            accept_prob, x)

        # accept_prob is zero when x is NaN, but we still want to sanitize such
        # values.
        x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        # If all accept_prob's are zero, the x_center will have a nonsense value,
        # but we'll discard the resultant gradients later on, so it's fine.
        x_center = (
            _reduce_sum_with_axes(expanded_accept_prob * x_safe, batch_axes,
                                  reduce_chain_axis_names) /
            (_reduce_sum_with_axes(expanded_accept_prob, batch_axes,
                                   reduce_chain_axis_names) + 1e-20))

        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term.
        return x - tf.stop_gradient(x_center)

    def _sum_event_part(x, shard_axes=None):
        event_axes = ps.range(batch_ndims, ps.rank(x))
        return distribute_lib.psum(tf.reduce_sum(x, axis=event_axes),
                                   shard_axes)

    def _sum_event(x):
        event_parts = _map_structure_up_to_with_axes(
            x,
            _sum_event_part,
            x,
            experimental_shard_axis_names=experimental_shard_axis_names)
        return sum(tf.nest.flatten(event_parts))

    def _square(x):
        return tf.nest.map_structure(tf.square, x)

    def _sub(x, y):
        return tf.nest.map_structure(lambda x, y: x - y, x, y)

    previous_state = tf.nest.map_structure(_center_previous_state,
                                           previous_state)
    proposed_state = tf.nest.map_structure(_center_proposed_state,
                                           proposed_state)
    chees = 0.25 * tf.square(
        _sum_event(_sub(_square(proposed_state), _square(previous_state))))
    return chees