def computation(state, seed): bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) rate = fun_mcmc.prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, final_step_size=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1, seed), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample( 4096, seed=self._make_seed(_test_seed())) return chain, log_accept_ratio_trace, true_samples
def testAdam(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] _, [(x, y), loss] = fun_mcmc.trace( fun_mcmc.adam_init([tf.zeros([]), tf.zeros([])]), lambda adam_state: fun_mcmc.adam_step( # pylint: disable=g-long-lambda adam_state, loss_fn, learning_rate=0.01), num_steps=1000, trace_fn=lambda state, extra: [state.state, extra.loss]) self.assertAllClose(1., x[-1], atol=1e-3) self.assertAllClose(2., y[-1], atol=1e-3) self.assertAllClose(0., loss[-1], atol=1e-3)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample(4096, seed=_test_seed()) return chain, log_accept_ratio_trace, true_samples