def haiku_model_by_kwargs_1(x, y): import haiku as hk linear_module = hk.transform(lambda x: hk.Linear(100)(x)) nn = haiku_module("nn", linear_module, x=x) mean = nn(x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
def haiku_model(x, y): import haiku as hk linear_module = hk.transform(lambda x: hk.Linear(100)(x)) nn = haiku_module("nn", linear_module, input_shape=(100,)) mean = nn(x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
def model(): transform = hk.transform_with_state if batchnorm else hk.transform nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3)) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: y = nn(numpyro.prng_key(), x) else: y = nn(x) numpyro.deterministic("y", y)
def guide(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": encoder = flax_module( "encoder", FlaxEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], ), input_shape=(1, hyperparams["vocab_size"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": encoder = haiku_module( "encoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], )), input_shape=(1, hyperparams["vocab_size"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) if nn_framework == "flax": concentration = encoder(batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": concentration = encoder(numpyro.prng_key(), batch_docs, is_training) numpyro.sample("theta", dist.Dirichlet(concentration))
def model(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": decoder = flax_module( "decoder", FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]), input_shape=(1, hyperparams["num_topics"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": decoder = haiku_module( "decoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"])), input_shape=(1, hyperparams["num_topics"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) theta = numpyro.sample( "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))) if nn_framework == "flax": logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": logits = decoder(numpyro.prng_key(), theta, is_training) total_count = batch_docs.sum(-1) numpyro.sample("obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs)
def haiku_model_by_kwargs_2(w, x, y): import haiku as hk class TestHaikuModule(hk.Module): def __init__(self, dim: int = 100): super().__init__() self._dim = dim def __call__(self, w, x): l1 = hk.Linear(self._dim, name="w_linear")(w) l2 = hk.Linear(self._dim, name="x_linear")(x) return l1 + l2 linear_module = hk.transform(lambda w, x: TestHaikuModule(100)(w, x)) nn = haiku_module("nn", linear_module, w=w, x=x) mean = nn(w, x) numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)