Beispiel #1
0
      def kernel(hmc_state, step_size_state, step, seed):
        if not self._is_on_jax:
          hmc_seed = _test_seed()
        else:
          hmc_seed, seed = util.split_seed(seed, 2)
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=hmc_seed)

        rate = fun_mcmc.prefab._polynomial_decay(  # pylint: disable=protected-access
            step=step,
            step_size=self._constant(0.01),
            power=0.5,
            decay_steps=num_adapt_steps,
            final_step_size=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1, seed),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))
Beispiel #2
0
 def kernel(hmc_state, seed):
   if not self._is_on_jax:
     hmc_seed = _test_seed()
   else:
     hmc_seed, seed = util.split_seed(seed, 2)
   hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
       hmc_state,
       step_size=step_size,
       num_integrator_steps=num_leapfrog_steps,
       target_log_prob_fn=target_log_prob_fn,
       seed=hmc_seed)
   return (hmc_state, seed), hmc_extra
 def kernel(hmc_state, seed):
     if backend.get_backend() == backend.TENSORFLOW:
         hmc_seed = tfp_test_util.test_seed()
     else:
         hmc_seed, seed = util.split_seed(seed, 2)
     hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
         hmc_state,
         step_size=step_size,
         num_integrator_steps=num_leapfrog_steps,
         target_log_prob_fn=target_log_prob_fn,
         seed=hmc_seed)
     return (hmc_state, seed), hmc_extra
        def trace():
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState(
                tf.zeros([1])),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Beispiel #5
0
 def kernel(hmc_state, raac_state, seed):
   if not self._is_on_jax:
     hmc_seed = _test_seed()
   else:
     hmc_seed, seed = util.split_seed(seed, 2)
   hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
       hmc_state,
       step_size=step_size,
       num_integrator_steps=num_leapfrog_steps,
       target_log_prob_fn=target_log_prob_fn,
       seed=hmc_seed)
   raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step(
       raac_state, hmc_state.state, axis=aggregation)
   return (hmc_state, raac_state, seed), hmc_extra
        def trace():
            # pylint: disable=g-long-lambda
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=self._make_seed(tfp_test_util.test_seed()))

            fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init(
                state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Beispiel #7
0
    def trace():
      kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
          state,
          step_size=self._constant(0.1),
          num_integrator_steps=3,
          target_log_prob_fn=target_log_prob_fn,
          seed=_test_seed())

      fun_mcmc.trace(
          state=fun_mcmc.hamiltonian_monte_carlo_init(
              tf.zeros([1], dtype=self._dtype), target_log_prob_fn),
          fn=kernel,
          num_steps=4,
          trace_fn=lambda *args: ())
Beispiel #8
0
  def testPreconditionedHMC(self):
    step_size = 0.2
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2])

    base_mean = [1., 0]
    base_cov = [[1, 0.5], [0.5, 1]]

    bijector = tfb.Softplus()
    base_dist = tfd.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
        state,
        step_size=step_size,
        num_integrator_steps=num_leapfrog_steps,
        target_log_prob_fn=target_log_prob_fn,
        seed=_test_seed()))

    _, chain = fun_mcmc.trace(
        state=fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state.state_extra[0])
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=_test_seed())

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
        def kernel(hmc_state, step_size, step):
            """HMC kernel."""
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_integrator_steps=FLAGS.mcmc_leapfrog_steps,
                momentum_sample_fn=create_momentum_sample_fn(hmc_state.state),
                target_log_prob_fn=log_prob_non_transformed)

            mean_p_accept = tf.reduce_mean(
                tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

            if FLAGS.mcmc_adapt_step_size:
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9)

            return (hmc_state, step_size, step + 1), hmc_extra
Beispiel #10
0
            def kernel(hmc_state, step_size, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=step_size,
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(
                    tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9,
                                                     adaptation_rate=rate)

                return (hmc_state, step_size, step + 1), hmc_extra
Beispiel #11
0
      def kernel(hmc_state, step_size_state, step):
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn)

        rate = tf.compat.v1.train.polynomial_decay(
            0.01,
            global_step=step,
            power=0.5,
            decay_steps=num_adapt_steps,
            end_learning_rate=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))