コード例 #1
0
def flax_model_by_kwargs(x, y):
    import flax

    linear_module = flax.nn.Dense.partial(features=100)
    nn = flax_module("nn", linear_module, inputs=x)
    mean = nn(x)
    numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
コード例 #2
0
def flax_model_by_shape(x, y):
    import flax

    linear_module = flax.linen.Dense(features=100)
    nn = flax_module("nn", linear_module, input_shape=(100,))
    mean = nn(x)
    numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
コード例 #3
0
def guide(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        encoder = flax_module(
            "encoder",
            FlaxEncoder(
                hyperparams["vocab_size"],
                hyperparams["num_topics"],
                hyperparams["hidden"],
                hyperparams["dropout_rate"],
            ),
            input_shape=(1, hyperparams["vocab_size"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        encoder = haiku_module(
            "encoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuEncoder(
                    hyperparams["vocab_size"],
                    hyperparams["num_topics"],
                    hyperparams["hidden"],
                    hyperparams["dropout_rate"],
                )),
            input_shape=(1, hyperparams["vocab_size"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)

        if nn_framework == "flax":
            concentration = encoder(batch_docs,
                                    is_training,
                                    rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            concentration = encoder(numpyro.prng_key(), batch_docs,
                                    is_training)

        numpyro.sample("theta", dist.Dirichlet(concentration))
コード例 #4
0
def model(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        decoder = flax_module(
            "decoder",
            FlaxDecoder(hyperparams["vocab_size"],
                        hyperparams["dropout_rate"]),
            input_shape=(1, hyperparams["num_topics"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        decoder = haiku_module(
            "decoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuDecoder(hyperparams["vocab_size"],
                             hyperparams["dropout_rate"])),
            input_shape=(1, hyperparams["num_topics"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)
        theta = numpyro.sample(
            "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"])))

        if nn_framework == "flax":
            logits = decoder(theta,
                             is_training,
                             rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            logits = decoder(numpyro.prng_key(), theta, is_training)

        total_count = batch_docs.sum(-1)
        numpyro.sample("obs",
                       dist.Multinomial(total_count, logits=logits),
                       obs=batch_docs)
コード例 #5
0
 def model():
     net = flax_module(
         "nn",
         Net(),
         apply_rng=["dropout"] if dropout else None,
         mutable=["batch_stats"] if batchnorm else None,
         input_shape=(4, 3),
     )
     x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
     if dropout:
         y = net(x, rngs={"dropout": numpyro.prng_key()})
     else:
         y = net(x)
     numpyro.deterministic("y", y)
コード例 #6
0
def flax_model(x, y):
    linear_module = flax.nn.Dense.partial(features=100)
    nn = flax_module("nn", linear_module, (100, ))
    mean = nn(x)
    numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)