Beispiel #1
0
    def test_minibatch_num_total_obs_not_given(self):
        batch_size = 20
        expected_scale = 1.

        result = minibatch(batch_size)

        self.assertAlmostEqual(expected_scale, result.scale)
Beispiel #2
0
    def test_minibatch_scale_correct_over_single_sample(self):
        batch_size = 1
        num_obs_total = 100
        expected_scale = num_obs_total / batch_size

        result = minibatch(batch_size, num_obs_total=num_obs_total)

        self.assertAlmostEqual(expected_scale, result.scale)
Beispiel #3
0
    def test_minibatch_scale_correct_for_true_minibatch(self):
        batch_size = 10
        num_obs_total = 100
        expected_scale = num_obs_total / batch_size

        result = minibatch(batch_size, num_obs_total=num_obs_total)

        self.assertAlmostEqual(expected_scale, result.scale)
Beispiel #4
0
    def test_minibatch_batch_size_deduced_from_array_and_num_total_obs_not_given(
            self):
        batch_size = 20
        expected_scale = 1.

        X = jnp.ones((batch_size, 3))
        result = minibatch(X)

        self.assertAlmostEqual(expected_scale, result.scale)
Beispiel #5
0
    def test_minibatch_batch_size_deduced_from_array(self):
        batch_size = 20
        num_obs_total = 100
        expected_scale = num_obs_total / batch_size

        X = jnp.ones((20, 3))
        result = minibatch(X, num_obs_total=num_obs_total)

        self.assertAlmostEqual(expected_scale, result.scale)
Beispiel #6
0
        def model_fn(X, N=None, num_obs_total=None):
            if N is None:
                N = jnp.shape(X)[0]
            if num_obs_total is None:
                num_obs_total = N

            mu = sample("theta", dist.Normal(1.))
            with minibatch(N, num_obs_total=num_obs_total):
                X = sample("X", dist.Normal(mu), obs=X, sample_shape=(N, ))
            return X, mu
Beispiel #7
0
def model(batch_X, batch_y=None, num_obs_total=None):
    """Defines the generative probabilistic model: p(y|z,X)p(z)

    The model is conditioned on the observed data
    :param batch_X: a batch of predictors
    :param batch_y: a batch of observations
    """
    assert(jnp.ndim(batch_X) == 2)
    batch_size, d = jnp.shape(batch_X)
    assert(batch_y is None or example_count(batch_y) == batch_size)

    z_w = sample('w', dist.Normal(jnp.zeros((d,)), jnp.ones((d,)))) # prior is N(0,I)
    z_intercept = sample('intercept', dist.Normal(0,1)) # prior is N(0,1)
    logits = batch_X.dot(z_w)+z_intercept

    with minibatch(batch_size, num_obs_total=num_obs_total):
        return sample('obs', dist.Bernoulli(logits=logits), obs=batch_y)
Beispiel #8
0
def guide(batch, z_dim, hidden_dim, out_dim=None, num_obs_total=None):
    """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|q)
    :param batch: a batch of observations
    :return: (named) sampled z from the variational (guide) distribution q(z)
    """
    assert (jnp.ndim(batch) == 3)
    batch_size = jnp.shape(batch)[0]
    batch = jnp.reshape(
        batch, (batch_size, -1)
    )  # squash each data item into a one-dimensional array (preserving only the batch size on the first axis)
    out_dim = jnp.shape(batch)[1]

    encode = numpyro.module('encoder', encoder(hidden_dim, z_dim),
                            (batch_size, out_dim))
    with minibatch(batch_size, num_obs_total=num_obs_total):
        z_loc, z_std = encode(
            batch)  # obtain mean and variance for q(z) ~ p(z|x) from encoder
        z = sample('z', dist.Normal(z_loc, z_std))  # z follows q(z)
        return z
Beispiel #9
0
def model(obs=None, num_obs_total=None, d=None):
    """Defines the generative probabilistic model: p(x|z)p(z)
    """
    if obs is not None:
        assert (jnp.ndim(obs) == 2)
        batch_size, d = jnp.shape(obs)
    else:
        assert (num_obs_total is not None)
        batch_size = num_obs_total
        assert (d != None)

    z_mu = sample('mu', dist.Normal(jnp.zeros((d, )), 1.))
    x_var = .1
    with minibatch(batch_size, num_obs_total):
        x = sample('obs',
                   dist.Normal(z_mu, x_var),
                   obs=obs,
                   sample_shape=(batch_size, ))
    return x
Beispiel #10
0
def model(batch_or_batchsize,
          z_dim,
          hidden_dim,
          out_dim=None,
          num_obs_total=None):
    """Defines the generative probabilistic model: p(x|z)p(z)

    The model is conditioned on the observed data

    :param batch: a batch of observations
    :param hidden_dim: dimensions of the hidden layers in the VAE
    :param z_dim: dimensions of the latent variable / code
    :param out_dim: number of dimensions in a single output sample (flattened)

    :return: (named) sample x from the model observation distribution p(x|z)p(z)
    """
    if is_int_scalar(batch_or_batchsize):
        batch = None
        batch_size = batch_or_batchsize
        if out_dim is None:
            raise ValueError("if no batch is provided, out_dim must be given")
    else:
        batch = batch_or_batchsize
        assert (jnp.ndim(batch) == 3)
        batch_size = jnp.shape(batch)[0]
        batch = jnp.reshape(
            batch, (batch_size, -1)
        )  # squash each data item into a one-dimensional array (preserving only the batch size on the first axis)
        out_dim = jnp.shape(batch)[1]

    decode = numpyro.module('decoder', decoder(hidden_dim, out_dim),
                            (batch_size, z_dim))
    with minibatch(batch_size, num_obs_total=num_obs_total):
        z = sample('z', dist.Normal(jnp.zeros((z_dim, )), jnp.ones(
            (z_dim, ))))  # prior on z is N(0,I)
        img_loc = decode(
            z
        )  # evaluate decoder (p(x|z)) on sampled z to get means for output bernoulli distribution
        x = sample(
            'obs', dist.Bernoulli(img_loc), obs=batch
        )  # outputs x are sampled from bernoulli distribution depending on z and conditioned on the observed data
        return x
Beispiel #11
0
 def test_minibatch_rejects_float_batch_size_argument(self):
     batch_size = 10.
     with self.assertRaises(TypeError):
         minibatch(batch_size)
Beispiel #12
0
 def test_minibatch_rejects_tuple_batch_size_argument(self):
     batch_size = (2, 3)
     with self.assertRaises(TypeError):
         minibatch(batch_size)
Beispiel #13
0
 def test_minibatch_rejects_batch_size_none(self):
     batch_size = None
     with self.assertRaises(TypeError):
         minibatch(batch_size)
Beispiel #14
0
 def test_model(X, num_obs_total):
     with minibatch(X, num_obs_total):
         sample('test',
                MinibatchTests.DummyDist(),
                sample_shape=X.shape)
Beispiel #15
0
 def test_model(batch_size, num_obs_total):
     with minibatch(batch_size, num_obs_total):
         sample('test',
                MinibatchTests.DummyDist(),
                sample_shape=(batch_size, ))