def testMclachlanIntegratorStepReversible(self): def target_log_prob_fn(q): return -q**2, [] def kinetic_energy_fn(p): return p**2., [] seed = self._make_seed(_test_seed()) state = 1. _, _, state_grads = fun_mcmc.call_potential_fn_with_grads( target_log_prob_fn, state, ) state_fwd, _ = _fwd_mclachlan_optimal_4th_order_step( integrator_step_state=fun_mcmc.IntegratorStepState( state=state, state_grads=state_grads, momentum=util.random_normal([], tf.float32, seed)), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) state_rev, _ = _rev_mclachlan_optimal_4th_order_step( integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) self.assertAllClose(state, state_rev.state, atol=1e-6)
def testIntegratorStep(self, method, num_tlp_calls, num_tlp_calls_jax=None): tlp_call_counter = [0] def target_log_prob_fn(q): tlp_call_counter[0] += 1 return -q**2, 1. def kinetic_energy_fn(p): return tf.abs(p)**3., 2. state = 1. _, _, state_grads = fun_mcmc.call_potential_fn_with_grads( target_log_prob_fn, state, ) state, extras = method( integrator_step_state=fun_mcmc.IntegratorStepState( state=state, state_grads=state_grads, momentum=2.), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) if num_tlp_calls_jax is not None and self._is_on_jax: num_tlp_calls = num_tlp_calls_jax self.assertEqual(num_tlp_calls, tlp_call_counter[0]) self.assertEqual(1., extras.state_extra) self.assertEqual(2., extras.kinetic_energy_extra) initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn(2.)[0] fin_hamiltonian = -target_log_prob_fn(state.state)[0] + kinetic_energy_fn( state.momentum)[0] self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
def testIntegratorStepReversible(self, method): def target_log_prob_fn(q): return -q**2, [] def kinetic_energy_fn(p): return p**2., [] seed = self._make_seed(tfp_test_util.test_seed()) state_fwd, _ = method( integrator_step_state=fun_mcmc.IntegratorStepState( state=1., state_grads=None, momentum=util.random_normal([], tf.float32, seed)), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) state_rev, _ = method(integrator_step_state=state_fwd._replace( momentum=-state_fwd.momentum), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) self.assertAllClose(1., state_rev.state, atol=1e-6)
def testIntegratorStep(self, method, num_tlp_calls): tlp_call_counter = [0] def target_log_prob_fn(q): tlp_call_counter[0] += 1 return -q**2, 1. def kinetic_energy_fn(p): return tf.abs(p)**3., 2. state, extras = method( integrator_step_state=fun_mcmc.IntegratorStepState( state=1., state_grads=None, momentum=2.), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) self.assertEqual(num_tlp_calls, tlp_call_counter[0]) self.assertEqual(1., extras.state_extra) self.assertEqual(2., extras.kinetic_energy_extra) initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn( 2.)[0] fin_hamiltonian = -target_log_prob_fn( state.state)[0] + kinetic_energy_fn(state.momentum)[0] self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
def testIntegratorStepReversible(self, method): def target_log_prob_fn(q): return -q**2, [] def kinetic_energy_fn(p): return p**2., [] seed = self._make_seed(_test_seed()) state = self._constant(1.) _, _, state_grads = fun_mcmc.call_potential_fn_with_grads( target_log_prob_fn, state, ) state_fwd, _ = method( integrator_step_state=fun_mcmc.IntegratorStepState( state=state, state_grads=state_grads, momentum=util.random_normal([], self._dtype, seed)), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) state_rev, _ = method( integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) self.assertAllClose(state, state_rev.state, atol=1e-6)
def testMclachlanIntegratorStepReversible(self): def target_log_prob_fn(q): return -q**2, [] def kinetic_energy_fn(p): return p**2., [] state_fwd, _ = _fwd_mclachlan_optimal_4th_order_step( integrator_step_state=fun_mcmc.IntegratorStepState( state=1., state_grads=None, momentum=tf.random.normal([])), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) state_rev, _ = _rev_mclachlan_optimal_4th_order_step( integrator_step_state=state_fwd._replace( momentum=-state_fwd.momentum), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) self.assertAllClose(1., state_rev.state, atol=1e-6)