def test_minibatch_num_total_obs_not_given(self): batch_size = 20 expected_scale = 1. result = minibatch(batch_size) self.assertAlmostEqual(expected_scale, result.scale)
def test_minibatch_scale_correct_over_single_sample(self): batch_size = 1 num_obs_total = 100 expected_scale = num_obs_total / batch_size result = minibatch(batch_size, num_obs_total=num_obs_total) self.assertAlmostEqual(expected_scale, result.scale)
def test_minibatch_scale_correct_for_true_minibatch(self): batch_size = 10 num_obs_total = 100 expected_scale = num_obs_total / batch_size result = minibatch(batch_size, num_obs_total=num_obs_total) self.assertAlmostEqual(expected_scale, result.scale)
def test_minibatch_batch_size_deduced_from_array_and_num_total_obs_not_given( self): batch_size = 20 expected_scale = 1. X = jnp.ones((batch_size, 3)) result = minibatch(X) self.assertAlmostEqual(expected_scale, result.scale)
def test_minibatch_batch_size_deduced_from_array(self): batch_size = 20 num_obs_total = 100 expected_scale = num_obs_total / batch_size X = jnp.ones((20, 3)) result = minibatch(X, num_obs_total=num_obs_total) self.assertAlmostEqual(expected_scale, result.scale)
def model_fn(X, N=None, num_obs_total=None): if N is None: N = jnp.shape(X)[0] if num_obs_total is None: num_obs_total = N mu = sample("theta", dist.Normal(1.)) with minibatch(N, num_obs_total=num_obs_total): X = sample("X", dist.Normal(mu), obs=X, sample_shape=(N, )) return X, mu
def model(batch_X, batch_y=None, num_obs_total=None): """Defines the generative probabilistic model: p(y|z,X)p(z) The model is conditioned on the observed data :param batch_X: a batch of predictors :param batch_y: a batch of observations """ assert(jnp.ndim(batch_X) == 2) batch_size, d = jnp.shape(batch_X) assert(batch_y is None or example_count(batch_y) == batch_size) z_w = sample('w', dist.Normal(jnp.zeros((d,)), jnp.ones((d,)))) # prior is N(0,I) z_intercept = sample('intercept', dist.Normal(0,1)) # prior is N(0,1) logits = batch_X.dot(z_w)+z_intercept with minibatch(batch_size, num_obs_total=num_obs_total): return sample('obs', dist.Bernoulli(logits=logits), obs=batch_y)
def guide(batch, z_dim, hidden_dim, out_dim=None, num_obs_total=None): """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|q) :param batch: a batch of observations :return: (named) sampled z from the variational (guide) distribution q(z) """ assert (jnp.ndim(batch) == 3) batch_size = jnp.shape(batch)[0] batch = jnp.reshape( batch, (batch_size, -1) ) # squash each data item into a one-dimensional array (preserving only the batch size on the first axis) out_dim = jnp.shape(batch)[1] encode = numpyro.module('encoder', encoder(hidden_dim, z_dim), (batch_size, out_dim)) with minibatch(batch_size, num_obs_total=num_obs_total): z_loc, z_std = encode( batch) # obtain mean and variance for q(z) ~ p(z|x) from encoder z = sample('z', dist.Normal(z_loc, z_std)) # z follows q(z) return z
def model(obs=None, num_obs_total=None, d=None): """Defines the generative probabilistic model: p(x|z)p(z) """ if obs is not None: assert (jnp.ndim(obs) == 2) batch_size, d = jnp.shape(obs) else: assert (num_obs_total is not None) batch_size = num_obs_total assert (d != None) z_mu = sample('mu', dist.Normal(jnp.zeros((d, )), 1.)) x_var = .1 with minibatch(batch_size, num_obs_total): x = sample('obs', dist.Normal(z_mu, x_var), obs=obs, sample_shape=(batch_size, )) return x
def model(batch_or_batchsize, z_dim, hidden_dim, out_dim=None, num_obs_total=None): """Defines the generative probabilistic model: p(x|z)p(z) The model is conditioned on the observed data :param batch: a batch of observations :param hidden_dim: dimensions of the hidden layers in the VAE :param z_dim: dimensions of the latent variable / code :param out_dim: number of dimensions in a single output sample (flattened) :return: (named) sample x from the model observation distribution p(x|z)p(z) """ if is_int_scalar(batch_or_batchsize): batch = None batch_size = batch_or_batchsize if out_dim is None: raise ValueError("if no batch is provided, out_dim must be given") else: batch = batch_or_batchsize assert (jnp.ndim(batch) == 3) batch_size = jnp.shape(batch)[0] batch = jnp.reshape( batch, (batch_size, -1) ) # squash each data item into a one-dimensional array (preserving only the batch size on the first axis) out_dim = jnp.shape(batch)[1] decode = numpyro.module('decoder', decoder(hidden_dim, out_dim), (batch_size, z_dim)) with minibatch(batch_size, num_obs_total=num_obs_total): z = sample('z', dist.Normal(jnp.zeros((z_dim, )), jnp.ones( (z_dim, )))) # prior on z is N(0,I) img_loc = decode( z ) # evaluate decoder (p(x|z)) on sampled z to get means for output bernoulli distribution x = sample( 'obs', dist.Bernoulli(img_loc), obs=batch ) # outputs x are sampled from bernoulli distribution depending on z and conditioned on the observed data return x
def test_minibatch_rejects_float_batch_size_argument(self): batch_size = 10. with self.assertRaises(TypeError): minibatch(batch_size)
def test_minibatch_rejects_tuple_batch_size_argument(self): batch_size = (2, 3) with self.assertRaises(TypeError): minibatch(batch_size)
def test_minibatch_rejects_batch_size_none(self): batch_size = None with self.assertRaises(TypeError): minibatch(batch_size)
def test_model(X, num_obs_total): with minibatch(X, num_obs_total): sample('test', MinibatchTests.DummyDist(), sample_shape=X.shape)
def test_model(batch_size, num_obs_total): with minibatch(batch_size, num_obs_total): sample('test', MinibatchTests.DummyDist(), sample_shape=(batch_size, ))