Exemple #1
0
def test_linear_regression():
    x_data = np.random.normal(0, 5, size=1000).reshape(-1, 1)
    y_data = 3 * x_data + np.random.normal(size=x_data.shape)

    kernel = HMC(
        step_size=0.001,
        num_integration_steps=90,
        inverse_mass_matrix=jnp.array([1.0, 1.0]),
    )

    observations = {"x": x_data, "predictions": y_data}
    rng_key = jax.random.PRNGKey(2)

    # Batch sampler
    sampler = mcx.sampler(
        rng_key,
        linear_regression,
        kernel,
        num_chains=2,
        **observations,
    )
    trace = sampler.run(num_samples=3000)

    mean_coeffs = np.asarray(
        jnp.mean(trace.raw.samples["coeffs"][:, 1000:], axis=1))
    mean_scale = np.asarray(
        jnp.mean(trace.raw.samples["sigma"][:, 1000:], axis=1))
    assert mean_coeffs == pytest.approx(3, 1e-1)
    assert mean_scale == pytest.approx(1, 1e-1)
Exemple #2
0
    def _fit(self, kernel, num_samples=1000, accelerate=True, **observations):
        """While it impossible to provide a universal fitting mechanism, some
        are certainly better than others.
        """

        _, self.rng_key = jax.random.split(self.rng_key)
        sampler = mcx.sampler(
                self.rng_key,
                self.model,
                kernel,
                **observations,
        )
        trace = sampler.run(1000, accelerate)

        self.sampler = sampler
        self.trace = trace

        return trace
Exemple #3
0
def test_linear_regression_mvn():
    # We only check that we can sample, but the results are not checked.
    x_data = np.random.multivariate_normal([0, 1], [[1.0, 0.4], [0.4, 1.0]],
                                           size=1000)
    y_data = x_data @ np.array([3, 1]) + np.random.normal(size=x_data.shape[0])

    kernel = HMC(num_integration_steps=90, )

    rng_key = jax.random.PRNGKey(2)

    # Batch sampler
    sampler = mcx.sampler(
        rng_key,
        linear_regression_mvn,
        (x_data, ),
        {"predictions": y_data},
        kernel,
        num_chains=2,
    )
    trace = sampler.run(num_samples=3000)