def trace(): kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=0.1, num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=tfp_test_util.test_seed()) fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState( tf.zeros([1])), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def trace(): # pylint: disable=g-long-lambda kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=0.1, num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=tfp_test_util.test_seed()) fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init( state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def testTraceTrace(self): def fun(x): return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args: ()) x, _ = fun_mcmc.trace(0., fun, 2, lambda *args: ()) self.assertAllEqual(4., x)
def testWrapTransitionKernel(self): class TestKernel(tfp.mcmc.TransitionKernel): def one_step(self, current_state, previous_kernel_results): return [x + 1 for x in current_state], previous_kernel_results + 1 def bootstrap_results(self, current_state): return sum(current_state) def is_calibrated(self): return True def kernel(state, pkr): return fun_mcmc.transition_kernel_wrapper(state, pkr, TestKernel()) ((final_state, final_kr), _), _ = fun_mcmc.trace(({ 'x': 0., 'y': 1. }, None), kernel, 2, trace_fn=lambda *args: ()) self.assertAllEqual({'x': 2., 'y': 3.}, self.evaluate(final_state)) self.assertAllEqual(1. + 2., self.evaluate(final_kr))
def testTraceTrace(self): if not tf.executing_eagerly(): return def fun(x): return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args: ()) x, _ = fun_mcmc.trace(0., fun, 2, lambda *args: ()) self.assertAllEqual(4., x)
def testTraceSingle(self): def fun(x): if x is None: x = 0. return x + 1., 2 * x x, e_trace = fun_mcmc.trace(state=None, fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1) self.assertAllEqual(5., x.numpy()) self.assertAllEqual([0., 2., 4., 6., 8.], e_trace.numpy())
def testBasicHMC(self): if not tf.executing_eagerly(): return step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=tfp_test_util.test_seed())) _, chain = fun_mcmc.trace( state=fun_mcmc.HamiltonianMonteCarloState(state=state, state_grads=None, target_log_prob=None, state_extra=None), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state.state_extra[0]) sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def testTraceNested(self): def fun(x, y): if x is None: x = 0. return (x + 1., y + 2.), () (x, y), (x_trace, y_trace) = fun_mcmc.trace(state=(None, 0.), fn=fun, num_steps=5, trace_fn=lambda xy, _: xy) self.assertAllEqual(5., x) self.assertAllEqual(10., y) self.assertAllEqual([1., 2., 3., 4., 5.], x_trace) self.assertAllEqual([2., 4., 6., 8., 10.], y_trace)
def testTraceSingle(self): if not tf.executing_eagerly(): return def fun(x): if x is None: x = 0. return x + 1., 2 * x (x, e), x_trace = fun_mcmc.trace(state=None, fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1) self.assertAllEqual(5., x.numpy()) self.assertAllEqual(8., e.numpy()) self.assertAllEqual([0., 2., 4., 6., 8.], x_trace.numpy())
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean(input_tensor=tf.exp( tf.minimum(0., hmc_extra.log_accept_ratio))) step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9, adaptation_rate=rate) return (hmc_state, step_size, step + 1), hmc_extra _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0), kernel, num_adapt_steps + num_steps, trace_fn=lambda state, extra: (state[0].state_extra[0], extra.log_accept_ratio)) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) return chain, log_accept_ratio_trace, true_samples
def fun(x): return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args: ())
def testAdaptiveStepSize(self): if not tf.executing_eagerly(): return step_size = 0.2 num_steps = 2000 num_adapt_steps = 1000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) @tf.function def kernel(hmc_state, step_size, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean(input_tensor=tf.exp( tf.minimum(0., hmc_extra.log_accept_ratio))) step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9, adaptation_rate=rate) return (hmc_state, step_size, step + 1), hmc_extra _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( (fun_mcmc.HamiltonianMonteCarloState( state=state, state_grads=None, target_log_prob=None, state_extra=None), step_size, 0), kernel, num_adapt_steps + num_steps, trace_fn=lambda state, extra: (state[0].state_extra[0], extra.log_accept_ratio)) log_accept_ratio_trace = log_accept_ratio_trace[num_adapt_steps:] chain = chain[num_adapt_steps:] sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.05, atol=0.05) self.assertAllClose(true_cov, sample_cov, rtol=0.05, atol=0.05) self.assertAllClose(tf.reduce_mean( input_tensor=tf.exp(tf.minimum(0., log_accept_ratio_trace))), 0.9, rtol=0.1)