Example #1
0
def adaptive_hamiltonian_monte_carlo_init(
    state: 'fun_mc.TensorNest',
    target_log_prob_fn: 'fun_mc.PotentialFn',
    step_size: 'fun_mc.FloatTensor' = 1e-2,
    initial_mean: 'fun_mc.FloatNest' = 0.,
    initial_scale: 'fun_mc.FloatNest' = 1.,
    scale_smoothing_steps: 'fun_mc.IntTensor' = 10,
) -> 'AdaptiveHamiltonianMonteCarloState':
    """Initializes `AdaptiveHamiltonianMonteCarloState`.

  Args:
    state: Initial state of the chain.
    target_log_prob_fn: Target log prob fn.
    step_size: Initial scalar step size.
    initial_mean: Initial mean for computing the running variance estimate. Must
      broadcast structurally and tensor-wise with state.
    initial_scale: Initial scale for computing the running variance estimate.
      Must broadcast structurally and tensor-wise with state.
    scale_smoothing_steps: How much weight to assign to the `initial_mean` and
      `initial_scale`. Increase this to stabilize early adaptation.

  Returns:
    adaptive_hmc_state: State of the `adaptive_hamiltonian_monte_carlo_step`
      `TransitionOperator`.
  """
    hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn)
    dtype = util.flatten_tree(hmc_state.state)[0].dtype
    chain_ndims = len(hmc_state.target_log_prob.shape)
    running_var_state = fun_mc.running_variance_init(
        shape=util.map_tree(lambda s: s.shape[chain_ndims:], hmc_state.state),
        dtype=util.map_tree(lambda s: s.dtype, hmc_state.state),
    )
    initial_mean = fun_mc.maybe_broadcast_structure(initial_mean, state)
    initial_scale = fun_mc.maybe_broadcast_structure(initial_scale, state)

    # It's important to add some smoothing here, as initial updates can be very
    # different than the stationary distribution.
    # TODO(siege): Add a pseudo-update functionality to avoid fiddling with the
    # internals here.
    running_var_state = running_var_state._replace(
        num_points=util.map_tree(
            lambda p: (  # pylint: disable=g-long-lambda
                int(np.prod(hmc_state.target_log_prob.shape)) * tf.cast(
                    scale_smoothing_steps, p.dtype)),
            running_var_state.num_points),
        mean=util.map_tree(  # pylint: disable=g-long-lambda
            lambda m, init_m: tf.ones_like(m) * init_m, running_var_state.mean,
            initial_mean),
        variance=util.map_tree(  # pylint: disable=g-long-lambda
            lambda v, init_s: tf.ones_like(v) * init_s**2,
            running_var_state.variance, initial_scale),
    )
    log_step_size_opt_state = fun_mc.adam_init(
        tf.math.log(tf.convert_to_tensor(step_size, dtype=dtype)))

    return AdaptiveHamiltonianMonteCarloState(
        hmc_state=hmc_state,
        running_var_state=running_var_state,
        log_step_size_opt_state=log_step_size_opt_state,
        step=tf.zeros([], tf.int32))
Example #2
0
    def computation(state, seed):
      bijector = tfp.bijectors.Softplus()
      base_dist = tfp.distributions.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step, seed):
        hmc_seed, seed = util.split_seed(seed, 2)
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=hmc_seed)

        rate = prefab._polynomial_decay(  # pylint: disable=protected-access
            step=step,
            step_size=self._constant(0.01),
            power=0.5,
            decay_steps=num_adapt_steps,
            final_step_size=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1, seed),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(
          4096, seed=self._make_seed(_test_seed()))
      return chain, log_accept_ratio_trace, true_samples
Example #3
0
  def testAdam(self):

    def loss_fn(x, y):
      return tf.square(x - 1.) + tf.square(y - 2.), []

    _, [(x, y), loss] = fun_mcmc.trace(
        fun_mcmc.adam_init([self._constant(0.), self._constant(0.)]),
        lambda adam_state: fun_mcmc.adam_step(  # pylint: disable=g-long-lambda
            adam_state,
            loss_fn,
            learning_rate=self._constant(0.01)),
        num_steps=1000,
        trace_fn=lambda state, extra: [state.state, extra.loss])

    self.assertAllClose(1., x[-1], atol=1e-3)
    self.assertAllClose(2., y[-1], atol=1e-3)
    self.assertAllClose(0., loss[-1], atol=1e-3)