def testWrapTransitionKernel(self): class TestKernel(tfp.mcmc.TransitionKernel): def one_step(self, current_state, previous_kernel_results): return [x + 1 for x in current_state], previous_kernel_results + 1 def bootstrap_results(self, current_state): return sum(current_state) def is_calibrated(self): return True def kernel(state, pkr): return util_tfp.transition_kernel_wrapper( state, pkr, TestKernel()) state = {'x': self._constant(0.), 'y': self._constant(1.)} kr = 1. (final_state, final_kr), _ = fun_mc.trace( (state, kr), kernel, 2, trace_fn=lambda *args: (), ) self.assertAllEqual({ 'x': 2., 'y': 3. }, util.map_tree(np.array, final_state)) self.assertAllEqual(1. + 2., final_kr)
def testInteractiveIterationAxis1(self): def kernel(x): return x + 1, x state, trace = prefab.interactive_trace( 0., lambda x: fun_mc.trace(x, kernel, 5), 20, iteration_axis=1, progress_bar_fn=None) self.assertAllClose(100., state) self.assertEqual([100], list(trace.shape)) self.assertAllClose(99., trace[-1])
def testPHMC(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2], dtype=self._dtype) base_mean = self._constant([2., 3.]) base_scale = self._constant([2., 0.5]) def target_log_prob_fn(x): return -tf.reduce_sum( 0.5 * tf.square((x - base_mean) / base_scale), -1), () def kernel(phmc_state, seed): phmc_seed, seed = util.split_seed(seed, 2) phmc_state, _ = prefab.persistent_hamiltonian_monte_carlo_step( phmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, noise_fraction=self._constant(0.5), mh_drift=self._constant(0.127), seed=phmc_seed) return (phmc_state, seed), phmc_state.state seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda state=(prefab.persistent_hamiltonian_monte_carlo_init( state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) true_samples = util.random_normal(shape=[4096, 2], dtype=self._dtype, seed=seed) * base_scale + base_mean true_mean = tf.reduce_mean(true_samples, axis=0) true_var = tf.math.reduce_variance(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
def testSGAHMC(self): @tfd.JointDistributionCoroutine def model(): x = yield Root(tfd.Normal(self._constant(0.), 1.)) yield tfd.Sample(tfd.Normal(x, 1.), 2) def target_log_prob_fn(x): return model.log_prob(x), () @tf.function def kernel(sga_hmc_state, step, seed): adapt = step < num_adapt_steps seed, hmc_seed = util.split_seed(seed, 2) sga_hmc_state, sga_hmc_extra = sga_hmc.stochastic_gradient_ascent_hmc_step( sga_hmc_state, scalar_step_size=0.1, target_log_prob_fn=target_log_prob_fn, criterion_fn=sga_hmc.chees_criterion, adapt=adapt, seed=hmc_seed, ) return (sga_hmc_state, step + 1, seed ), sga_hmc_extra.trajectory_length_params.mean_trajectory_length() init_trajectory_length = self._constant(0.1) num_adapt_steps = 10 _, trajectory_length = fun_mc.trace( (sga_hmc.stochastic_gradient_ascent_hmc_init( util.map_tree_up_to( model.dtype, lambda dtype, shape: tf.zeros( # pylint: disable=g-long-lambda (16,) + tuple(shape), dtype), model.dtype, model.event_shape), target_log_prob_fn, init_trajectory_length=init_trajectory_length), 0, self._make_seed(_test_seed())), kernel, num_adapt_steps + 2) # We expect it to increase as part of adaptation. self.assertAllGreater(trajectory_length[-1], init_trajectory_length) # After adaptation is done, the trajectory length should remain constant. self.assertAllClose(trajectory_length[-1], trajectory_length[-2])
def testStepSizeAdaptation(self): def log_accept_ratio_fn(step_size): return -step_size**2 def kernel(ssa_state, seed): normal_seed, seed = util.split_seed(seed, 2) log_accept_ratio = ( log_accept_ratio_fn(ssa_state.step_size()) + 0.01 * util.random_normal([4], self._dtype, normal_seed)) ssa_state, ssa_extra = prefab.step_size_adaptation_step( ssa_state, log_accept_ratio, num_adaptation_steps=100) return (ssa_state, seed), (ssa_extra.accept_prob, ssa_state.step_size(), ssa_state.step_size(num_adaptation_steps=100)) seed = self._make_seed(_test_seed()) _, (p_accept, step_size, rms_step_size) = fun_mc.trace( (prefab.step_size_adaptation_init(tf.constant( 0.1, self._dtype)), seed), kernel, 200) self.assertAllClose(0.8, p_accept[100], atol=0.1) self.assertAllClose(step_size[100], step_size[150]) self.assertAllClose(rms_step_size[100], rms_step_size[150])
def interactive_trace( state: 'fun_mc.State', fn: 'fun_mc.TransitionOperator', num_steps: 'fun_mc.IntTensor', trace_mask: 'fun_mc.BooleanNest' = True, iteration_axis: int = 0, block_until_ready: 'bool' = True, progress_bar_fn: 'Callable[[Iterable[Any]], Iterable[Any]]' = ( _tqdm_progress_bar_fn), ) -> 'Tuple[fun_mc.State, fun_mc.TensorNest]': """Wrapped around fun_mc.trace, suited for interactive work. This is accomplished through unrolling fun_mc.trace, as well as optionally using a progress bar (TQDM by default). Args: state: A nest of `Tensor`s or None. fn: A `TransitionOperator`. num_steps: Number of steps to run the function for. Must be greater than 1. trace_mask: A potentially shallow nest with boolean leaves applied to the `extra` return value of `fn`. This controls whether or not to actually trace the quantities in `extra`. For subtrees of `extra` where the mask leaf is `True`, those subtrees are traced (i.e. the corresponding subtrees in `traces` will contain an extra leading dimension equalling `num_steps`). For subtrees of `extra` where the mask leaf is `False`, those subtrees are merely propagated, and their corresponding subtrees in `traces` correspond to their final value. iteration_axis: Integer. Indicates the axis of the trace outputs that should be flattened with the first axis. This is most useful when `fn` is `trace`. E.g. if the trace has shape `[num_steps, 2, 5]` and `iteration_axis=2`, the trace outputs will be reshaped/transposed to `[2, 5 * num_steps]`. A value of 0 disables this operation. block_until_ready: Whether to wait for the computation to finish between steps. This results in smoother progress bars under, e.g., JAX. progress_bar_fn: A callable that will be called with an iterable with length `num_steps` and which returns another iterable with the same length. This will be advanced for every step taken. If None, no progress bar is shown. Default: `lambda it: tqdm.tqdm(it, leave=True)`. Returns: state: The final state returned by `fn`. traces: A nest with the same structure as the extra return value of `fn`, but with leaves replaced with stacked and unstacked values according to the `trace_mask`. """ num_steps = tf.get_static_value(num_steps) if num_steps is None: raise ValueError( 'Interactive tracing requires `num_steps` to be statically known.') if progress_bar_fn is None: pbar = None else: pbar = iter(progress_bar_fn(range(num_steps))) def fn_with_progress(state): state, extra = fun_mc.call_transition_operator(fn, state) if block_until_ready: state, extra = util.block_until_ready((state, extra)) if pbar is not None: try: next(pbar) except StopIteration: pass return [state], extra [state], trace = fun_mc.trace( # Wrap the state in a singleton list to simplify implementation of # `fn_with_progress`. state=[state], fn=fn_with_progress, num_steps=num_steps, trace_mask=trace_mask, unroll=True, ) if iteration_axis != 0: def fix_part(x): x = util.move_axis(x, 0, iteration_axis - 1) x = tf.reshape( x, tuple(x.shape[:iteration_axis - 1]) + (-1,) + tuple(x.shape[iteration_axis + 1:])) return x trace = util.map_tree(fix_part, trace) return state, trace
def testAdaptiveHMC(self): num_chains = 16 num_steps = 4000 num_warmup_steps = num_steps // 2 num_adapt_steps = int(0.8 * num_warmup_steps) # Setup the model and state constraints. model = tfp.distributions.JointDistributionSequential([ tfp.distributions.Normal(loc=self._constant(0.), scale=1.), tfp.distributions.Independent( tfp.distributions.LogNormal(loc=self._constant([1., 1.]), scale=0.5), 1), ]) bijector = [tfp.bijectors.Identity(), tfp.bijectors.Exp()] transform_fn = util_tfp.bijector_to_transform_fn(bijector, model.dtype, batch_ndims=1) def target_log_prob_fn(*x): return model.log_prob(x), () # Start out at zeros (in the unconstrained space). state, _ = transform_fn(*[ tf.zeros([num_chains] + list(e), dtype=self._dtype) for e in model.event_shape ]) reparam_log_prob_fn, reparam_state = fun_mc.reparameterize_potential_fn( target_log_prob_fn, transform_fn, state) # Define the kernel. def kernel(adaptive_hmc_state, seed): hmc_seed, seed = util.split_seed(seed, 2) adaptive_hmc_state, adaptive_hmc_extra = ( prefab.adaptive_hamiltonian_monte_carlo_step( adaptive_hmc_state, target_log_prob_fn=reparam_log_prob_fn, num_adaptation_steps=num_adapt_steps, seed=hmc_seed)) return (adaptive_hmc_state, seed), (adaptive_hmc_extra.state, adaptive_hmc_extra.is_accepted, adaptive_hmc_extra.step_size) seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, (state_chain, is_accepted_chain, _) = tf.function(lambda reparam_state, seed: fun_mc.trace( # pylint: disable=g-long-lambda state=(prefab.adaptive_hamiltonian_monte_carlo_init( reparam_state, reparam_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(reparam_state, seed) # Discard the warmup samples. state_chain = [s[num_warmup_steps:] for s in state_chain] is_accepted_chain = is_accepted_chain[num_warmup_steps:] accept_rate = tf.reduce_mean(tf.cast(is_accepted_chain, tf.float32)) rhat = tfp.mcmc.potential_scale_reduction(state_chain) sample_mean = [tf.reduce_mean(s, axis=[0, 1]) for s in state_chain] sample_var = [ tf.math.reduce_variance(s, axis=[0, 1]) for s in state_chain ] self.assertAllAssertsNested(lambda rhat: self.assertAllLess(rhat, 1.1), rhat) self.assertAllClose(0.8, accept_rate, atol=0.05) self.assertAllClose(model.mean(), sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(model.variance(), sample_var, rtol=0.1, atol=0.1)
def inner(x): state, trace = fun_mc.trace(x, kernel, 5) trace = tf.transpose(trace, [1, 0]) return state, trace