def testBasicHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = tf.constant([2., 3.]) base_scale = tf.constant([2., 0.5]) def target_log_prob_fn(x): return -tf.reduce_sum( 0.5 * tf.square((x - base_mean) / base_scale), -1), () def kernel(hmc_state, seed): if backend.get_backend() == backend.TENSORFLOW: hmc_seed = tfp_test_util.test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_extra if backend.get_backend() == backend.TENSORFLOW: seed = tfp_test_util.test_seed() else: seed = self._make_seed(tfp_test_util.test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.HamiltonianMonteCarloState(state), seed), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) true_samples = util.random_normal(shape=[4096, 2], dtype=tf.float32, seed=seed) * base_scale + base_mean true_mean = tf.reduce_mean(true_samples, axis=0) true_var = tf.math.reduce_variance(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
def trace(): kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=0.1, num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=tfp_test_util.test_seed()) fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState( tf.zeros([1])), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def testPreconditionedHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=tfp_test_util.test_seed())) _, chain = fun_mcmc.trace( state=fun_mcmc.HamiltonianMonteCarloState(state), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state.state_extra[0]) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step(step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.HamiltonianMonteCarloState(state), fun_mcmc.adam_init(tf.math.log(step_size)), 0), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample(4096, seed=_test_seed()) return chain, log_accept_ratio_trace, true_samples
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9, adaptation_rate=rate) return (hmc_state, step_size, step + 1), hmc_extra _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0), kernel, num_adapt_steps + num_steps, trace_fn=lambda state, extra: (state[0].state_extra[0], extra.log_accept_ratio)) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) return chain, log_accept_ratio_trace, true_samples
def testRunningApproximateAutoCovariance(self, state_shape, event_ndims, aggregation): # We'll use HMC as the source of our chain. # While HMC is being sampled, we also compute the running autocovariance. step_size = 0.2 num_steps = 1000 num_leapfrog_steps = 10 max_lags = 300 state = tf.constant(np.zeros(state_shape).astype(np.float32)) def target_log_prob_fn(x): lp = -0.5 * tf.square(x) if event_ndims is None: return lp, () else: return tf.reduce_sum(lp, -1), () def kernel(hmc_state, raac_state, seed): if backend.get_backend() == backend.TENSORFLOW: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step( raac_state, hmc_state.state, axis=aggregation) return (hmc_state, raac_state, seed), hmc_extra if backend.get_backend() == backend.TENSORFLOW: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. (_, raac_state, _), chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=( fun_mcmc.HamiltonianMonteCarloState(state), fun_mcmc.running_approximate_auto_covariance_init( max_lags=max_lags, state_shape=state_shape, dtype=state.dtype, axis=aggregation), seed, ), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) true_aggregation = (0, ) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_variance = np.array( tf.math.reduce_variance(np.array(chain), true_aggregation)) true_autocov = np.array( tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags)) if aggregation is not None: true_autocov = tf.reduce_mean( true_autocov, [a + 1 for a in util.flatten_tree(aggregation)]) self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5) self.assertAllClose(true_autocov, raac_state.auto_covariance / raac_state.auto_covariance[0], atol=0.1)