Example #1
0
def adaptive_hamiltonian_monte_carlo_init(
    state: 'fun_mc.TensorNest',
    target_log_prob_fn: 'fun_mc.PotentialFn',
    step_size: 'fun_mc.FloatTensor' = 1e-2,
    initial_mean: 'fun_mc.FloatNest' = 0.,
    initial_scale: 'fun_mc.FloatNest' = 1.,
    scale_smoothing_steps: 'fun_mc.IntTensor' = 10,
) -> 'AdaptiveHamiltonianMonteCarloState':
  """Initializes `AdaptiveHamiltonianMonteCarloState`.

  Args:
    state: Initial state of the chain.
    target_log_prob_fn: Target log prob fn.
    step_size: Initial scalar step size.
    initial_mean: Initial mean for computing the running variance estimate. Must
      broadcast structurally and tensor-wise with state.
    initial_scale: Initial scale for computing the running variance estimate.
      Must broadcast structurally and tensor-wise with state.
    scale_smoothing_steps: How much weight to assign to the `initial_mean` and
      `initial_scale`. Increase this to stabilize early adaptation.

  Returns:
    adaptive_hmc_state: State of the `adaptive_hamiltonian_monte_carlo_step`
      `TransitionOperator`.
  """
  hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn)
  dtype = util.flatten_tree(hmc_state.state)[0].dtype
  chain_ndims = len(hmc_state.target_log_prob.shape)
  running_var_state = fun_mc.running_variance_init(
      shape=util.map_tree(lambda s: s.shape[chain_ndims:], hmc_state.state),
      dtype=util.map_tree(lambda s: s.dtype, hmc_state.state),
  )
  initial_mean = fun_mc.maybe_broadcast_structure(initial_mean, state)
  initial_scale = fun_mc.maybe_broadcast_structure(initial_scale, state)

  # It's important to add some smoothing here, as initial updates can be very
  # different than the stationary distribution.
  # TODO(siege): Add a pseudo-update functionality to avoid fiddling with the
  # internals here.
  running_var_state = running_var_state._replace(
      num_points=util.map_tree(
          lambda p: (  # pylint: disable=g-long-lambda
              int(np.prod(hmc_state.target_log_prob.shape)) * tf.cast(
                  scale_smoothing_steps, p.dtype)),
          running_var_state.num_points),
      mean=util.map_tree(  # pylint: disable=g-long-lambda
          lambda m, init_m: tf.ones_like(m) * init_m, running_var_state.mean,
          initial_mean),
      variance=util.map_tree(  # pylint: disable=g-long-lambda
          lambda v, init_s: tf.ones_like(v) * init_s**2,
          running_var_state.variance, initial_scale),
  )
  ssa_state = step_size_adaptation_init(
      tf.convert_to_tensor(step_size, dtype=dtype))

  return AdaptiveHamiltonianMonteCarloState(
      hmc_state=hmc_state,
      running_var_state=running_var_state,
      ssa_state=ssa_state,
      step=tf.zeros([], tf.int32))
Example #2
0
    def grad(*grads):
      grads = util.unflatten_tree(res, util.flatten_tree(grads))

      step_size_scale_bc = fun_mc.maybe_broadcast_structure(
          step_size_scale, hmc_extra.integrator_extra.momentum_grads)

      # We wish to compute `grads^T @
      # jacobian(proposed_state(trajectory_length))`.
      #
      # The Jacobian is known from from Hamilton's equations:
      #
      # dx / dt = dK(v) / dv
      #
      # where `x` is the state, `v` is the momentum and `K` is the kinetic
      # energy. Since `step_size_scale` rescales momentum, we the right hand
      # side of that expression is `momentum_grads * step_size_scale` by the
      # chain rule. Since the Jacobian in question has 1 row, the
      # vector-Jacobian product is simply the dot product.
      state_grads = util.map_tree(lambda s, m, g: s * m * g, step_size_scale_bc,
                                  hmc_extra.integrator_extra.momentum_grads,
                                  grads[1].proposed_state)

      def do_sum(x, shard_axis_names):
        res = tf.reduce_sum(
            x, list(range(len(trajectory_length.shape), len(x.shape))))
        if shard_axis_names:
          res = backend.distribute_lib.psum(res, shard_axis_names)
        return res

      if shard_axis_names:
        shard_axis_names_bc = shard_axis_names
      else:
        shard_axis_names_bc = util.map_tree(lambda _: [], state_grads)

      return sum(
          util.flatten_tree(
              util.map_tree_up_to(state_grads, do_sum, state_grads,
                                  shard_axis_names_bc)))
Example #3
0
def persistent_hamiltonian_monte_carlo_step(
    phmc_state: 'PersistentHamiltonianMonteCarloState',
    target_log_prob_fn: 'fun_mc.PotentialFn',
    step_size: 'Optional[Any]' = None,
    num_integrator_steps: 'Optional[fun_mc.IntTensor]' = None,
    noise_fraction: 'Optional[fun_mc.FloatTensor]' = None,
    mh_drift: 'Optional[fun_mc.FloatTensor]' = None,
    kinetic_energy_fn: 'Optional[fun_mc.PotentialFn]' = None,
    momentum_sample_fn: 'Optional[PersistentMomentumSampleFn]' = None,
    integrator_trace_fn: 'Callable[[fun_mc.IntegratorStepState, '
    'fun_mc.IntegratorStepExtras], fun_mc.TensorNest]' = lambda *args: (),
    log_uniform: 'Optional[fun_mc.FloatTensor]' = None,
    integrator_fn: 'Optional[Callable[[fun_mc.IntegratorState, '
    'fun_mc.FloatTensor], Tuple[fun_mc.IntegratorState, '
    'fun_mc.IntegratorExtras]]]' = None,
    unroll_integrator: 'bool' = False,
    max_num_integrator_steps: 'Optional[fun_mc.IntTensor]' = None,
    energy_change_fn: 'Callable[[fun_mc.IntegratorState, '
    'fun_mc.IntegratorState, fun_mc.IntegratorExtras], '
    'Tuple[fun_mc.FloatTensor, Any]]' = (
        fun_mc._default_hamiltonian_monte_carlo_energy_change_fn),  # pylint: disable=protected-access
    named_axis: 'Optional[fun_mc.StringNest]' = None,
    seed=None,
) -> ('Tuple[PersistentHamiltonianMonteCarloState, '
      'PersistentHamiltonianMonteCarloExtra]'):
    """A step of the Persistent Hamiltonian Monte Carlo `TransitionOperator`.

  This is an implementation of the generalized HMC with persistent momentum
  described in [1] (algorithm 15) combined with the persistent Metropolis
  Hastings test from [2]. This generalizes the regular HMC with persistent
  momentum from [3] and the various underdamped langevin dynamics schemes (e.g.
  [4]).

  The generalization lies in the free choice of `momentum_sample_fn` and
  `kinetic_energy_fn`. The former forms a Markov Chain with the stationary
  distribution implied by the `kinetic_energy_fn`. By default, the standard
  quadratic kinetic energy is used and the underdamped update is used for
  `momentum_sample_fn`, namely:

  ```none
  new_momentum = (
      (1 - noise_fraction**2)**0.5 * old_momentum  +
      noise_fraction * eps)
  eps ~ Normal(0, 1)
  ```

  Here are the parameter settings for few special cases:

  1. Persistent Hamiltonian Monte Carlo [1] + persistent MH [2]:

  ```none
  num_integrator_steps >= 1
  step_size > 0
  noise_fraction in [0, 1]
  mh_drift = 0.03
  ```

  Empirical results suggest that if `num_integrator_steps == 1`, then a
  reasonable value for `mh_drift` is `1 - (1 - noise_fraction**2)**0.5`.

  2. Unadjusted Underdamped Langevin Dynamics (see [4]):

  ```none
  num_integrator_steps = 1
  step_size > 0
  noise_fraction = (1 - exp(-2 * step_size * dampening))**0.5
  # This disables the MH step for all but most extreme divergences.
  log_uniform = -1000
  ```

  `dampening` refers to the parameter in the SDE formulation of the algorithm:

  ```none
  dv_t = -dampening * v_t * dt - grad(f)(x_t) * dt + (2 * dampening)**0.5 * dB_t
  dx_t = v_t * dt
  ```

  Args:
    phmc_state: `PersistentHamiltonianMonteCarloState`.
    target_log_prob_fn: Target log prob fn.
    step_size: Step size, structure broadcastable to the `target_log_prob_fn`
      state. Optional if `integrator_fn` is specified.
    num_integrator_steps: Number of integrator steps to take. Optional if
      `integrator_fn` is specified.
    noise_fraction: Noise fraction when refreshing momentum. Optional if
      `momentum_sample_fn` is specified.
    mh_drift: Metropolis Hastings drift term. Optional if `log_uniform` is
      specified.
    kinetic_energy_fn: Kinetic energy function.
    momentum_sample_fn: Sampler for the momentum.
    integrator_trace_fn: Trace function for the integrator.
    log_uniform: Optional logarithm of a uniformly distributed random sample in
      [0, 1], used for the MH accept/reject step.
    integrator_fn: Integrator to use for the HMC dynamics. Uses a
      `hamiltonian_integrator` with `leapfrog_step` by default.
    unroll_integrator: Whether to unroll the loop in the integrator. Only works
      if `num_integrator_steps`/`max_num_integrator_steps' is statically known.
      Ignored if `integrator_fn` is specified.
    max_num_integrator_steps: Maximum number of integrator steps to take. Useful
      when `num_integrator_steps` is dynamic, and yet you still want
      gradients/tracing to work. Ignored if `integrator_fn` is specified.
    energy_change_fn: Callable with signature: `(current_integrator_state,
      proposed_integrator_state,) -> (energy_change, energy_change_extra)`.
      Computes the change in energy between current and proposed states. By
      default, it just substracts the current and proposed energies. A typical
      reason to override this is to improve numerical stability.
    named_axis: Named axes of the state, same structure as `hmc_state.state`.
    seed: For reproducibility.

  Returns:
    phmc_state: PersistentHamiltonianMonteCarloState
    phmc_extra: PersistentHamiltonianMonteCarloExtra

  #### References

  [1]: Neklyudov, K., Welling, M., Egorov, E., & Vetrov, D. (2020). Involutive
       MCMC: a Unifying Framework.

  [2]: Neal, R. M. (2020). Non-reversibly updating a uniform [0,1] value for
       Metropolis accept/reject decisions.

  [3]: Horowitz, A. M. (1991). A generalized guided Monte Carlo algorithm.
       Physics Letters. [Part B], 268(2), 247-252.

  [4]: Ma, Y.-A., Chatterji, N., Cheng, X., Flammarion, N., Bartlett, P., &
       Jordan, M. I. (2019). Is There an Analog of Nesterov Acceleration for
       MCMC?
  """
    state = phmc_state.state
    momentum = phmc_state.momentum
    direction = phmc_state.direction
    state_grads = phmc_state.state_grads
    target_log_prob = phmc_state.target_log_prob
    state_extra = phmc_state.state_extra
    pmh_state = phmc_state.pmh_state

    # Impute the optional args.
    if kinetic_energy_fn is None:
        kinetic_energy_fn = fun_mc.make_gaussian_kinetic_energy_fn(
            len(target_log_prob.shape)
            if target_log_prob.shape is not None else tf.rank(target_log_prob),
            named_axis=named_axis)

    if momentum_sample_fn is None:
        if named_axis is None:
            named_axis = util.map_tree(lambda _: [], state)

        def _momentum_sample_fn(old_momentum: fun_mc.State,
                                seed: Any) -> Tuple[fun_mc.State, Tuple[()]]:
            seeds = util.unflatten_tree(
                old_momentum,
                util.split_seed(seed, len(util.flatten_tree(old_momentum))))

            def _sample_part(old_momentum, seed, named_axis):
                seed = backend.distribute_lib.fold_in_axis_index(
                    seed, named_axis)
                return (tf.math.sqrt(1 - tf.square(noise_fraction)) *
                        old_momentum + noise_fraction * util.random_normal(
                            old_momentum.shape, old_momentum.dtype, seed))

            new_momentum = util.map_tree_up_to(state, _sample_part,
                                               old_momentum, seeds, named_axis)
            return new_momentum

        momentum_sample_fn = _momentum_sample_fn

    if integrator_fn is None:
        step_size = util.map_tree(tf.convert_to_tensor, step_size)
        step_size = fun_mc.maybe_broadcast_structure(step_size, state)

        def _integrator_fn(
            state: fun_mc.IntegratorState, direction: fun_mc.FloatTensor
        ) -> Tuple[fun_mc.IntegratorState, fun_mc.IntegratorExtras]:

            directional_step_size = util.map_tree(
                lambda step_size, state: (  # pylint: disable=g-long-lambda
                    step_size * tf.reshape(
                        direction,
                        list(direction.shape) + [1] *
                        (len(state.shape) - len(direction.shape)))),
                step_size,
                state.state)
            # TODO(siege): Ideally we'd pass in the direction here, but the
            # `hamiltonian_integrator` cannot handle dynamic direction switching like
            # that.
            return fun_mc.hamiltonian_integrator(
                state,
                num_steps=num_integrator_steps,
                integrator_step_fn=functools.partial(
                    fun_mc.leapfrog_step,
                    step_size=directional_step_size,
                    target_log_prob_fn=target_log_prob_fn,
                    kinetic_energy_fn=kinetic_energy_fn),
                kinetic_energy_fn=kinetic_energy_fn,
                unroll=unroll_integrator,
                max_num_steps=max_num_integrator_steps,
                integrator_trace_fn=integrator_trace_fn)

        integrator_fn = _integrator_fn

    seed, sample_seed = util.split_seed(seed, 2)
    momentum = momentum_sample_fn(momentum, sample_seed)

    initial_integrator_state = fun_mc.IntegratorState(
        target_log_prob=target_log_prob,
        momentum=momentum,
        state=state,
        state_grads=state_grads,
        state_extra=state_extra,
    )

    integrator_state, integrator_extra = integrator_fn(
        initial_integrator_state, direction)

    proposed_state = phmc_state._replace(
        state=integrator_state.state,
        state_grads=integrator_state.state_grads,
        target_log_prob=integrator_state.target_log_prob,
        momentum=integrator_state.momentum,
        state_extra=integrator_state.state_extra,
        # Flip the direction in the proposal, for reversibility.
        direction=util.map_tree(lambda d: -d, direction),
    )

    # Stick the new momentum into phmc_state. We're doing accept/reject purely on
    # the Hamiltonian proposal, not the momentum refreshment kernel.
    phmc_state = phmc_state._replace(momentum=momentum)

    energy_change, energy_change_extra = energy_change_fn(
        initial_integrator_state,
        integrator_state,
        integrator_extra,
    )

    if log_uniform is None:
        pmh_state, pmh_extra = fun_mc.persistent_metropolis_hastings_step(
            pmh_state,
            current_state=phmc_state,
            proposed_state=proposed_state,
            energy_change=energy_change,
            drift=mh_drift)
        is_accepted = pmh_extra.is_accepted
        phmc_state = pmh_extra.accepted_state
    else:
        # We explicitly don't update the PMH state.
        phmc_state, mh_extra = fun_mc.metropolis_hastings_step(
            current_state=phmc_state,
            proposed_state=proposed_state,
            energy_change=energy_change,
            log_uniform=log_uniform)
        is_accepted = mh_extra.is_accepted

    phmc_state = typing.cast(PersistentHamiltonianMonteCarloState, phmc_state)
    phmc_state = phmc_state._replace(
        pmh_state=pmh_state,
        # Flip the direction unconditionally; when the state is accepted, this
        # undoes the flip made in the proposal, maintaining the old momentum
        # direction.
        direction=util.map_tree(lambda d: -d, phmc_state.direction),
    )

    return phmc_state, PersistentHamiltonianMonteCarloExtra(
        is_accepted=is_accepted,
        proposed_phmc_state=proposed_state,
        log_accept_ratio=-energy_change,
        integrator_state=integrator_state,
        integrator_extra=integrator_extra,
        energy_change_extra=energy_change_extra,
        initial_momentum=momentum)
Example #4
0
def chees_criterion(
    previous_state: 'fun_mc.State',
    proposed_state: 'fun_mc.State',
    accept_prob: 'fun_mc.FloatTensor',
    trajectory_length: 'Optional[fun_mc.FloatTensor]' = None,
    state_mean: 'Optional[fun_mc.State]' = None,
    state_mean_weight: 'fun_mc.FloatNest' = 0.,
    named_axis: 'Optional[fun_mc.StringNest]' = None,
    chain_named_axis: 'Optional[fun_mc.StringNest]' = None,
) -> 'Tuple[fun_mc.FloatTensor, fun_mc.FloatTensor]':
  """The ChEES criterion from [1].

  ChEES stands for Change in the Estimator of the Expected Square.

  ```None
  ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2],
  ```

  where `x` is the previous chain state, `x'` is the next chain state, and
  `||.||` is the L2 norm. Both expectations are with respect to the chain's
  stationary distribution. In practice, the inner expectation is replaced by the
  empirical mean across chains optionally averaged with a provided `state_mean`
  (weighted by `state_mean_weight`).

  This can be thought of as the standard expected squared jump distance (ESJD)
  criterion, except that the jump distance is computed in the space of centered
  squared L2 norms. It is also possible to relate ChEES to ESS computed in the
  same space if the true autocorrelation function of the centered squared L2
  norm follows a certain functional form.

  ChEES in this implementation is scaled by a normalized
  acceptance probability, so as to discard contributions from bad proposals.

  Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals,
  which can give excellent mean estimates but terrible variance estimates;
  maximizing ChEES should give good estimates across a wider range of types of
  posterior expectations.

  Args:
    previous_state: (Possibly nested) floating point `Tensor`. The previous
      state of the MCMC chain.
    proposed_state: (Possibly nested) floating point `Tensor`. The proposed
      state of the MCMC chain.
    accept_prob: Floating `Tensor`. Probability of acceping the proposed state.
    trajectory_length: Ignored.
    state_mean: (Possibly nested) floating point `Tensor`. Optional estimate of
      the MCMC chain mean.
    state_mean_weight: Floating point `Tensor`. Used to weight `state_mean` with
      the mean computed by averaging across the previous/proposed state. Setting
      it to effectively uses `state_mean` as the only source of the MCMCM chain
      mean.
    named_axis: Named axes of the state. Same structure as `previous_state`.
    chain_named_axis: Named axes of the MCMC chain that the criterion is to be
      averaged over.

  Returns:
    chees: The value of the ChEES criterion.
    per_chain_chees: The value of the ChEES criterion per chain.

  #### References

  [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme
       for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In
       preparation.

  """
  del trajectory_length
  batch_ndims = len(accept_prob.shape)
  batch_axes = tuple(range(batch_ndims))
  no_state_mean = object()
  if state_mean is None:
    state_mean = fun_mc.maybe_broadcast_structure(no_state_mean, previous_state)
  state_mean_weight = fun_mc.maybe_broadcast_structure(state_mean_weight,
                                                       previous_state)
  if named_axis is None:
    named_axis_bc = util.map_tree(lambda _: [], previous_state)
  else:
    named_axis_bc = named_axis

  if chain_named_axis is None:
    chain_named_axis = []

  def _center_previous_state(x, mx, mw):
    x_center = distribute_lib.reduce_mean(
        x, axis=batch_axes, named_axis=chain_named_axis)
    if mx is not no_state_mean:
      x_center = x_center * (1 - mw) + mx * mw
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term.
    return x - tf.stop_gradient(x_center)

  def _center_proposed_state(x, mx, mw):
    expand_shape = list(accept_prob.shape) + [1] * (
        len(x.shape) - len(accept_prob.shape))
    expanded_accept_prob = tf.reshape(accept_prob, expand_shape)

    # Weight the proposed state by the acceptance probability. The goal here is
    # to get a reliable diagnostic of the underlying dynamics, rather than
    # incorporating the effect of the MetropolisHastings correction.

    # accept_prob is zero when x is NaN, but we still want to sanitize such
    # values.
    x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
    # If all accept_prob's are zero, the x_center will have a nonsense value,
    # but we'll set the overall criterion to zero in this case, so it's fine.
    x_center = (
        distribute_lib.reduce_sum(
            expanded_accept_prob * x_safe,
            axis=batch_axes,
            named_axis=chain_named_axis) /
        (distribute_lib.reduce_sum(
            expanded_accept_prob, axis=batch_axes, named_axis=chain_named_axis)
         + 1e-20))
    if mx is not no_state_mean:
      x_center = x_center * (1 - mw) + mx * mw
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term.
    return x - tf.stop_gradient(x_center)

  def _sum_event_part(x, named_axis):
    event_axes = tuple(range(batch_ndims, len(x.shape)))
    return distribute_lib.reduce_sum(x, axis=event_axes, named_axis=named_axis)

  def _sum_event(x):
    return sum(
        util.flatten_tree(
            util.map_tree_up_to(
                x,
                _sum_event_part,
                x,
                named_axis_bc,
            )))

  def _square(x):
    return util.map_tree(tf.square, x)

  def _sub(x, y):
    return util.map_tree(lambda x, y: x - y, x, y)

  previous_state = util.map_tree(_center_previous_state, previous_state,
                                 state_mean, state_mean_weight)
  proposed_state = util.map_tree(_center_proposed_state, proposed_state,
                                 state_mean, state_mean_weight)
  chees = 0.25 * tf.square(
      _sum_event(_sub(_square(proposed_state), _square(previous_state))))

  # Zero-out per-chain ChEES values where acceptance probability is low. Those
  # values are probably not reflective of the underlying dynamics.
  chees = tf.where(accept_prob > 1e-4, chees, 0.)
  accept_prob = accept_prob / distribute_lib.reduce_sum(
      accept_prob + 1e-20, named_axis=chain_named_axis)
  chees = chees * accept_prob

  return distribute_lib.reduce_mean(chees, named_axis=chain_named_axis), chees