Esempio n. 1
0
 def model():
     transform = hk.transform_with_state if batchnorm else hk.transform
     nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3))
     x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
     if dropout:
         y = nn(numpyro.prng_key(), x)
     else:
         y = nn(x)
     numpyro.deterministic("y", y)
Esempio n. 2
0
def flax_model_by_kwargs(x, y):
    import flax

    linear_module = flax.linen.Dense(features=100)
    nn = flax_module("nn", linear_module, inputs=x)
    mean = nn(x)
    numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
Esempio n. 3
0
def haiku_model_by_kwargs_1(x, y):
    import haiku as hk

    linear_module = hk.transform(lambda x: hk.Linear(100)(x))
    nn = haiku_module("nn", linear_module, x=x)
    mean = nn(x)
    numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
Esempio n. 4
0
 def model(data, labels):
     nn = random_module(
         "nn",
         linear_module,
         {bias_name: dist.Cauchy(), weight_name: dist.Normal()},
         **kwargs
     )
     logits = nn(data).squeeze(-1)
     numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)
Esempio n. 5
0
def haiku_model_by_kwargs_2(w, x, y):
    import haiku as hk

    class TestHaikuModule(hk.Module):
        def __init__(self, dim: int = 100):
            super().__init__()
            self._dim = dim

        def __call__(self, w, x):
            l1 = hk.Linear(self._dim, name="w_linear")(w)
            l2 = hk.Linear(self._dim, name="x_linear")(x)
            return l1 + l2

    linear_module = hk.transform(lambda w, x: TestHaikuModule(100)(w, x))
    nn = haiku_module("nn", linear_module, w=w, x=x)
    mean = nn(w, x)
    numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)