def kernel(hmc_state, step_size_state, step, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) rate = fun_mcmc.prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, final_step_size=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1, seed), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))
def testAdam(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] _, [(x, y), loss] = fun_mcmc.trace( fun_mcmc.adam_init([tf.zeros([]), tf.zeros([])]), lambda adam_state: fun_mcmc.adam_step( # pylint: disable=g-long-lambda adam_state, loss_fn, learning_rate=0.01), num_steps=1000, trace_fn=lambda state, extra: [state.state, extra.loss]) self.assertAllClose(1., x[-1], atol=1e-3) self.assertAllClose(2., y[-1], atol=1e-3) self.assertAllClose(0., loss[-1], atol=1e-3)
def kernel(hmc_state, step_size_state, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))