def trace(): kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=0.1, num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=tfp_test_util.test_seed()) fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState( tf.zeros([1])), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def trace(): # pylint: disable=g-long-lambda kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=0.1, num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=self._make_seed(tfp_test_util.test_seed())) fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init( state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def trace(): kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=self._constant(0.1), num_integrator_steps=3, target_log_prob_fn=target_log_prob_fn, seed=_test_seed()) fun_mcmc.trace( state=fun_mcmc.hamiltonian_monte_carlo_init( tf.zeros([1], dtype=self._dtype), target_log_prob_fn), fn=kernel, num_steps=4, trace_fn=lambda *args: ())
def testRunningCovarianceMaxPoints(self): window_size = 100 rng = np.random.RandomState(_test_seed()) data = tf.convert_to_tensor( np.concatenate( [ rng.randn(window_size, 2), np.array([1., 2.]) + np.array([2., 3.]) * rng.randn(window_size * 10, 2) ], axis=0, )) def kernel(rvs, idx): rvs, _ = fun_mcmc.running_covariance_step( rvs, data[idx], window_size=window_size) return (rvs, idx + 1), (rvs.mean, rvs.covariance) _, (mean, cov) = fun_mcmc.trace( state=(fun_mcmc.running_covariance_init([2], data.dtype), 0), fn=kernel, num_steps=len(data), ) # Up to window_size, we compute the running mean/variance exactly. self.assertAllClose( np.mean(data[:window_size], axis=0), mean[window_size - 1]) self.assertAllClose( _gen_cov(data[:window_size], axis=0), cov[window_size - 1]) # After window_size, we're doing exponential moving average, and pick up the # mean/variance after the change in the distribution. Since the moving # average is computed only over ~window_size points, this test is rather # noisy. self.assertAllClose(np.array([1., 2.]), mean[-1], atol=0.2) self.assertAllClose(np.array([[4., 0.], [0., 9.]]), cov[-1], atol=1.)
def testTraceTrace(self): def fun(x): return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args: ()) x, _ = fun_mcmc.trace(0., fun, 2, lambda *args: ()) self.assertAllEqual(4., x)
def SanitizedAutoCorrelationMean(x, axis, reduce_axis, max_lags=None, **kwargs): shape_arr = np.array(list(x.shape)) axes = list(sorted(set(range(len(shape_arr))) - set([reduce_axis]))) mean_shape = shape_arr[axes] if max_lags is not None: mean_shape[axis] = max_lags + 1 mean_state = fun_mcmc.running_mean_init(mean_shape, x.dtype) new_order = list(range(len(shape_arr))) new_order[0] = new_order[reduce_axis] new_order[reduce_axis] = 0 x = tf.transpose(x, new_order) x_arr = tf.TensorArray(x.dtype, x.shape[0]).unstack(x) mean_state, _ = fun_mcmc.trace( state=mean_state, fn=lambda state: fun_mcmc.running_mean_step( # pylint: disable=g-long-lambda state, SanitizedAutoCorrelation(x_arr.read(state.num_points), axis, max_lags=max_lags, **kwargs)), num_steps=x.shape[0], trace_fn=lambda *_: ()) return mean_state.mean
def testWrapTransitionKernel(self): class TestKernel(tfp.mcmc.TransitionKernel): def one_step(self, current_state, previous_kernel_results): return [x + 1 for x in current_state], previous_kernel_results + 1 def bootstrap_results(self, current_state): return sum(current_state) def is_calibrated(self): return True def kernel(state, pkr): return fun_mcmc.transition_kernel_wrapper(state, pkr, TestKernel()) state = {'x': 0., 'y': 1.} kr = 1. (final_state, final_kr), _ = fun_mcmc.trace( (state, kr), kernel, 2, trace_fn=lambda *args: (), ) self.assertAllEqual({ 'x': 2., 'y': 3. }, util.map_tree(np.array, final_state)) self.assertAllEqual(1. + 2., final_kr)
def testRunningVarianceMaxPoints(self): window_size = 100 rng = np.random.RandomState(_test_seed()) data = tf.convert_to_tensor( np.concatenate( [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)], axis=0)) def kernel(rvs, idx): rvs, _ = fun_mcmc.running_variance_step( rvs, data[idx], window_size=window_size) return (rvs, idx + 1), (rvs.mean, rvs.variance) _, (mean, var) = fun_mcmc.trace( state=(fun_mcmc.running_variance_init([], data.dtype), 0), fn=kernel, num_steps=len(data), ) # Up to window_size, we compute the running mean/variance exactly. self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1]) self.assertAllClose(np.var(data[:window_size]), var[window_size - 1]) # After window_size, we're doing exponential moving average, and pick up the # mean/variance after the change in the distribution. Since the moving # average is computed only over ~window_size points, this test is rather # noisy. self.assertAllClose(1., mean[-1], atol=0.2) self.assertAllClose(4., var[-1], atol=0.8)
def testPreconditionedHMC(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2], dtype=self._dtype) base_mean = self._constant([1., 0]) base_cov = self._constant([[1, 0.5], [0.5, 1]]) bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda def kernel(hmc_state, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_state.state_extra[0] if not self._is_on_jax: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed())) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def testTraceTrace(self): def fun(x): return fun_mcmc.trace(x, lambda x: (x + 1., x + 1.), 2, trace_mask=False) x, trace = fun_mcmc.trace(0., fun, 2) self.assertAllEqual(4., x) self.assertAllEqual([2., 4.], trace)
def testTraceMask(self): def fun(x): return x + 1, (2 * x, 3 * x) x, (trace_1, trace_2) = fun_mcmc.trace( state=0, fn=fun, num_steps=3, trace_mask=(True, False)) self.assertAllEqual(3, x) self.assertAllEqual([0, 2, 4], trace_1) self.assertAllEqual(6, trace_2) x, (trace_1, trace_2) = fun_mcmc.trace( state=0, fn=fun, num_steps=3, trace_mask=False) self.assertAllEqual(3, x) self.assertAllEqual(4, trace_1) self.assertAllEqual(6, trace_2)
def testTraceSingle(self): def fun(x): return x + 1., 2 * x x, e_trace = fun_mcmc.trace( state=0., fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1) self.assertAllEqual(5., x) self.assertAllEqual([0., 2., 4., 6., 8.], e_trace)
def testTraceSingle(self): def fun(x): if x is None: x = 0. return x + 1., 2 * x x, e_trace = fun_mcmc.trace( state=None, fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1) self.assertAllEqual(5., x.numpy()) self.assertAllEqual([0., 2., 4., 6., 8.], e_trace.numpy())
def testTraceNested(self): def fun(x, y): return (x + 1., y + 2.), () (x, y), (x_trace, y_trace) = fun_mcmc.trace( state=(0., 0.), fn=fun, num_steps=5, trace_fn=lambda xy, _: xy) self.assertAllEqual(5., x) self.assertAllEqual(10., y) self.assertAllEqual([1., 2., 3., 4., 5.], x_trace) self.assertAllEqual([2., 4., 6., 8., 10.], y_trace)
def testBasicHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = tf.constant([2., 3.]) base_scale = tf.constant([2., 0.5]) def target_log_prob_fn(x): return -tf.reduce_sum(0.5 * tf.square( (x - base_mean) / base_scale), -1), () def kernel(hmc_state, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_extra if not self._is_on_jax: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) true_samples = util.random_normal( shape=[4096, 2], dtype=tf.float32, seed=seed) * base_scale + base_mean true_mean = tf.reduce_mean(true_samples, axis=0) true_var = tf.math.reduce_variance(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
def computation(state, seed): bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) rate = fun_mcmc.prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, final_step_size=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1, seed), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample( 4096, seed=self._make_seed(_test_seed())) return chain, log_accept_ratio_trace, true_samples
def testGradientDescent(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] _, [(x, y), loss] = fun_mcmc.trace( fun_mcmc.GradientDescentState([tf.zeros([]), tf.zeros([])]), lambda gd_state: fun_mcmc.gradient_descent_step( # pylint: disable=g-long-lambda gd_state, loss_fn, learning_rate=0.01), num_steps=1000, trace_fn=lambda state, extra: [state.state, extra.loss]) self.assertAllClose(1., x[-1], atol=1e-3) self.assertAllClose(2., y[-1], atol=1e-3) self.assertAllClose(0., loss[-1], atol=1e-3)
def testRandomWalkMetropolis(self): num_steps = 1000 state = tf.ones([16], dtype=tf.int32) target_logits = tf.constant([1., 2., 3., 4.]) + 2. proposal_logits = tf.constant([4., 3., 2., 1.]) + 2. def target_log_prob_fn(x): return tf.gather(target_logits, x), () def proposal_fn(x, seed): current_logits = tf.gather(proposal_logits, x) proposal = util.random_categorical(proposal_logits[tf.newaxis], x.shape[0], seed)[0] proposed_logits = tf.gather(proposal_logits, proposal) return tf.cast(proposal, x.dtype), ((), proposed_logits - current_logits) def kernel(rwm_state, seed): if backend.get_backend() == backend.TENSORFLOW: rwm_seed = tfp_test_util.test_seed() else: rwm_seed, seed = util.split_seed(seed, 2) rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis( rwm_state, target_log_prob_fn=target_log_prob_fn, proposal_fn=proposal_fn, seed=rwm_seed) return (rwm_state, seed), rwm_extra if backend.get_backend() == backend.TENSORFLOW: seed = tfp_test_util.test_seed() else: seed = self._make_seed(tfp_test_util.test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.random_walk_metropolis_init( state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state[0].state))(state, seed) # Discard the warmup samples. chain = chain[500:] sample_mean = tf.reduce_mean(tf.one_hot(chain, 4), axis=[0, 1]) self.assertAllClose(tf.nn.softmax(target_logits), sample_mean, atol=0.1)
def testPreconditionedHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=_test_seed())) _, chain = fun_mcmc.trace( state=fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state.state_extra[0]) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=_test_seed()) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def testAdam(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] _, [(x, y), loss] = fun_mcmc.trace( fun_mcmc.adam_init([self._constant(0.), self._constant(0.)]), lambda adam_state: fun_mcmc.adam_step( # pylint: disable=g-long-lambda adam_state, loss_fn, learning_rate=self._constant(0.01)), num_steps=1000, trace_fn=lambda state, extra: [state.state, extra.loss]) self.assertAllClose(1., x[-1], atol=1e-3) self.assertAllClose(2., y[-1], atol=1e-3) self.assertAllClose(0., loss[-1], atol=1e-3)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample(4096, seed=_test_seed()) return chain, log_accept_ratio_trace, true_samples
def testRunningMean(self, shape, aggregation): rng = np.random.RandomState(_test_seed()) data = tf.convert_to_tensor(rng.randn(*shape)) def kernel(rms, idx): rms, _ = fun_mcmc.running_mean_step(rms, data[idx], axis=aggregation) return (rms, idx + 1), () true_aggregation = (0,) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_mean = np.mean(data, true_aggregation) (rms, _), _ = fun_mcmc.trace( state=(fun_mcmc.running_mean_init(true_mean.shape, data.dtype), 0), fn=kernel, num_steps=len(data), trace_fn=lambda *args: ()) self.assertAllClose(true_mean, rms.mean)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9, adaptation_rate=rate) return (hmc_state, step_size, step + 1), hmc_extra _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0), kernel, num_adapt_steps + num_steps, trace_fn=lambda state, extra: (state[0].state_extra[0], extra.log_accept_ratio)) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) return chain, log_accept_ratio_trace, true_samples
def testPotentialScaleReduction(self, chain_shape, independent_chain_ndims): rng = np.random.RandomState(_test_seed()) chain_means = rng.randn(*((1,) + chain_shape[1:])).astype(np.float32) chains = 0.4 * rng.randn(*chain_shape).astype(np.float32) + chain_means true_rhat = tfp.mcmc.potential_scale_reduction( chains, independent_chain_ndims=independent_chain_ndims) chains = tf.convert_to_tensor(chains) psrs, _ = fun_mcmc.trace( state=fun_mcmc.potential_scale_reduction_init(chain_shape[1:], tf.float32), fn=lambda psrs: fun_mcmc.potential_scale_reduction_step( # pylint: disable=g-long-lambda psrs, chains[psrs.num_points]), num_steps=chain_shape[0], trace_fn=lambda *_: ()) running_rhat = fun_mcmc.potential_scale_reduction_extract( psrs, independent_chain_ndims=independent_chain_ndims) self.assertAllClose(true_rhat, running_rhat)
def testRunningCovariance(self, shape, aggregation): data = tf.convert_to_tensor(np.random.randn(*shape)) true_aggregation = (0, ) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_mean = np.mean(data, true_aggregation) true_cov = _gen_cov(data, true_aggregation) def kernel(rcs, idx): rcs, _ = fun_mcmc.running_covariance_step(rcs, data[idx], axis=aggregation) return (rcs, idx + 1), () (rcs, _), _ = fun_mcmc.trace(state=(fun_mcmc.running_covariance_init( true_mean.shape, data[0].dtype), 0), fn=kernel, num_steps=len(data), trace_fn=lambda *args: ()) self.assertAllClose(true_mean, rcs.mean) self.assertAllClose(true_cov, rcs.covariance)
def testRunningVariance(self, shape, aggregation): rng = np.random.RandomState(_test_seed()) data = self._constant(rng.randn(*shape)) true_aggregation = (0,) + (() if aggregation is None else tuple( [a + 1 for a in util.flatten_tree(aggregation)])) true_mean = np.mean(data, true_aggregation) true_var = np.var(data, true_aggregation) def kernel(rvs, idx): rvs, _ = fun_mcmc.running_variance_step(rvs, data[idx], axis=aggregation) return (rvs, idx + 1), () (rvs, _), _ = fun_mcmc.trace( state=(fun_mcmc.running_variance_init(true_mean.shape, data[0].dtype), 0), fn=kernel, num_steps=len(data), trace_fn=lambda *args: ()) self.assertAllClose(true_mean, rvs.mean) self.assertAllClose(true_var, rvs.variance)
def testSimpleDualAverages(self): def loss_fn(x, y): return tf.square(x - 1.) + tf.square(y - 2.), [] def kernel(sda_state, rms_state): sda_state, _ = fun_mcmc.simple_dual_averages_step(sda_state, loss_fn, 1.) rms_state, _ = fun_mcmc.running_mean_step(rms_state, sda_state.state) return (sda_state, rms_state), rms_state.mean _, (x, y) = fun_mcmc.trace( ( fun_mcmc.simple_dual_averages_init( [self._constant(0.), self._constant(0.)]), fun_mcmc.running_mean_init([[], []], [self._dtype, self._dtype]), ), kernel, num_steps=1000, ) self.assertAllClose(1., x[-1], atol=1e-1) self.assertAllClose(2., y[-1], atol=1e-1)
def testRunningMeanMaxPoints(self): window_size = 100 rng = np.random.RandomState(_test_seed()) data = self._constant( np.concatenate( [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)], axis=0)) def kernel(rms, idx): rms, _ = fun_mcmc.running_mean_step( rms, data[idx], window_size=window_size) return (rms, idx + 1), rms.mean _, mean = fun_mcmc.trace( state=(fun_mcmc.running_mean_init([], data.dtype), 0), fn=kernel, num_steps=len(data), ) # Up to window_size, we compute the running mean exactly. self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1]) # After window_size, we're doing exponential moving average, and pick up the # mean after the change in the distribution. Since the moving average is # computed only over ~window_size points, this test is rather noisy. self.assertAllClose(1., mean[-1], atol=0.2)
def trace_n(num_steps): return fun_mcmc.trace(0, lambda x: (x + 1, ()), num_steps)[0]
def fun(x): return fun_mcmc.trace(x, lambda x: (x + 1., x + 1.), 2, trace_mask=False)