def test_multiple_chains(self, make_kernel, target_accept_rate): num_chains = 16 num_samples = 4000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_states = jax.vmap(self._initialize_state)(random.split( init_key, num_chains)) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit( jax.vmap( harvest.harvest(kernels.sample_chain(kernel, num_samples), tag=kernels.MCMC_METRICS))) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples, metrics = sample_chain({}, random.split(chain_key, num_chains), initial_states) samples = tf.nest.map_structure( lambda s, shape: s.reshape([num_chains * num_samples] + list(shape) ), samples, self.model.event_shape) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.1, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.1, atol=0.1) onp.testing.assert_allclose(target_accept_rate, metrics['kernel']['accept_prob'].mean(), atol=1e-2, rtol=1e-2)
def test_single_chain(self, make_kernel, target_accept_rate): num_samples = 20000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_state = self._initialize_state(init_key) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit( harvest.harvest(kernels.sample_chain(kernel, num_samples), tag=kernels.MCMC_METRICS)) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples, metrics = sample_chain({}, chain_key, initial_state) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.5, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.5, atol=0.1) onp.testing.assert_allclose(target_accept_rate, metrics['kernel']['accept_prob'].mean(), atol=1e-2, rtol=1e-2)
def test_single_chain(self, make_kernel, target_accept_rate): num_samples = 20000 sample_key, chain_key, init_key = random.split(self._seed, 3) unconstrained_log_prob = self._make_unconstrained_log_prob() initial_state = self._initialize_state(init_key) kernel = make_kernel(unconstrained_log_prob) sample_chain = jax.jit(kernels.sample_chain(kernel, num_samples)) true_samples = self.model.sample(sample_shape=4096, seed=sample_key) samples = sample_chain(chain_key, initial_state) onp.testing.assert_allclose(true_samples.mean(axis=0), samples.mean(axis=0), rtol=0.5, atol=0.1) onp.testing.assert_allclose(np.cov(true_samples.T), np.cov(samples.T), rtol=0.5, atol=0.1)