def test_linear_regression(): x_data = np.random.normal(0, 5, size=1000).reshape(-1, 1) y_data = 3 * x_data + np.random.normal(size=x_data.shape) kernel = HMC( step_size=0.001, num_integration_steps=90, inverse_mass_matrix=jnp.array([1.0, 1.0]), ) observations = {"x": x_data, "predictions": y_data} rng_key = jax.random.PRNGKey(2) # Batch sampler sampler = mcx.sampler( rng_key, linear_regression, kernel, num_chains=2, **observations, ) trace = sampler.run(num_samples=3000) mean_coeffs = np.asarray( jnp.mean(trace.raw.samples["coeffs"][:, 1000:], axis=1)) mean_scale = np.asarray( jnp.mean(trace.raw.samples["sigma"][:, 1000:], axis=1)) assert mean_coeffs == pytest.approx(3, 1e-1) assert mean_scale == pytest.approx(1, 1e-1)
def _fit(self, kernel, num_samples=1000, accelerate=True, **observations): """While it impossible to provide a universal fitting mechanism, some are certainly better than others. """ _, self.rng_key = jax.random.split(self.rng_key) sampler = mcx.sampler( self.rng_key, self.model, kernel, **observations, ) trace = sampler.run(1000, accelerate) self.sampler = sampler self.trace = trace return trace
def test_linear_regression_mvn(): # We only check that we can sample, but the results are not checked. x_data = np.random.multivariate_normal([0, 1], [[1.0, 0.4], [0.4, 1.0]], size=1000) y_data = x_data @ np.array([3, 1]) + np.random.normal(size=x_data.shape[0]) kernel = HMC(num_integration_steps=90, ) rng_key = jax.random.PRNGKey(2) # Batch sampler sampler = mcx.sampler( rng_key, linear_regression_mvn, (x_data, ), {"predictions": y_data}, kernel, num_chains=2, ) trace = sampler.run(num_samples=3000)