Esempio n. 1
0
        def trace():
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState(
                tf.zeros([1])),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Esempio n. 2
0
        def trace():
            # pylint: disable=g-long-lambda
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init(
                state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Esempio n. 3
0
    def testBasicHMC(self):
        if not tf.executing_eagerly():
            return
        step_size = 0.2
        num_steps = 2000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = [1., 0]
        base_cov = [[1, 0.5], [0.5, 1]]

        bijector = tfb.Softplus()
        base_dist = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        target_dist = bijector(base_dist)

        def orig_target_log_prob_fn(x):
            return target_dist.log_prob(x), ()

        target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
            orig_target_log_prob_fn, bijector, state)

        # pylint: disable=g-long-lambda
        kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
            state,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=tfp_test_util.test_seed()))

        _, chain = fun_mcmc.trace(
            state=fun_mcmc.HamiltonianMonteCarloState(state=state,
                                                      state_grads=None,
                                                      target_log_prob=None,
                                                      state_extra=None),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state.state_extra[0])

        sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1])
        sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed())

        true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0)
        true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Esempio n. 4
0
            def kernel(hmc_state, step_size, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=step_size,
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(input_tensor=tf.exp(
                    tf.minimum(0., hmc_extra.log_accept_ratio)))
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9,
                                                     adaptation_rate=rate)

                return (hmc_state, step_size, step + 1), hmc_extra