Exemple #1
0
    def test_sample_prior_predictive_without_intermediates(self):

        def model(N, d):
            x = sample("x", self.DistWithIntermediate(), sample_shape=(N, d))

        N, d = 100, 2
        rng_key = jax.random.PRNGKey(3781)
        samples = sample_prior_predictive(rng_key, model, (N, d))
        self.assertEqual((N, d), jnp.shape(samples['x']))
Exemple #2
0
def create_toy_data(rng_key, N, d):
    ## Create some toy data
    mu_true = jnp.ones(d)
    samples = sample_prior_predictive(rng_key, model, (None, 2 * N, d),
                                      {'mu': mu_true})
    X = samples['obs']

    X_train = X[:N]
    X_test = X[N:]

    return X_train, X_test, mu_true
Exemple #3
0
    def test_sample_prior_predictive_with_intermediates(self):

        def model(N, d):
            x = sample("x", self.DistWithIntermediate(), sample_shape=(N, d))

        N, d = 100, 2
        rng_key = jax.random.PRNGKey(2)
        samples = sample_prior_predictive(rng_key, model, (N, d), with_intermediates=True)
        self.assertEqual((N, d), jnp.shape(samples['x'][0]))
        self.assertEqual(2, len(samples['x']))
        self.assertEqual(1, len(samples['x'][1]))
        self.assertEqual((N, d, 2), jnp.shape(samples['x'][1][0]))
Exemple #4
0
    def test_sample_prior_predictive(self):
        def model(N, d):
            mu = sample("mu", dist.Normal(jnp.zeros(d)))
            x = sample("x", dist.Normal(mu), sample_shape=(N,))

        N, d = 100, 2
        rng_key = jax.random.PRNGKey(1836)
        samples = sample_prior_predictive(rng_key, model, (N, d))
        self.assertEqual((d,), jnp.shape(samples['mu']))
        self.assertEqual((N, d), jnp.shape(samples['x']))
        # crude test that samples are from model distribution (mean is within 3 times stddev)
        self.assertTrue(jnp.allclose(jnp.mean(samples['x'], axis=0), samples['mu'], atol=3/jnp.sqrt(N)))
        self.assertTrue(jnp.allclose(samples['mu'], 0, atol=3.))
Exemple #5
0
    def test_sample_prior_predictive_with_substitute(self):
        def model(N, d):
            mu = sample("mu", dist.Normal(jnp.zeros(d)))
            x = sample("x", dist.Normal(mu), sample_shape=(N,))

        N, d = 100, 2
        mu_fixed = jnp.array([1., -.5])
        rng_key = jax.random.PRNGKey(235)
        samples = sample_prior_predictive(rng_key, model, (N, d), substitutes={'mu': mu_fixed})
        self.assertEqual((d,), jnp.shape(samples['mu']))
        self.assertEqual((N, d), jnp.shape(samples['x']))
        # crude test that samples are from model distribution (mean is within 3 times stddev)
        self.assertTrue(jnp.allclose(jnp.mean(samples['x'], axis=0), mu_fixed, atol=3/jnp.sqrt(N)))
        self.assertTrue(jnp.allclose(samples['mu'], mu_fixed))
Exemple #6
0
def create_toy_data(rng_key, N, d):
    ## Create some toy data
    X_rng_key, prior_pred_rng_key = jax.random.split(rng_key)

    X = jax.random.normal(X_rng_key, shape=(2 * N, d))

    sampled_data = sample_prior_predictive(prior_pred_rng_key, model, (X, ))
    y = sampled_data['obs']
    w_true = sampled_data['w']
    intercept_true = sampled_data['intercept']

    X_train = X[:N]
    y_train = y[:N]
    X_test = X[N:]
    y_test = y[N:]

    return (X_train, y_train), (X_test, y_test), (w_true, intercept_true)
def create_toy_data(rng_key, N, d):
    """Creates some toy data (for training and testing)"""
    # To spice things up, it is imbalanced:
    # The last component has twice as many samples as the others.
    mus = jnp.array([-10. * jnp.ones(d), 10. * jnp.ones(d), -2. * jnp.ones(d)])
    sigs = jnp.reshape(jnp.array([0.1, 1., 0.1]), (3,1))
    pis = jnp.array([1/4, 1/4, 2/4])

    samples = sample_prior_predictive(rng_key, model, (3, None, 2*N, d), substitutes={
        'pis': pis, 'mus': mus, 'sigs': sigs
    }, with_intermediates=True)

    X = samples['obs'][0]
    z = samples['obs'][1][0]

    z_train = z[:N]
    X_train = X[:N]
    z_test = z[N:]
    X_test = X[N:]

    latent_vals = (z_train, z_test, mus, sigs)
    return X_train, X_test, latent_vals