예제 #1
0
        def trace():
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState(
                tf.zeros([1])),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
예제 #2
0
        def trace():
            # pylint: disable=g-long-lambda
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init(
                state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
예제 #3
0
    def testTraceTrace(self):
        def fun(x):
            return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args:
                                  ())

        x, _ = fun_mcmc.trace(0., fun, 2, lambda *args: ())
        self.assertAllEqual(4., x)
예제 #4
0
    def testWrapTransitionKernel(self):
        class TestKernel(tfp.mcmc.TransitionKernel):
            def one_step(self, current_state, previous_kernel_results):
                return [x + 1
                        for x in current_state], previous_kernel_results + 1

            def bootstrap_results(self, current_state):
                return sum(current_state)

            def is_calibrated(self):
                return True

        def kernel(state, pkr):
            return fun_mcmc.transition_kernel_wrapper(state, pkr, TestKernel())

        ((final_state, final_kr), _), _ = fun_mcmc.trace(({
            'x': 0.,
            'y': 1.
        }, None),
                                                         kernel,
                                                         2,
                                                         trace_fn=lambda *args:
                                                         ())
        self.assertAllEqual({'x': 2., 'y': 3.}, self.evaluate(final_state))
        self.assertAllEqual(1. + 2., self.evaluate(final_kr))
예제 #5
0
  def testTraceTrace(self):
    if not tf.executing_eagerly():
      return

    def fun(x):
      return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args: ())

    x, _ = fun_mcmc.trace(0., fun, 2, lambda *args: ())
    self.assertAllEqual(4., x)
예제 #6
0
    def testTraceSingle(self):
        def fun(x):
            if x is None:
                x = 0.
            return x + 1., 2 * x

        x, e_trace = fun_mcmc.trace(state=None,
                                    fn=fun,
                                    num_steps=5,
                                    trace_fn=lambda _, xp1: xp1)

        self.assertAllEqual(5., x.numpy())
        self.assertAllEqual([0., 2., 4., 6., 8.], e_trace.numpy())
예제 #7
0
    def testBasicHMC(self):
        if not tf.executing_eagerly():
            return
        step_size = 0.2
        num_steps = 2000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = [1., 0]
        base_cov = [[1, 0.5], [0.5, 1]]

        bijector = tfb.Softplus()
        base_dist = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        target_dist = bijector(base_dist)

        def orig_target_log_prob_fn(x):
            return target_dist.log_prob(x), ()

        target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
            orig_target_log_prob_fn, bijector, state)

        # pylint: disable=g-long-lambda
        kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
            state,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=tfp_test_util.test_seed()))

        _, chain = fun_mcmc.trace(
            state=fun_mcmc.HamiltonianMonteCarloState(state=state,
                                                      state_grads=None,
                                                      target_log_prob=None,
                                                      state_extra=None),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state.state_extra[0])

        sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1])
        sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed())

        true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0)
        true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
예제 #8
0
    def testTraceNested(self):
        def fun(x, y):
            if x is None:
                x = 0.
            return (x + 1., y + 2.), ()

        (x, y), (x_trace, y_trace) = fun_mcmc.trace(state=(None, 0.),
                                                    fn=fun,
                                                    num_steps=5,
                                                    trace_fn=lambda xy, _: xy)

        self.assertAllEqual(5., x)
        self.assertAllEqual(10., y)
        self.assertAllEqual([1., 2., 3., 4., 5.], x_trace)
        self.assertAllEqual([2., 4., 6., 8., 10.], y_trace)
예제 #9
0
    def testTraceSingle(self):
        if not tf.executing_eagerly():
            return

        def fun(x):
            if x is None:
                x = 0.
            return x + 1., 2 * x

        (x, e), x_trace = fun_mcmc.trace(state=None,
                                         fn=fun,
                                         num_steps=5,
                                         trace_fn=lambda _, xp1: xp1)

        self.assertAllEqual(5., x.numpy())
        self.assertAllEqual(8., e.numpy())
        self.assertAllEqual([0., 2., 4., 6., 8.], x_trace.numpy())
예제 #10
0
        def computation(state):
            bijector = tfb.Softplus()
            base_dist = tfd.MultivariateNormalFullCovariance(
                loc=base_mean, covariance_matrix=base_cov)
            target_dist = bijector(base_dist)

            def orig_target_log_prob_fn(x):
                return target_dist.log_prob(x), ()

            target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
                orig_target_log_prob_fn, bijector, state)

            def kernel(hmc_state, step_size, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=step_size,
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(input_tensor=tf.exp(
                    tf.minimum(0., hmc_extra.log_accept_ratio)))
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9,
                                                     adaptation_rate=rate)

                return (hmc_state, step_size, step + 1), hmc_extra

            _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
                (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0),
                kernel,
                num_adapt_steps + num_steps,
                trace_fn=lambda state, extra:
                (state[0].state_extra[0], extra.log_accept_ratio))
            true_samples = target_dist.sample(4096,
                                              seed=tfp_test_util.test_seed())
            return chain, log_accept_ratio_trace, true_samples
예제 #11
0
 def fun(x):
     return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args:
                           ())
예제 #12
0
    def testAdaptiveStepSize(self):
        if not tf.executing_eagerly():
            return
        step_size = 0.2
        num_steps = 2000
        num_adapt_steps = 1000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = [1., 0]
        base_cov = [[1, 0.5], [0.5, 1]]

        bijector = tfb.Softplus()
        base_dist = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        target_dist = bijector(base_dist)

        def orig_target_log_prob_fn(x):
            return target_dist.log_prob(x), ()

        target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
            orig_target_log_prob_fn, bijector, state)

        @tf.function
        def kernel(hmc_state, step_size, step):
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_leapfrog_steps=num_leapfrog_steps,
                target_log_prob_fn=target_log_prob_fn)

            rate = tf.compat.v1.train.polynomial_decay(
                0.01,
                global_step=step,
                power=0.5,
                decay_steps=num_adapt_steps,
                end_learning_rate=0.)
            mean_p_accept = tf.reduce_mean(input_tensor=tf.exp(
                tf.minimum(0., hmc_extra.log_accept_ratio)))
            step_size = fun_mcmc.sign_adaptation(step_size,
                                                 output=mean_p_accept,
                                                 set_point=0.9,
                                                 adaptation_rate=rate)

            return (hmc_state, step_size, step + 1), hmc_extra

        _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
            (fun_mcmc.HamiltonianMonteCarloState(
                state=state,
                state_grads=None,
                target_log_prob=None,
                state_extra=None), step_size, 0),
            kernel,
            num_adapt_steps + num_steps,
            trace_fn=lambda state, extra:
            (state[0].state_extra[0], extra.log_accept_ratio))

        log_accept_ratio_trace = log_accept_ratio_trace[num_adapt_steps:]
        chain = chain[num_adapt_steps:]

        sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1])
        sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed())

        true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0)
        true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        self.assertAllClose(true_mean, sample_mean, rtol=0.05, atol=0.05)
        self.assertAllClose(true_cov, sample_cov, rtol=0.05, atol=0.05)
        self.assertAllClose(tf.reduce_mean(
            input_tensor=tf.exp(tf.minimum(0., log_accept_ratio_trace))),
                            0.9,
                            rtol=0.1)