def trace(): kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=self._constant(0.1), num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=_test_seed()) fun_mcmc.trace( state=fun_mcmc.hamiltonian_monte_carlo_init( tf.zeros([1], dtype=self._dtype), target_log_prob_fn), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def testRunningCovarianceMaxPoints(self): window_size = 100 rng = np.random.RandomState(_test_seed()) data = self._constant( np.concatenate( [ rng.randn(window_size, 2), np.array([1., 2.]) + np.array([2., 3.]) * rng.randn(window_size * 10, 2) ], axis=0, )) def kernel(rvs, idx): rvs, _ = fun_mcmc.running_covariance_step( rvs, data[idx], window_size=window_size) return (rvs, idx + 1), (rvs.mean, rvs.covariance) _, (mean, cov) = fun_mcmc.trace( state=(fun_mcmc.running_covariance_init([2], data.dtype), 0), fn=kernel, num_steps=len(data), ) # Up to window_size, we compute the running mean/variance exactly. self.assertAllClose( np.mean(data[:window_size], axis=0), mean[window_size - 1]) self.assertAllClose( _gen_cov(data[:window_size], axis=0), cov[window_size - 1]) # After window_size, we're doing exponential moving average, and pick up the # mean/variance after the change in the distribution. Since the moving # average is computed only over ~window_size points, this test is rather # noisy. self.assertAllClose(np.array([1., 2.]), mean[-1], atol=0.2) self.assertAllClose(np.array([[4., 0.], [0., 9.]]), cov[-1], atol=1.)
def testRunningVarianceMaxPoints(self): window_size = 100 rng = np.random.RandomState(_test_seed()) data = self._constant( np.concatenate( [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)], axis=0)) def kernel(rvs, idx): rvs, _ = fun_mcmc.running_variance_step( rvs, data[idx], window_size=window_size) return (rvs, idx + 1), (rvs.mean, rvs.variance) _, (mean, var) = fun_mcmc.trace( state=(fun_mcmc.running_variance_init([], data.dtype), 0), fn=kernel, num_steps=len(data), ) # Up to window_size, we compute the running mean/variance exactly. self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1]) self.assertAllClose(np.var(data[:window_size]), var[window_size - 1]) # After window_size, we're doing exponential moving average, and pick up the # mean/variance after the change in the distribution. Since the moving # average is computed only over ~window_size points, this test is rather # noisy. self.assertAllClose(1., mean[-1], atol=0.2) self.assertAllClose(4., var[-1], atol=0.8)
def testWrapTransitionKernel(self): class TestKernel(tfp.mcmc.TransitionKernel): def one_step(self, current_state, previous_kernel_results): return [x + 1 for x in current_state], previous_kernel_results + 1 def bootstrap_results(self, current_state): return sum(current_state) def is_calibrated(self): return True def kernel(state, pkr): return util_tfp.transition_kernel_wrapper(state, pkr, TestKernel()) state = {'x': self._constant(0.), 'y': self._constant(1.)} kr = 1. (final_state, final_kr), _ = fun_mcmc.trace( (state, kr), kernel, 2, trace_fn=lambda *args: (), ) self.assertAllEqual({ 'x': 2., 'y': 3. }, util.map_tree(np.array, final_state)) self.assertAllEqual(1. + 2., final_kr)
def testTraceMask(self, unroll): def fun(x): return x + 1, (2 * x, 3 * x) x, (trace_1, trace_2) = fun_mcmc.trace( state=0, fn=fun, num_steps=3, trace_mask=(True, False), unroll=unroll) self.assertAllEqual(3, x) self.assertAllEqual([0, 2, 4], trace_1) self.assertAllEqual(6, trace_2) x, (trace_1, trace_2) = fun_mcmc.trace( state=0, fn=fun, num_steps=3, trace_mask=False, unroll=unroll) self.assertAllEqual(3, x) self.assertAllEqual(4, trace_1) self.assertAllEqual(6, trace_2)
def testTraceTrace(self, unroll): def fun(x): return fun_mcmc.trace( x, lambda x: (x + 1., x + 1.), 2, trace_mask=False, unroll=unroll) x, trace = fun_mcmc.trace(0., fun, 2) self.assertAllEqual(4., x) self.assertAllEqual([2., 4.], trace)
def testPreconditionedHMC(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2], dtype=self._dtype) base_mean = self._constant([1., 0]) base_cov = self._constant([[1, 0.5], [0.5, 1]]) bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda def kernel(hmc_state, seed): hmc_seed, seed = util.split_seed(seed, 2) hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_state.state_extra[0] seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed())) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def testTraceSingle(self, unroll): def fun(x): return x + 1., 2 * x x, e_trace = fun_mcmc.trace( state=0., fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1, unroll=unroll) self.assertAllEqual(5., x) self.assertAllEqual([0., 2., 4., 6., 8.], e_trace)
def computation(state, seed): bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step, seed): hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) rate = prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, final_step_size=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1, seed), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample( 4096, seed=self._make_seed(_test_seed())) return chain, log_accept_ratio_trace, true_samples
def testTraceNested(self, unroll): def fun(x, y): return (x + 1., y + 2.), () (x, y), (x_trace, y_trace) = fun_mcmc.trace( state=(0., 0.), fn=fun, num_steps=5, trace_fn=lambda xy, _: xy, unroll=unroll) self.assertAllEqual(5., x) self.assertAllEqual(10., y) self.assertAllEqual([1., 2., 3., 4., 5.], x_trace) self.assertAllEqual([2., 4., 6., 8., 10.], y_trace)
def testBasicHMC(self, unroll): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2], dtype=self._dtype) base_mean = self._constant([2., 3.]) base_scale = self._constant([2., 0.5]) def target_log_prob_fn(x): return -tf.reduce_sum(0.5 * tf.square( (x - base_mean) / base_scale), -1), () def kernel(hmc_state, seed): hmc_seed, seed = util.split_seed(seed, 2) hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, unroll_integrator=unroll, seed=hmc_seed) return (hmc_state, seed), hmc_state.state seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) true_samples = util.random_normal( shape=[4096, 2], dtype=self._dtype, seed=seed) * base_scale + base_mean true_mean = tf.reduce_mean(true_samples, axis=0) true_var = tf.math.reduce_variance(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
def testGradientDescent(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] _, [(x, y), loss] = fun_mcmc.trace( fun_mcmc.GradientDescentState([self._constant(0.), self._constant(0.)]), lambda gd_state: fun_mcmc.gradient_descent_step( # pylint: disable=g-long-lambda gd_state, loss_fn, learning_rate=self._constant(0.01)), num_steps=1000, trace_fn=lambda state, extra: [state.state, extra.loss]) self.assertAllClose(1., x[-1], atol=1e-3) self.assertAllClose(2., y[-1], atol=1e-3) self.assertAllClose(0., loss[-1], atol=1e-3)
def testRunningMean(self, shape, aggregation): rng = np.random.RandomState(_test_seed()) data = self._constant(rng.randn(*shape)) def kernel(rms, idx): rms, _ = fun_mcmc.running_mean_step(rms, data[idx], axis=aggregation) return (rms, idx + 1), () true_aggregation = (0,) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_mean = np.mean(data, true_aggregation) (rms, _), _ = fun_mcmc.trace( state=(fun_mcmc.running_mean_init(true_mean.shape, data.dtype), 0), fn=kernel, num_steps=len(data), trace_fn=lambda *args: ()) self.assertAllClose(true_mean, rms.mean)
def testRandomWalkMetropolis(self): num_steps = 1000 state = tf.ones([16], dtype=tf.int32) target_logits = self._constant([1., 2., 3., 4.]) + 2. proposal_logits = self._constant([4., 3., 2., 1.]) + 2. def target_log_prob_fn(x): return tf.gather(target_logits, x), () def proposal_fn(x, seed): current_logits = tf.gather(proposal_logits, x) proposal = util.random_categorical(proposal_logits[tf.newaxis], x.shape[0], seed)[0] proposed_logits = tf.gather(proposal_logits, proposal) return tf.cast(proposal, x.dtype), ((), proposed_logits - current_logits) def kernel(rwm_state, seed): rwm_seed, seed = util.split_seed(seed, 2) rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis( rwm_state, target_log_prob_fn=target_log_prob_fn, proposal_fn=proposal_fn, seed=rwm_seed) return (rwm_state, seed), rwm_extra seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.random_walk_metropolis_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) # Discard the warmup samples. chain = chain[500:] sample_mean = tf.reduce_mean(tf.one_hot(chain, 4), axis=[0, 1]) self.assertAllClose(tf.nn.softmax(target_logits), sample_mean, atol=0.11)
def testSimpleDualAverages(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] def kernel(sda_state, rms_state): sda_state, _ = fun_mcmc.simple_dual_averages_step(sda_state, loss_fn, 1.) rms_state, _ = fun_mcmc.running_mean_step(rms_state, sda_state.state) return (sda_state, rms_state), rms_state.mean _, (x, y) = fun_mcmc.trace( ( fun_mcmc.simple_dual_averages_init( [self._constant(0.), self._constant(0.)]), fun_mcmc.running_mean_init([[], []], [self._dtype, self._dtype]), ), kernel, num_steps=1000, ) self.assertAllClose(1., x[-1], atol=1e-1) self.assertAllClose(2., y[-1], atol=1e-1)
def testPotentialScaleReduction(self, chain_shape, independent_chain_ndims): self.skipTest('https://github.com/tensorflow/probability/issues/1054') # TODO(siege): Update fun_mcmc.potential_scale_reduction. rng = np.random.RandomState(_test_seed()) chain_means = rng.randn(*((1,) + chain_shape[1:])).astype(np.float32) chains = 0.4 * rng.randn(*chain_shape).astype(np.float32) + chain_means true_rhat = tfp.mcmc.potential_scale_reduction( chains, independent_chain_ndims=independent_chain_ndims) chains = self._constant(chains) psrs, _ = fun_mcmc.trace( state=fun_mcmc.potential_scale_reduction_init(chain_shape[1:], self._dtype), fn=lambda psrs: fun_mcmc.potential_scale_reduction_step( # pylint: disable=g-long-lambda psrs, chains[psrs.num_points]), num_steps=chain_shape[0], trace_fn=lambda *_: ()) running_rhat = fun_mcmc.potential_scale_reduction_extract( psrs, independent_chain_ndims=independent_chain_ndims) self.assertAllClose(true_rhat, running_rhat)
def fun(x): return fun_mcmc.trace( x, lambda x: (x + 1., x + 1.), 2, trace_mask=False, unroll=unroll)
def testRunningApproximateAutoCovariance(self, state_shape, event_ndims, aggregation): # We'll use HMC as the source of our chain. # While HMC is being sampled, we also compute the running autocovariance. step_size = 0.2 num_steps = 1000 num_leapfrog_steps = 10 max_lags = 300 state = tf.zeros(state_shape, dtype=self._dtype) def target_log_prob_fn(x): lp = -0.5 * tf.square(x) if event_ndims is None: return lp, () else: return tf.reduce_sum(lp, -1), () def kernel(hmc_state, raac_state, seed): hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step( raac_state, hmc_state.state, axis=aggregation) return (hmc_state, raac_state, seed), hmc_extra seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. (_, raac_state, _), chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=( fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.running_approximate_auto_covariance_init( max_lags=max_lags, state_shape=state_shape, dtype=state.dtype, axis=aggregation), seed, ), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) true_aggregation = (0,) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_variance = np.array( tf.math.reduce_variance(np.array(chain), true_aggregation)) true_autocov = np.array( tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags)) if aggregation is not None: true_autocov = tf.reduce_mean( true_autocov, [a + 1 for a in util.flatten_tree(aggregation)]) self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5) self.assertAllClose( true_autocov, raac_state.auto_covariance / raac_state.auto_covariance[0], atol=0.1)
def testAdaptiveHMC(self): num_chains = 16 num_steps = 4000 num_warmup_steps = num_steps // 2 num_adapt_steps = int(0.8 * num_warmup_steps) # Setup the model and state constraints. model = tfp.distributions.JointDistributionSequential([ tfp.distributions.Normal(loc=self._constant(0.), scale=1.), tfp.distributions.Independent( tfp.distributions.LogNormal(loc=self._constant([1., 1.]), scale=0.5), 1), ]) bijector = [tfp.bijectors.Identity(), tfp.bijectors.Exp()] transform_fn = util_tfp.bijector_to_transform_fn(bijector, model.dtype, batch_ndims=1) def target_log_prob_fn(*x): return model.log_prob(x), () # Start out at zeros (in the unconstrained space). state, _ = transform_fn(*[ tf.zeros([num_chains] + list(e), dtype=self._dtype) for e in model.event_shape ]) reparam_log_prob_fn, reparam_state = fun_mcmc.reparameterize_potential_fn( target_log_prob_fn, transform_fn, state) # Define the kernel. def kernel(adaptive_hmc_state, seed): hmc_seed, seed = util.split_seed(seed, 2) adaptive_hmc_state, adaptive_hmc_extra = ( prefab.adaptive_hamiltonian_monte_carlo_step( adaptive_hmc_state, target_log_prob_fn=reparam_log_prob_fn, num_adaptation_steps=num_adapt_steps, seed=hmc_seed)) return (adaptive_hmc_state, seed), (adaptive_hmc_extra.state, adaptive_hmc_extra.is_accepted, adaptive_hmc_extra.step_size) seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, (state_chain, is_accepted_chain, _) = tf.function(lambda reparam_state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(prefab.adaptive_hamiltonian_monte_carlo_init( reparam_state, reparam_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(reparam_state, seed) # Discard the warmup samples. state_chain = [s[num_warmup_steps:] for s in state_chain] is_accepted_chain = is_accepted_chain[num_warmup_steps:] accept_rate = tf.reduce_mean(tf.cast(is_accepted_chain, tf.float32)) rhat = tfp.mcmc.potential_scale_reduction(state_chain) sample_mean = [tf.reduce_mean(s, axis=[0, 1]) for s in state_chain] sample_var = [ tf.math.reduce_variance(s, axis=[0, 1]) for s in state_chain ] self.assertAllAssertsNested(lambda rhat: self.assertAllLess(rhat, 1.1), rhat) self.assertAllClose(0.8, accept_rate, atol=0.05) self.assertAllClose(model.mean(), sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(model.variance(), sample_var, rtol=0.1, atol=0.1)
def interactive_trace( state: 'fun_mc.State', fn: 'fun_mc.TransitionOperator', num_steps: 'fun_mc.IntTensor', trace_mask: 'fun_mc.BooleanNest' = True, block_until_ready: 'bool' = True, progress_bar_fn: 'Callable[[Iterable[Any]], Iterable[Any]]' = ( _tqdm_progress_bar_fn), ) -> 'Tuple[fun_mc.State, fun_mc.TensorNest]': """Wrapped around fun_mcmc.trace, suited for interactive work. This is accomplished through unrolling fun_mcmc.trace, as well as optionally using a progress bar (TQDM by default). Args: state: A nest of `Tensor`s or None. fn: A `TransitionOperator`. num_steps: Number of steps to run the function for. Must be greater than 1. trace_mask: A potentially shallow nest with boolean leaves applied to the `extra` return value of `fn`. This controls whether or not to actually trace the quantities in `extra`. For subtrees of `extra` where the mask leaf is `True`, those subtrees are traced (i.e. the corresponding subtrees in `traces` will contain an extra leading dimension equalling `num_steps`). For subtrees of `extra` where the mask leaf is `False`, those subtrees are merely propagated, and their corresponding subtrees in `traces` correspond to their final value. block_until_ready: Whether to wait for the computation to finish between steps. This results in smoother progress bars under, e.g., JAX. progress_bar_fn: A callable that will be called with an iterable with length `num_steps` and which returns another iterable with the same length. This will be advanced for every step taken. If None, no progress bar is shown. Default: `lambda it: tqdm.tqdm(it, leave=True)`. Returns: state: The final state returned by `fn`. traces: A nest with the same structure as the extra return value of `fn`, but with leaves replaced with stacked and unstacked values according to the `trace_mask`. """ num_steps = tf.get_static_value(num_steps) if num_steps is None: raise ValueError( 'Interactive tracing requires `num_steps` to be statically known.') if progress_bar_fn is None: pbar = None else: pbar = iter(progress_bar_fn(range(num_steps))) def fn_with_progress(state): state, extra = fun_mc.call_transition_operator(fn, state) if block_until_ready: state, extra = util.block_until_ready((state, extra)) if pbar is not None: try: next(pbar) except StopIteration: pass return [state], extra [state], trace = fun_mc.trace( # Wrap the state in a singleton list to simplify implementation of # `fn_with_progress`. state=[state], fn=fn_with_progress, num_steps=num_steps, trace_mask=trace_mask, unroll=True, ) return state, trace
def trace_n(num_steps): return fun_mcmc.trace(0, lambda x: (x + 1, ()), num_steps)[0]