def testRunningVarianceMaxPoints(self): window_size = 100 rng = np.random.RandomState(_test_seed()) data = self._constant( np.concatenate( [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)], axis=0)) def kernel(rvs, idx): rvs, _ = fun_mcmc.running_variance_step( rvs, data[idx], window_size=window_size) return (rvs, idx + 1), (rvs.mean, rvs.variance) _, (mean, var) = fun_mcmc.trace( state=(fun_mcmc.running_variance_init([], data.dtype), 0), fn=kernel, num_steps=len(data), ) # Up to window_size, we compute the running mean/variance exactly. self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1]) self.assertAllClose(np.var(data[:window_size]), var[window_size - 1]) # After window_size, we're doing exponential moving average, and pick up the # mean/variance after the change in the distribution. Since the moving # average is computed only over ~window_size points, this test is rather # noisy. self.assertAllClose(1., mean[-1], atol=0.2) self.assertAllClose(4., var[-1], atol=0.8)
def adaptive_hamiltonian_monte_carlo_init( state: 'fun_mc.TensorNest', target_log_prob_fn: 'fun_mc.PotentialFn', step_size: 'fun_mc.FloatTensor' = 1e-2, initial_mean: 'fun_mc.FloatNest' = 0., initial_scale: 'fun_mc.FloatNest' = 1., scale_smoothing_steps: 'fun_mc.IntTensor' = 10, ) -> 'AdaptiveHamiltonianMonteCarloState': """Initializes `AdaptiveHamiltonianMonteCarloState`. Args: state: Initial state of the chain. target_log_prob_fn: Target log prob fn. step_size: Initial scalar step size. initial_mean: Initial mean for computing the running variance estimate. Must broadcast structurally and tensor-wise with state. initial_scale: Initial scale for computing the running variance estimate. Must broadcast structurally and tensor-wise with state. scale_smoothing_steps: How much weight to assign to the `initial_mean` and `initial_scale`. Increase this to stabilize early adaptation. Returns: adaptive_hmc_state: State of the `adaptive_hamiltonian_monte_carlo_step` `TransitionOperator`. """ hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn) dtype = util.flatten_tree(hmc_state.state)[0].dtype chain_ndims = len(hmc_state.target_log_prob.shape) running_var_state = fun_mc.running_variance_init( shape=util.map_tree(lambda s: s.shape[chain_ndims:], hmc_state.state), dtype=util.map_tree(lambda s: s.dtype, hmc_state.state), ) initial_mean = fun_mc.maybe_broadcast_structure(initial_mean, state) initial_scale = fun_mc.maybe_broadcast_structure(initial_scale, state) # It's important to add some smoothing here, as initial updates can be very # different than the stationary distribution. # TODO(siege): Add a pseudo-update functionality to avoid fiddling with the # internals here. running_var_state = running_var_state._replace( num_points=util.map_tree( lambda p: ( # pylint: disable=g-long-lambda int(np.prod(hmc_state.target_log_prob.shape)) * tf.cast( scale_smoothing_steps, p.dtype)), running_var_state.num_points), mean=util.map_tree( # pylint: disable=g-long-lambda lambda m, init_m: tf.ones_like(m) * init_m, running_var_state.mean, initial_mean), variance=util.map_tree( # pylint: disable=g-long-lambda lambda v, init_s: tf.ones_like(v) * init_s**2, running_var_state.variance, initial_scale), ) log_step_size_opt_state = fun_mc.adam_init( tf.math.log(tf.convert_to_tensor(step_size, dtype=dtype))) return AdaptiveHamiltonianMonteCarloState( hmc_state=hmc_state, running_var_state=running_var_state, log_step_size_opt_state=log_step_size_opt_state, step=tf.zeros([], tf.int32))
def testRunningVariance(self, shape, aggregation): rng = np.random.RandomState(_test_seed()) data = self._constant(rng.randn(*shape)) true_aggregation = (0,) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_mean = np.mean(data, true_aggregation) true_var = np.var(data, true_aggregation) def kernel(rvs, idx): rvs, _ = fun_mcmc.running_variance_step(rvs, data[idx], axis=aggregation) return (rvs, idx + 1), () (rvs, _), _ = fun_mcmc.trace( state=(fun_mcmc.running_variance_init(true_mean.shape, data[0].dtype), 0), fn=kernel, num_steps=len(data), trace_fn=lambda *args: ()) self.assertAllClose(true_mean, rvs.mean) self.assertAllClose(true_var, rvs.variance)