def testMclachlanIntegratorStepReversible(self): def target_log_prob_fn(q): return -q**2, [] def kinetic_energy_fn(p): return p**2., [] seed = self._make_seed(_test_seed()) state = 1. _, _, state_grads = fun_mcmc.call_potential_fn_with_grads( target_log_prob_fn, state, ) state_fwd, _ = _fwd_mclachlan_optimal_4th_order_step( integrator_step_state=fun_mcmc.IntegratorStepState( state=state, state_grads=state_grads, momentum=util.random_normal([], tf.float32, seed)), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) state_rev, _ = _rev_mclachlan_optimal_4th_order_step( integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) self.assertAllClose(state, state_rev.state, atol=1e-6)
def testIntegratorStep(self, method, num_tlp_calls, num_tlp_calls_jax=None): tlp_call_counter = [0] def target_log_prob_fn(q): tlp_call_counter[0] += 1 return -q**2, 1. def kinetic_energy_fn(p): return tf.abs(p)**3., 2. state = 1. _, _, state_grads = fun_mcmc.call_potential_fn_with_grads( target_log_prob_fn, state, ) state, extras = method( integrator_step_state=fun_mcmc.IntegratorStepState( state=state, state_grads=state_grads, momentum=2.), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) if num_tlp_calls_jax is not None and self._is_on_jax: num_tlp_calls = num_tlp_calls_jax self.assertEqual(num_tlp_calls, tlp_call_counter[0]) self.assertEqual(1., extras.state_extra) self.assertEqual(2., extras.kinetic_energy_extra) initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn(2.)[0] fin_hamiltonian = -target_log_prob_fn(state.state)[0] + kinetic_energy_fn( state.momentum)[0] self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
def testIntegratorStepReversible(self, method): def target_log_prob_fn(q): return -q**2, [] def kinetic_energy_fn(p): return p**2., [] seed = self._make_seed(_test_seed()) state = self._constant(1.) _, _, state_grads = fun_mcmc.call_potential_fn_with_grads( target_log_prob_fn, state, ) state_fwd, _ = method( integrator_step_state=fun_mcmc.IntegratorStepState( state=state, state_grads=state_grads, momentum=util.random_normal([], self._dtype, seed)), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) state_rev, _ = method( integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) self.assertAllClose(state, state_rev.state, atol=1e-6)
def testSurrogateLossFnDecorator(self): @fun_mcmc.make_surrogate_loss_fn(loss_value=1.) def loss_fn(_): return 3., 2. ret, extra, grads = fun_mcmc.call_potential_fn_with_grads(loss_fn, 0.) self.assertAllClose(1., ret) self.assertAllClose(2., extra) self.assertAllClose(3., grads)
def testRaggedIntegrator(self): def target_log_prob_fn(q): return -q**2, q def kinetic_energy_fn(p): return tf.abs(p)**3., p integrator_fn = lambda state, num_steps: fun_mcmc.hamiltonian_integrator( # pylint: disable=g-long-lambda state, num_steps=num_steps, integrator_step_fn=lambda state: fun_mcmc.leapfrog_step( # pylint: disable=g-long-lambda state, step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn), kinetic_energy_fn=kinetic_energy_fn, integrator_trace_fn=lambda state, extra: (state, extra)) state = tf.zeros([2]) momentum = tf.ones([2]) target_log_prob, _, state_grads = fun_mcmc.call_potential_fn_with_grads( target_log_prob_fn, state) start_state = fun_mcmc.IntegratorState( target_log_prob=target_log_prob, momentum=momentum, state=state, state_grads=state_grads, state_extra=state, ) state_1 = integrator_fn(start_state, 1) state_2 = integrator_fn(start_state, 2) state_1_2 = integrator_fn(start_state, [1, 2]) # Make sure integrators actually integrated to different points. self.assertFalse(np.all(state_1[0].state == state_2[0].state)) # Ragged integration should be consistent with the non-ragged equivalent. def get_batch(state, idx): # For the integrator trace, we'll grab the final value. return util.map_tree( lambda x: x[idx] if len(x.shape) == 1 else x[-1, idx], state) self.assertAllClose(get_batch(state_1, 0), get_batch(state_1_2, 0)) self.assertAllClose(get_batch(state_2, 0), get_batch(state_1_2, 1)) # Ragged traces should be equal up to the number of steps for the batch # element. def get_slice(state, num, idx): return util.map_tree(lambda x: x[:num, idx], state[1].integrator_trace) self.assertAllClose(get_slice(state_1, 1, 0), get_slice(state_1_2, 1, 0)) self.assertAllClose(get_slice(state_2, 2, 0), get_slice(state_1_2, 2, 1))
def testSurrogateLossFn(self, state): def grad_fn(*args, **kwargs): # This is uglier than user code due to the parameterized test... new_state = util.unflatten_tree(state, util.flatten_tree((args, kwargs))) return util.map_tree(lambda x: x + 1., new_state), new_state loss_fn = fun_mcmc.make_surrogate_loss_fn(grad_fn) # Mutate the state to make sure we didn't capture anything. state = util.map_tree(lambda x: x + 1., state) ret, extra, grads = fun_mcmc.call_potential_fn_with_grads(loss_fn, state) # The default is 0. self.assertAllClose(0., ret) # The gradients of the surrogate loss are state + 1. self.assertAllClose(util.map_tree(lambda x: x + 1., state), grads) self.assertAllClose(state, extra)