def flax_model_by_kwargs(x, y): import flax linear_module = flax.nn.Dense.partial(features=100) nn = flax_module("nn", linear_module, inputs=x) mean = nn(x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
def flax_model_by_shape(x, y): import flax linear_module = flax.linen.Dense(features=100) nn = flax_module("nn", linear_module, input_shape=(100,)) mean = nn(x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
def guide(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": encoder = flax_module( "encoder", FlaxEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], ), input_shape=(1, hyperparams["vocab_size"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": encoder = haiku_module( "encoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], )), input_shape=(1, hyperparams["vocab_size"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) if nn_framework == "flax": concentration = encoder(batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": concentration = encoder(numpyro.prng_key(), batch_docs, is_training) numpyro.sample("theta", dist.Dirichlet(concentration))
def model(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": decoder = flax_module( "decoder", FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]), input_shape=(1, hyperparams["num_topics"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": decoder = haiku_module( "decoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"])), input_shape=(1, hyperparams["num_topics"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) theta = numpyro.sample( "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))) if nn_framework == "flax": logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": logits = decoder(numpyro.prng_key(), theta, is_training) total_count = batch_docs.sum(-1) numpyro.sample("obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs)
def model(): net = flax_module( "nn", Net(), apply_rng=["dropout"] if dropout else None, mutable=["batch_stats"] if batchnorm else None, input_shape=(4, 3), ) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: y = net(x, rngs={"dropout": numpyro.prng_key()}) else: y = net(x) numpyro.deterministic("y", y)
def flax_model(x, y): linear_module = flax.nn.Dense.partial(features=100) nn = flax_module("nn", linear_module, (100, )) mean = nn(x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)