def testIntegratorStep(self, method, num_tlp_calls, num_tlp_calls_jax=None): tlp_call_counter = [0] def target_log_prob_fn(q): tlp_call_counter[0] += 1 return -q**2, 1. def kinetic_energy_fn(p): return tf.abs(p)**3., 2. state, extras = method( integrator_step_state=fun_mcmc.IntegratorStepState( state=1., state_grads=None, momentum=2.), step_size=0.1, target_log_prob_fn=target_log_prob_fn, kinetic_energy_fn=kinetic_energy_fn) if num_tlp_calls_jax is not None and backend.get_backend( ) == backend.JAX: num_tlp_calls = num_tlp_calls_jax self.assertEqual(num_tlp_calls, tlp_call_counter[0]) self.assertEqual(1., extras.state_extra) self.assertEqual(2., extras.kinetic_energy_extra) initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn( 2.)[0] fin_hamiltonian = -target_log_prob_fn( state.state)[0] + kinetic_energy_fn(state.momentum)[0] self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
def kernel(rwm_state, seed): if backend.get_backend() == backend.TENSORFLOW: rwm_seed = tfp_test_util.test_seed() else: rwm_seed, seed = util.split_seed(seed, 2) rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis( rwm_state, target_log_prob_fn=target_log_prob_fn, proposal_fn=proposal_fn, seed=rwm_seed) return (rwm_state, seed), rwm_extra
def kernel(hmc_state, seed): if backend.get_backend() == backend.TENSORFLOW: hmc_seed = tfp_test_util.test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_extra
def testBasicHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = tf.constant([2., 3.]) base_scale = tf.constant([2., 0.5]) def target_log_prob_fn(x): return -tf.reduce_sum( 0.5 * tf.square((x - base_mean) / base_scale), -1), () def kernel(hmc_state, seed): if backend.get_backend() == backend.TENSORFLOW: hmc_seed = tfp_test_util.test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_extra if backend.get_backend() == backend.TENSORFLOW: seed = tfp_test_util.test_seed() else: seed = self._make_seed(tfp_test_util.test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.HamiltonianMonteCarloState(state), seed), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) true_samples = util.random_normal(shape=[4096, 2], dtype=tf.float32, seed=seed) * base_scale + base_mean true_mean = tf.reduce_mean(true_samples, axis=0) true_var = tf.math.reduce_variance(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
def kernel(hmc_state, raac_state, seed): if backend.get_backend() == backend.TENSORFLOW: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step( raac_state, hmc_state.state, axis=aggregation) return (hmc_state, raac_state, seed), hmc_extra
def testRandomWalkMetropolis(self): num_steps = 1000 state = tf.ones([16], dtype=tf.int32) target_logits = tf.constant([1., 2., 3., 4.]) + 2. proposal_logits = tf.constant([4., 3., 2., 1.]) + 2. def target_log_prob_fn(x): return tf.gather(target_logits, x), () def proposal_fn(x, seed): current_logits = tf.gather(proposal_logits, x) proposal = util.random_categorical(proposal_logits[tf.newaxis], x.shape[0], seed)[0] proposed_logits = tf.gather(proposal_logits, proposal) return tf.cast(proposal, x.dtype), ((), proposed_logits - current_logits) def kernel(rwm_state, seed): if backend.get_backend() == backend.TENSORFLOW: rwm_seed = tfp_test_util.test_seed() else: rwm_seed, seed = util.split_seed(seed, 2) rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis( rwm_state, target_log_prob_fn=target_log_prob_fn, proposal_fn=proposal_fn, seed=rwm_seed) return (rwm_state, seed), rwm_extra if backend.get_backend() == backend.TENSORFLOW: seed = tfp_test_util.test_seed() else: seed = self._make_seed(tfp_test_util.test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.random_walk_metropolis_init( state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) # Discard the warmup samples. chain = chain[500:] sample_mean = tf.reduce_mean(tf.one_hot(chain, 4), axis=[0, 1]) self.assertAllClose(tf.nn.softmax(target_logits), sample_mean, atol=0.1)
def trace( state: State, fn: TransitionOperator, num_steps: IntTensor, trace_fn: Callable[[State, TensorNest], TensorNest], parallel_iterations: int = 10, ) -> Tuple[State, TensorNest]: """`TransitionOperator` that runs `fn` repeatedly and traces its outputs. Args: state: A nest of `Tensor`s or None. fn: A `TransitionOperator`. num_steps: Number of steps to run the function for. Must be greater than 1. trace_fn: Callable that the unpacked outputs of `fn` and returns a nest of `Tensor`s. These will be stacked and returned. parallel_iterations: Number of iterations of the while loop to run in parallel. Returns: state: The final state returned by `fn`. traces: Stacked outputs of `trace_fn`. """ state = util.map_tree( lambda t: (t if t is None else tf.convert_to_tensor(t)), state) def wrapper(state): state, extra = util.map_tree(tf.convert_to_tensor, call_transition_operator(fn, state)) trace_element = util.map_tree(tf.convert_to_tensor, trace_fn(state, extra)) return state, trace_element # JAX tracing/pre-compilation isn't as stable as TF's, so we won't use it to # start. if (backend.get_backend() != backend.TENSORFLOW or any(e is None for e in util.flatten_tree(state)) or tf.executing_eagerly()): state, first_trace = wrapper(state) trace_arrays = util.map_tree( lambda v: util.write_dynamic_array( # pylint: disable=g-long-lambda util.make_dynamic_array( v.dtype, size=num_steps, element_shape=v.shape), 0, v), first_trace) start_idx = 1 else: state_spec = util.map_tree(tf.TensorSpec.from_tensor, state) # We need the shapes and dtypes of the outputs of `wrapper` function to # create the `TensorArray`s, we can get it by pre-compiling the wrapper # function. wrapper = tf.function(autograph=False)(wrapper) concrete_wrapper = wrapper.get_concrete_function(state_spec) _, trace_dtypes = concrete_wrapper.output_dtypes _, trace_shapes = concrete_wrapper.output_shapes trace_arrays = util.map_tree( lambda dtype, shape: tf.TensorArray( # pylint: disable=g-long-lambda dtype, size=num_steps, element_shape=shape), trace_dtypes, trace_shapes) wrapper = lambda state: concrete_wrapper(*util.flatten_tree(state)) start_idx = 0 def body(i, state, trace_arrays): state, trace_element = wrapper(state) trace_arrays = util.map_tree( lambda a, v: util.write_dynamic_array(a, i, v), trace_arrays, trace_element) return i + 1, state, trace_arrays def cond(i, *_): return i < num_steps _, state, trace_arrays = tf.while_loop( cond=cond, body=body, loop_vars=(start_idx, state, trace_arrays), parallel_iterations=parallel_iterations) stacked_trace = util.map_tree(util.snapshot_dynamic_array, trace_arrays) # TensorFlow often loses the static shape information. if backend.get_backend() == backend.TENSORFLOW: static_length = tf.get_static_value(num_steps) def _merge_static_length(x): x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:])) return x stacked_trace = util.map_tree(_merge_static_length, stacked_trace) return state, stacked_trace
def _wrapper(self, *args, **kwargs): if backend.get_backend() != backend.JAX: return fn(self, *args, **kwargs)
def testRunningApproximateAutoCovariance(self, state_shape, event_ndims, aggregation): # We'll use HMC as the source of our chain. # While HMC is being sampled, we also compute the running autocovariance. step_size = 0.2 num_steps = 1000 num_leapfrog_steps = 10 max_lags = 300 state = tf.constant(np.zeros(state_shape).astype(np.float32)) def target_log_prob_fn(x): lp = -0.5 * tf.square(x) if event_ndims is None: return lp, () else: return tf.reduce_sum(lp, -1), () def kernel(hmc_state, raac_state, seed): if backend.get_backend() == backend.TENSORFLOW: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step( raac_state, hmc_state.state, axis=aggregation) return (hmc_state, raac_state, seed), hmc_extra if backend.get_backend() == backend.TENSORFLOW: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. (_, raac_state, _), chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=( fun_mcmc.HamiltonianMonteCarloState(state), fun_mcmc.running_approximate_auto_covariance_init( max_lags=max_lags, state_shape=state_shape, dtype=state.dtype, axis=aggregation), seed, ), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) true_aggregation = (0, ) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_variance = np.array( tf.math.reduce_variance(np.array(chain), true_aggregation)) true_autocov = np.array( tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags)) if aggregation is not None: true_autocov = tf.reduce_mean( true_autocov, [a + 1 for a in util.flatten_tree(aggregation)]) self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5) self.assertAllClose(true_autocov, raac_state.auto_covariance / raac_state.auto_covariance[0], atol=0.1)