def testBasicHMC(self): if not tf.executing_eagerly(): return step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=tfp_test_util.test_seed())) _, chain = fun_mcmc.trace( state=fun_mcmc.HamiltonianMonteCarloState(state=state, state_grads=None, target_log_prob=None, state_extra=None), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state.state_extra[0]) sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean(input_tensor=tf.exp( tf.minimum(0., hmc_extra.log_accept_ratio))) step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9, adaptation_rate=rate) return (hmc_state, step_size, step + 1), hmc_extra _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0), kernel, num_adapt_steps + num_steps, trace_fn=lambda state, extra: (state[0].state_extra[0], extra.log_accept_ratio)) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) return chain, log_accept_ratio_trace, true_samples
def testTransformLogProbFn(self): def log_prob_fn(x, y): return tfd.Normal(0., 1.).log_prob(x) + tfd.Normal( 1., 1.).log_prob(y), () bijectors = [tfb.AffineScalar(scale=2.), tfb.AffineScalar(scale=3.)] (transformed_log_prob_fn, transformed_init_state) = fun_mcmc.transform_log_prob_fn( log_prob_fn, bijectors, [2., 3.]) self.assertIsInstance(transformed_init_state, list) self.assertAllClose([1., 1.], transformed_init_state) tlp, (orig_space, _) = transformed_log_prob_fn(1., 1.) self.assertIsInstance(orig_space, list) lp = log_prob_fn(2., 3.)[0] + sum( b.forward_log_det_jacobian(1., event_ndims=0) for b in bijectors) self.assertAllClose([2., 3.], orig_space) self.assertAllClose(lp, tlp)
def testAdaptiveStepSize(self): if not tf.executing_eagerly(): return step_size = 0.2 num_steps = 2000 num_adapt_steps = 1000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) @tf.function def kernel(hmc_state, step_size, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean(input_tensor=tf.exp( tf.minimum(0., hmc_extra.log_accept_ratio))) step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9, adaptation_rate=rate) return (hmc_state, step_size, step + 1), hmc_extra _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( (fun_mcmc.HamiltonianMonteCarloState( state=state, state_grads=None, target_log_prob=None, state_extra=None), step_size, 0), kernel, num_adapt_steps + num_steps, trace_fn=lambda state, extra: (state[0].state_extra[0], extra.log_accept_ratio)) log_accept_ratio_trace = log_accept_ratio_trace[num_adapt_steps:] chain = chain[num_adapt_steps:] sample_mean = tf.reduce_mean(input_tensor=chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) true_mean = tf.reduce_mean(input_tensor=true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.05, atol=0.05) self.assertAllClose(true_cov, sample_cov, rtol=0.05, atol=0.05) self.assertAllClose(tf.reduce_mean( input_tensor=tf.exp(tf.minimum(0., log_accept_ratio_trace))), 0.9, rtol=0.1)