def model(data, obs, subsample_size): n, m = data.shape theta = numpyro.sample('theta', dist.Normal(jnp.zeros(m), .5 * jnp.ones(m))) with numpyro.plate('N', n, subsample_size=subsample_size): batch_feats = numpyro.subsample(data, event_dim=1) batch_obs = numpyro.subsample(obs, event_dim=0) numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs)
def model(x, y=None, hidden_dim=50, subsample_size=100): """BNN described in section 5 of [1]. **References:** 1. *Stein variational gradient descent: A general purpose bayesian inference algorithm* Qiang Liu and Dilin Wang (2016). """ prec_nn = numpyro.sample("prec_nn", Gamma( 1.0, 0.1)) # hyper prior for precision of nn weights and biases n, m = x.shape with numpyro.plate("l1_hidden", hidden_dim, dim=-1): # prior l1 bias term b1 = numpyro.sample( "nn_b1", Normal( 0.0, 1.0 / jnp.sqrt(prec_nn), ), ) assert b1.shape == (hidden_dim, ) with numpyro.plate("l1_feat", m, dim=-2): w1 = numpyro.sample("nn_w1", Normal( 0.0, 1.0 / jnp.sqrt(prec_nn))) # prior on l1 weights assert w1.shape == (m, hidden_dim) with numpyro.plate("l2_hidden", hidden_dim, dim=-1): w2 = numpyro.sample("nn_w2", Normal( 0.0, 1.0 / jnp.sqrt(prec_nn))) # prior on output weights b2 = numpyro.sample("nn_b2", Normal( 0.0, 1.0 / jnp.sqrt(prec_nn))) # prior on output bias term # precision prior on observations prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1)) with numpyro.plate( "data", x.shape[0], subsample_size=subsample_size, dim=-1, ): batch_x = numpyro.subsample(x, event_dim=1) if y is not None: batch_y = numpyro.subsample(y, event_dim=0) else: batch_y = y numpyro.sample( "y", Normal( jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 / jnp.sqrt(prec_obs)), # 1 hidden layer with ReLU activation obs=batch_y, )
def model(data, subsample_size): mean = numpyro.sample("mean", dist.Normal().expand((3,)).to_event(1)) with numpyro.plate( "batch", data.shape[0], dim=-2, subsample_size=subsample_size ): sub_data = numpyro.subsample(data, 0) numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data)
def model(data): x = numpyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}) with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1): batch = numpyro.subsample(data, event_dim=0) numpyro.sample("obs", dist.Normal(x, 1), obs=batch)
def test_subsample_data(): data = jnp.arange(100.0) subsample_size = 7 with handlers.seed(rng_seed=0): with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: assert data[idx].shape == (subsample_size,) subsample_data = numpyro.subsample(data, event_dim=0) assert subsample_data.shape == (subsample_size,)
def guide(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": encoder = flax_module( "encoder", FlaxEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], ), input_shape=(1, hyperparams["vocab_size"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": encoder = haiku_module( "encoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], )), input_shape=(1, hyperparams["vocab_size"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) if nn_framework == "flax": concentration = encoder(batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": concentration = encoder(numpyro.prng_key(), batch_docs, is_training) numpyro.sample("theta", dist.Dirichlet(concentration))
def test_subsample_replay(): data = jnp.arange(100.) subsample_size = 7 with handlers.trace() as guide_trace, handlers.seed(rng_seed=0): with numpyro.plate("a", len(data), subsample_size=subsample_size): pass with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace): with numpyro.plate("a", len(data)): subsample_data = numpyro.subsample(data, event_dim=0) assert subsample_data.shape == (subsample_size, )
def model(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": decoder = flax_module( "decoder", FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]), input_shape=(1, hyperparams["num_topics"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": decoder = haiku_module( "decoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"])), input_shape=(1, hyperparams["num_topics"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) theta = numpyro.sample( "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))) if nn_framework == "flax": logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": logits = decoder(numpyro.prng_key(), theta, is_training) total_count = batch_docs.sum(-1) numpyro.sample("obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs)
def model_subsample_2(): data = jnp.ones((10, 1, 20)) outer = numpyro.plate('outer', data.shape[-1], subsample_size=10) inner = numpyro.plate('inner', data.shape[-3], subsample_size=5, dim=-3) with outer: x = numpyro.sample('x', dist.Normal(0., 1.)) assert x.shape == (10, ) with inner: y = numpyro.sample('y', dist.Normal(0., 1.)) assert y.shape == (5, 1, 1) z = numpyro.deterministic('z', x**2) assert z.shape == (10, ) with outer, inner: xy = numpyro.sample('xy', dist.Normal(0., 1.)) assert xy.shape == (5, 1, 10) subsample_data = numpyro.subsample(data, event_dim=0) assert subsample_data.shape == (5, 1, 10)
def logistic_regression(): data = jnp.arange(10) x = numpyro.sample("x", dist.Normal(0, 1)) with numpyro.plate("N", 10, subsample_size=2): batch = numpyro.subsample(data, 0) numpyro.sample("obs", dist.Bernoulli(logits=x), obs=batch)