def test_sample_prior_predictive_without_intermediates(self): def model(N, d): x = sample("x", self.DistWithIntermediate(), sample_shape=(N, d)) N, d = 100, 2 rng_key = jax.random.PRNGKey(3781) samples = sample_prior_predictive(rng_key, model, (N, d)) self.assertEqual((N, d), jnp.shape(samples['x']))
def create_toy_data(rng_key, N, d): ## Create some toy data mu_true = jnp.ones(d) samples = sample_prior_predictive(rng_key, model, (None, 2 * N, d), {'mu': mu_true}) X = samples['obs'] X_train = X[:N] X_test = X[N:] return X_train, X_test, mu_true
def test_sample_prior_predictive_with_intermediates(self): def model(N, d): x = sample("x", self.DistWithIntermediate(), sample_shape=(N, d)) N, d = 100, 2 rng_key = jax.random.PRNGKey(2) samples = sample_prior_predictive(rng_key, model, (N, d), with_intermediates=True) self.assertEqual((N, d), jnp.shape(samples['x'][0])) self.assertEqual(2, len(samples['x'])) self.assertEqual(1, len(samples['x'][1])) self.assertEqual((N, d, 2), jnp.shape(samples['x'][1][0]))
def test_sample_prior_predictive(self): def model(N, d): mu = sample("mu", dist.Normal(jnp.zeros(d))) x = sample("x", dist.Normal(mu), sample_shape=(N,)) N, d = 100, 2 rng_key = jax.random.PRNGKey(1836) samples = sample_prior_predictive(rng_key, model, (N, d)) self.assertEqual((d,), jnp.shape(samples['mu'])) self.assertEqual((N, d), jnp.shape(samples['x'])) # crude test that samples are from model distribution (mean is within 3 times stddev) self.assertTrue(jnp.allclose(jnp.mean(samples['x'], axis=0), samples['mu'], atol=3/jnp.sqrt(N))) self.assertTrue(jnp.allclose(samples['mu'], 0, atol=3.))
def test_sample_prior_predictive_with_substitute(self): def model(N, d): mu = sample("mu", dist.Normal(jnp.zeros(d))) x = sample("x", dist.Normal(mu), sample_shape=(N,)) N, d = 100, 2 mu_fixed = jnp.array([1., -.5]) rng_key = jax.random.PRNGKey(235) samples = sample_prior_predictive(rng_key, model, (N, d), substitutes={'mu': mu_fixed}) self.assertEqual((d,), jnp.shape(samples['mu'])) self.assertEqual((N, d), jnp.shape(samples['x'])) # crude test that samples are from model distribution (mean is within 3 times stddev) self.assertTrue(jnp.allclose(jnp.mean(samples['x'], axis=0), mu_fixed, atol=3/jnp.sqrt(N))) self.assertTrue(jnp.allclose(samples['mu'], mu_fixed))
def create_toy_data(rng_key, N, d): ## Create some toy data X_rng_key, prior_pred_rng_key = jax.random.split(rng_key) X = jax.random.normal(X_rng_key, shape=(2 * N, d)) sampled_data = sample_prior_predictive(prior_pred_rng_key, model, (X, )) y = sampled_data['obs'] w_true = sampled_data['w'] intercept_true = sampled_data['intercept'] X_train = X[:N] y_train = y[:N] X_test = X[N:] y_test = y[N:] return (X_train, y_train), (X_test, y_test), (w_true, intercept_true)
def create_toy_data(rng_key, N, d): """Creates some toy data (for training and testing)""" # To spice things up, it is imbalanced: # The last component has twice as many samples as the others. mus = jnp.array([-10. * jnp.ones(d), 10. * jnp.ones(d), -2. * jnp.ones(d)]) sigs = jnp.reshape(jnp.array([0.1, 1., 0.1]), (3,1)) pis = jnp.array([1/4, 1/4, 2/4]) samples = sample_prior_predictive(rng_key, model, (3, None, 2*N, d), substitutes={ 'pis': pis, 'mus': mus, 'sigs': sigs }, with_intermediates=True) X = samples['obs'][0] z = samples['obs'][1][0] z_train = z[:N] X_train = X[:N] z_test = z[N:] X_test = X[N:] latent_vals = (z_train, z_test, mus, sigs) return X_train, X_test, latent_vals