Exemplo n.º 1
0
        def trace():
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState(
                tf.zeros([1])),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Exemplo n.º 2
0
    def testBasicHMC(self):
        if not tf.executing_eagerly():
            return
        step_size = 0.2
        num_steps = 2000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = [1., 0]
        base_cov = [[1, 0.5], [0.5, 1]]

        bijector = tfb.Softplus()
        base_dist = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        target_dist = bijector(base_dist)

        def orig_target_log_prob_fn(x):
            return target_dist.log_prob(x), ()

        target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
            orig_target_log_prob_fn, bijector, state)

        # pylint: disable=g-long-lambda
        kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
            state,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=tfp_test_util.test_seed()))

        _, chain = fun_mcmc.trace(
            state=fun_mcmc.HamiltonianMonteCarloState(state=state,
                                                      state_grads=None,
                                                      target_log_prob=None,
                                                      state_extra=None),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state.state_extra[0])

        sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1])
        sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed())

        true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0)
        true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Exemplo n.º 3
0
        def computation(state):
            bijector = tfb.Softplus()
            base_dist = tfd.MultivariateNormalFullCovariance(
                loc=base_mean, covariance_matrix=base_cov)
            target_dist = bijector(base_dist)

            def orig_target_log_prob_fn(x):
                return target_dist.log_prob(x), ()

            target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
                orig_target_log_prob_fn, bijector, state)

            def kernel(hmc_state, step_size, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=step_size,
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(input_tensor=tf.exp(
                    tf.minimum(0., hmc_extra.log_accept_ratio)))
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9,
                                                     adaptation_rate=rate)

                return (hmc_state, step_size, step + 1), hmc_extra

            _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
                (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0),
                kernel,
                num_adapt_steps + num_steps,
                trace_fn=lambda state, extra:
                (state[0].state_extra[0], extra.log_accept_ratio))
            true_samples = target_dist.sample(4096,
                                              seed=tfp_test_util.test_seed())
            return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 4
0
    def testAdaptiveStepSize(self):
        if not tf.executing_eagerly():
            return
        step_size = 0.2
        num_steps = 2000
        num_adapt_steps = 1000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = [1., 0]
        base_cov = [[1, 0.5], [0.5, 1]]

        bijector = tfb.Softplus()
        base_dist = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        target_dist = bijector(base_dist)

        def orig_target_log_prob_fn(x):
            return target_dist.log_prob(x), ()

        target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
            orig_target_log_prob_fn, bijector, state)

        @tf.function
        def kernel(hmc_state, step_size, step):
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_leapfrog_steps=num_leapfrog_steps,
                target_log_prob_fn=target_log_prob_fn)

            rate = tf.compat.v1.train.polynomial_decay(
                0.01,
                global_step=step,
                power=0.5,
                decay_steps=num_adapt_steps,
                end_learning_rate=0.)
            mean_p_accept = tf.reduce_mean(input_tensor=tf.exp(
                tf.minimum(0., hmc_extra.log_accept_ratio)))
            step_size = fun_mcmc.sign_adaptation(step_size,
                                                 output=mean_p_accept,
                                                 set_point=0.9,
                                                 adaptation_rate=rate)

            return (hmc_state, step_size, step + 1), hmc_extra

        _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
            (fun_mcmc.HamiltonianMonteCarloState(
                state=state,
                state_grads=None,
                target_log_prob=None,
                state_extra=None), step_size, 0),
            kernel,
            num_adapt_steps + num_steps,
            trace_fn=lambda state, extra:
            (state[0].state_extra[0], extra.log_accept_ratio))

        log_accept_ratio_trace = log_accept_ratio_trace[num_adapt_steps:]
        chain = chain[num_adapt_steps:]

        sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1])
        sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed())

        true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0)
        true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        self.assertAllClose(true_mean, sample_mean, rtol=0.05, atol=0.05)
        self.assertAllClose(true_cov, sample_cov, rtol=0.05, atol=0.05)
        self.assertAllClose(tf.reduce_mean(
            input_tensor=tf.exp(tf.minimum(0., log_accept_ratio_trace))),
                            0.9,
                            rtol=0.1)