コード例 #1
0
    def test_multiple_chains(self, make_kernel, target_accept_rate):
        num_chains = 16
        num_samples = 4000

        sample_key, chain_key, init_key = random.split(self._seed, 3)
        unconstrained_log_prob = self._make_unconstrained_log_prob()
        initial_states = jax.vmap(self._initialize_state)(random.split(
            init_key, num_chains))

        kernel = make_kernel(unconstrained_log_prob)
        sample_chain = jax.jit(
            jax.vmap(
                harvest.harvest(kernels.sample_chain(kernel, num_samples),
                                tag=kernels.MCMC_METRICS)))

        true_samples = self.model.sample(sample_shape=4096, seed=sample_key)
        samples, metrics = sample_chain({},
                                        random.split(chain_key, num_chains),
                                        initial_states)
        samples = tf.nest.map_structure(
            lambda s, shape: s.reshape([num_chains * num_samples] + list(shape)
                                       ), samples, self.model.event_shape)

        onp.testing.assert_allclose(true_samples.mean(axis=0),
                                    samples.mean(axis=0),
                                    rtol=0.1,
                                    atol=0.1)
        onp.testing.assert_allclose(np.cov(true_samples.T),
                                    np.cov(samples.T),
                                    rtol=0.1,
                                    atol=0.1)
        onp.testing.assert_allclose(target_accept_rate,
                                    metrics['kernel']['accept_prob'].mean(),
                                    atol=1e-2,
                                    rtol=1e-2)
コード例 #2
0
    def test_single_chain(self, make_kernel, target_accept_rate):
        num_samples = 20000

        sample_key, chain_key, init_key = random.split(self._seed, 3)
        unconstrained_log_prob = self._make_unconstrained_log_prob()
        initial_state = self._initialize_state(init_key)
        kernel = make_kernel(unconstrained_log_prob)
        sample_chain = jax.jit(
            harvest.harvest(kernels.sample_chain(kernel, num_samples),
                            tag=kernels.MCMC_METRICS))

        true_samples = self.model.sample(sample_shape=4096, seed=sample_key)
        samples, metrics = sample_chain({}, chain_key, initial_state)

        onp.testing.assert_allclose(true_samples.mean(axis=0),
                                    samples.mean(axis=0),
                                    rtol=0.5,
                                    atol=0.1)
        onp.testing.assert_allclose(np.cov(true_samples.T),
                                    np.cov(samples.T),
                                    rtol=0.5,
                                    atol=0.1)
        onp.testing.assert_allclose(target_accept_rate,
                                    metrics['kernel']['accept_prob'].mean(),
                                    atol=1e-2,
                                    rtol=1e-2)
コード例 #3
0
ファイル: kernels_test.py プロジェクト: mederrata/probability
    def test_single_chain(self, make_kernel, target_accept_rate):
        num_samples = 20000

        sample_key, chain_key, init_key = random.split(self._seed, 3)
        unconstrained_log_prob = self._make_unconstrained_log_prob()
        initial_state = self._initialize_state(init_key)
        kernel = make_kernel(unconstrained_log_prob)
        sample_chain = jax.jit(kernels.sample_chain(kernel, num_samples))

        true_samples = self.model.sample(sample_shape=4096, seed=sample_key)
        samples = sample_chain(chain_key, initial_state)

        onp.testing.assert_allclose(true_samples.mean(axis=0),
                                    samples.mean(axis=0),
                                    rtol=0.5,
                                    atol=0.1)
        onp.testing.assert_allclose(np.cov(true_samples.T),
                                    np.cov(samples.T),
                                    rtol=0.5,
                                    atol=0.1)