Example #1
0
def model(data, obs, subsample_size):
    n, m = data.shape
    theta = numpyro.sample('theta', dist.Normal(jnp.zeros(m), .5 * jnp.ones(m)))
    with numpyro.plate('N', n, subsample_size=subsample_size):
        batch_feats = numpyro.subsample(data, event_dim=1)
        batch_obs = numpyro.subsample(obs, event_dim=0)
        numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs)
Example #2
0
def model(x, y=None, hidden_dim=50, subsample_size=100):
    """BNN described in section 5 of [1].

    **References:**
        1. *Stein variational gradient descent: A general purpose bayesian inference algorithm*
        Qiang Liu and Dilin Wang (2016).
    """

    prec_nn = numpyro.sample("prec_nn", Gamma(
        1.0, 0.1))  # hyper prior for precision of nn weights and biases

    n, m = x.shape

    with numpyro.plate("l1_hidden", hidden_dim, dim=-1):
        # prior l1 bias term
        b1 = numpyro.sample(
            "nn_b1",
            Normal(
                0.0,
                1.0 / jnp.sqrt(prec_nn),
            ),
        )
        assert b1.shape == (hidden_dim, )

        with numpyro.plate("l1_feat", m, dim=-2):
            w1 = numpyro.sample("nn_w1", Normal(
                0.0, 1.0 / jnp.sqrt(prec_nn)))  # prior on l1 weights
            assert w1.shape == (m, hidden_dim)

    with numpyro.plate("l2_hidden", hidden_dim, dim=-1):
        w2 = numpyro.sample("nn_w2", Normal(
            0.0, 1.0 / jnp.sqrt(prec_nn)))  # prior on output weights

    b2 = numpyro.sample("nn_b2", Normal(
        0.0, 1.0 / jnp.sqrt(prec_nn)))  # prior on output bias term

    # precision prior on observations
    prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1))
    with numpyro.plate(
            "data",
            x.shape[0],
            subsample_size=subsample_size,
            dim=-1,
    ):
        batch_x = numpyro.subsample(x, event_dim=1)
        if y is not None:
            batch_y = numpyro.subsample(y, event_dim=0)
        else:
            batch_y = y

        numpyro.sample(
            "y",
            Normal(
                jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 /
                jnp.sqrt(prec_obs)),  # 1 hidden layer with ReLU activation
            obs=batch_y,
        )
Example #3
0
 def model(data, subsample_size):
     mean = numpyro.sample("mean", dist.Normal().expand((3,)).to_event(1))
     with numpyro.plate(
         "batch", data.shape[0], dim=-2, subsample_size=subsample_size
     ):
         sub_data = numpyro.subsample(data, 0)
         numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data)
Example #4
0
 def model(data):
     x = numpyro.sample("x",
                        dist.Bernoulli(0.5),
                        infer={"enumerate": "parallel"})
     with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1):
         batch = numpyro.subsample(data, event_dim=0)
         numpyro.sample("obs", dist.Normal(x, 1), obs=batch)
Example #5
0
def test_subsample_data():
    data = jnp.arange(100.0)
    subsample_size = 7
    with handlers.seed(rng_seed=0):
        with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx:
            assert data[idx].shape == (subsample_size,)
            subsample_data = numpyro.subsample(data, event_dim=0)
            assert subsample_data.shape == (subsample_size,)
Example #6
0
def guide(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        encoder = flax_module(
            "encoder",
            FlaxEncoder(
                hyperparams["vocab_size"],
                hyperparams["num_topics"],
                hyperparams["hidden"],
                hyperparams["dropout_rate"],
            ),
            input_shape=(1, hyperparams["vocab_size"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        encoder = haiku_module(
            "encoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuEncoder(
                    hyperparams["vocab_size"],
                    hyperparams["num_topics"],
                    hyperparams["hidden"],
                    hyperparams["dropout_rate"],
                )),
            input_shape=(1, hyperparams["vocab_size"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)

        if nn_framework == "flax":
            concentration = encoder(batch_docs,
                                    is_training,
                                    rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            concentration = encoder(numpyro.prng_key(), batch_docs,
                                    is_training)

        numpyro.sample("theta", dist.Dirichlet(concentration))
Example #7
0
def test_subsample_replay():
    data = jnp.arange(100.)
    subsample_size = 7

    with handlers.trace() as guide_trace, handlers.seed(rng_seed=0):
        with numpyro.plate("a", len(data), subsample_size=subsample_size):
            pass

    with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace):
        with numpyro.plate("a", len(data)):
            subsample_data = numpyro.subsample(data, event_dim=0)
            assert subsample_data.shape == (subsample_size, )
Example #8
0
def model(docs, hyperparams, is_training=False, nn_framework="flax"):
    if nn_framework == "flax":
        decoder = flax_module(
            "decoder",
            FlaxDecoder(hyperparams["vocab_size"],
                        hyperparams["dropout_rate"]),
            input_shape=(1, hyperparams["num_topics"]),
            # ensure PRNGKey is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    elif nn_framework == "haiku":
        decoder = haiku_module(
            "decoder",
            # use `transform_with_state` for BatchNorm
            hk.transform_with_state(
                HaikuDecoder(hyperparams["vocab_size"],
                             hyperparams["dropout_rate"])),
            input_shape=(1, hyperparams["num_topics"]),
            apply_rng=True,
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )
    else:
        raise ValueError(
            f"Invalid choice {nn_framework} for argument nn_framework")

    with numpyro.plate("documents",
                       docs.shape[0],
                       subsample_size=hyperparams["batch_size"]):
        batch_docs = numpyro.subsample(docs, event_dim=1)
        theta = numpyro.sample(
            "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"])))

        if nn_framework == "flax":
            logits = decoder(theta,
                             is_training,
                             rngs={"dropout": numpyro.prng_key()})
        elif nn_framework == "haiku":
            logits = decoder(numpyro.prng_key(), theta, is_training)

        total_count = batch_docs.sum(-1)
        numpyro.sample("obs",
                       dist.Multinomial(total_count, logits=logits),
                       obs=batch_docs)
Example #9
0
def model_subsample_2():
    data = jnp.ones((10, 1, 20))
    outer = numpyro.plate('outer', data.shape[-1], subsample_size=10)
    inner = numpyro.plate('inner', data.shape[-3], subsample_size=5, dim=-3)
    with outer:
        x = numpyro.sample('x', dist.Normal(0., 1.))
        assert x.shape == (10, )
    with inner:
        y = numpyro.sample('y', dist.Normal(0., 1.))
        assert y.shape == (5, 1, 1)
        z = numpyro.deterministic('z', x**2)
        assert z.shape == (10, )

    with outer, inner:
        xy = numpyro.sample('xy', dist.Normal(0., 1.))
        assert xy.shape == (5, 1, 10)
        subsample_data = numpyro.subsample(data, event_dim=0)
        assert subsample_data.shape == (5, 1, 10)
Example #10
0
def logistic_regression():
    data = jnp.arange(10)
    x = numpyro.sample("x", dist.Normal(0, 1))
    with numpyro.plate("N", 10, subsample_size=2):
        batch = numpyro.subsample(data, 0)
        numpyro.sample("obs", dist.Bernoulli(logits=x), obs=batch)