Beispiel #1
0
def model_4(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with mask(mask=include_prior):
        probs_w = numpyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).expand_by(
                [hidden_dim]).to_event(2),
        )
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3),
        )

    def transition_fn(carry, y):
        w_prev, x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
                x = numpyro.sample(
                    "x", dist.Categorical(Vindex(probs_x)[w, x_prev]))
                with numpyro.plate("tones", data_dim, dim=-1) as tones:
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[w, x, tones]),
                                   obs=y)
        return (w, x, t + 1), None

    w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
Beispiel #2
0
 def guide():
     loc = numpyro.param("loc", np.zeros(3))
     cov = numpyro.param("cov", np.eye(3), constraint=constraints.positive_definite)
     x = numpyro.sample("x", dist.MultivariateNormal(loc, cov))
     with numpyro.plate("plate", len(data)):
         with handlers.mask(mask=np.invert(mask)):
             numpyro.sample("y_unobserved", dist.MultivariateNormal(x, np.eye(3)))
Beispiel #3
0
 def model(data, mask):
     with numpyro.plate('N', N):
         x = numpyro.sample('x', dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample('y', dist.Delta(x, log_density=1.))
             with handlers.scale(scale=2):
                 numpyro.sample('obs', dist.Normal(x, 1), obs=data)
Beispiel #4
0
 def model(data, mask):
     with numpyro.plate("N", N):
         x = numpyro.sample("x", dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample("y", dist.Delta(x, log_density=1.0))
             with handlers.scale(scale=2):
                 numpyro.sample("obs", dist.Normal(x, 1), obs=data)
Beispiel #5
0
def model_6(sequences, lengths, args, include_prior=False):
    num_sequences, max_length, data_dim = sequences.shape

    with mask(mask=include_prior):
        # Explicitly parameterize the full tensor of transition probabilities, which
        # has hidden_dim cubed entries.
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).expand(
                [args.hidden_dim, args.hidden_dim]).to_event(2),
        )

        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, x_curr, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                probs_x_t = Vindex(probs_x)[x_prev, x_curr]
                x_prev, x_curr = x_curr, numpyro.sample(
                    "x", dist.Categorical(probs_x_t))
                with numpyro.plate("tones", data_dim, dim=-1):
                    probs_y_t = probs_y[x_curr.squeeze(-1)]
                    numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y)
        return (x_prev, x_curr, t + 1), None

    x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (x_prev, x_curr, 0),
         jnp.swapaxes(sequences, 0, 1),
         history=2)
Beispiel #6
0
def model_1(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
                with numpyro.plate("tones", data_dim, dim=-1):
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[x.squeeze(-1)]),
                                   obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
Beispiel #7
0
def model_3(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with mask(mask=include_prior):
        probs_w = numpyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = numpyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3),
        )

    def transition_fn(carry, y):
        w_prev, x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
                x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with numpyro.plate("tones", data_dim, dim=-1) as tones:
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[w, x, tones]),
                                   obs=y)
        return (w, x, t + 1), None

    w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
Beispiel #8
0
 def guide():
     loc = numpyro.param("loc", np.zeros(()))
     scale = numpyro.param("scale", np.ones(()), constraint=constraints.positive)
     x = numpyro.sample("x", dist.Normal(loc, scale))
     with numpyro.plate("plate", len(data)):
         with handlers.mask(mask=np.invert(mask)):
             numpyro.sample("y_unobserved", dist.Normal(x, 1.0))
Beispiel #9
0
 def transition_fn(carry, y):
     x_prev, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
             with numpyro.plate("tones", data_dim, dim=-1):
                 numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
     return (x, t + 1), None
Beispiel #10
0
 def transition_fn(carry, y):
     w_prev, x_prev, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
             x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev]))
             with numpyro.plate("tones", data_dim, dim=-1) as tones:
                 numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y)
     return (w, x, t + 1), None
Beispiel #11
0
def model_bpinn(p,
                t,
                Y,
                F,
                data_type,
                D_H,
                u_sigma=None,
                f_sigma=None,
                sigma_w=1):

    m = 0.15
    d = 0.15
    B = 0.2

    D_X, D_Y = 2, 1

    # sample first layer
    w1 = numpyro.sample(
        "w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w * jnp.ones(
            (D_X, D_H))))
    b1 = numpyro.sample(
        "b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w * jnp.ones((D_H, 1))))
    # sample second layer
    w2 = numpyro.sample(
        "w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w * jnp.ones(
            (D_H, D_H))))
    b2 = numpyro.sample(
        "b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w * jnp.ones((D_H, 1))))
    # sample final layer
    w3 = numpyro.sample(
        "w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w * jnp.ones(
            (D_H, D_Y))))
    b3 = numpyro.sample(
        "b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w * jnp.ones((D_Y, 1))))

    u_mu, dudt = mu_grad(p, t, w1, b1, w2, b2, w3, b3)
    dudtt = second_grad(p, t, w1, b1, w2, b2, w3, b3)

    # prior on the observation noise
    if u_sigma is None:
        prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
        u_sigma = 1.0 / jnp.sqrt(prec_u)
    if f_sigma is None:
        prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
        f_sigma = 1.0 / jnp.sqrt(prec_f)

    # observe data
    with numpyro.plate('observations', p.shape[0]):
        with handlers.mask(mask=data_type):
            u_hat = numpyro.sample("Y", dist.Normal(u_mu, u_sigma), obs=Y)
        f_mu = m * dudtt + d * dudt + B * jnp.sin(
            u_mu) - p  # Forcing physics-term, always=0
        f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)

    return u_mu, f_mu
Beispiel #12
0
 def transition_fn(carry, y):
     x_prev, x_curr, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             probs_x_t = Vindex(probs_x)[x_prev, x_curr]
             x_prev, x_curr = x_curr, numpyro.sample(
                 "x", dist.Categorical(probs_x_t))
             with numpyro.plate("tones", data_dim, dim=-1):
                 probs_y_t = probs_y[x_curr.squeeze(-1)]
                 numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y)
     return (x_prev, x_curr, t + 1), None
Beispiel #13
0
 def transition_fn(carry, y):
     w_prev, x_prev, t = carry
     with numpyro.plate("sequences", num_sequences, dim=-2):
         with mask(mask=(t < lengths)[..., None]):
             w = numpyro.sample("w", dist.Categorical(probs_w[w_prev]))
             x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
             # Note the broadcasting tricks here: to index probs_y on tensors x and y,
             # we also need a final tensor for the tones dimension. This is conveniently
             # provided by the plate associated with that dimension.
             with numpyro.plate("tones", data_dim, dim=-1) as tones:
                 numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y)
     return (w, x, t + 1), None
Beispiel #14
0
    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None
Beispiel #15
0
    def transition_fn(
            carry: Tuple[jnp.ndarray, jnp.ndarray], y: jnp.ndarray
    ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
        """One time step funciton."""

        x_prev, t = carry
        with numpyro.plate("sequence", batch, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
                with numpyro.plate("tones", data_dim, dim=-1):
                    numpyro.sample("y",
                                   dist.Bernoulli(probs_y[x.squeeze(-1)]),
                                   obs=y)
        return (x, t + 1), None
Beispiel #16
0
    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
            phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma))
        phi_t = expit(phi_logit_t)
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None
Beispiel #17
0
    def transition_fn(carry, y):
        first_capture_mask, z = carry
        # note that phi_t needs to be outside the plate, since
        # phi_t is shared across all N individuals
        phi_t = numpyro.sample("phi", dist.Uniform(0.0, 1.0))

        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None
Beispiel #18
0
def test_get_mask_optimization():
    def model():
        with numpyro.handlers.seed(rng_seed=0):
            x = numpyro.sample("x", dist.Normal(0, 1))
            numpyro.sample("y", dist.Normal(x, 1), obs=0.)
            called.add("model-always")
            if numpyro.get_mask() is not False:
                called.add("model-sometimes")
                numpyro.factor("f", x + 1)

    def guide():
        with numpyro.handlers.seed(rng_seed=1):
            x = numpyro.sample("x", dist.Normal(0, 1))
            called.add("guide-always")
            if numpyro.get_mask() is not False:
                called.add("guide-sometimes")
                numpyro.factor("g", 2 - x)

    called = set()
    trace = handlers.trace(guide).get_trace()
    handlers.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" in called
    assert "guide-sometimes" in called

    called = set()
    with handlers.mask(mask=False):
        trace = handlers.trace(guide).get_trace()
        handlers.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called

    called = set()
    Predictive(model, guide=guide, num_samples=2,
               parallel=True)(random.PRNGKey(2))
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called
Beispiel #19
0
    def transition_fn(carry, y):
        first_capture_mask, z = carry
        phi_gamma_t = numpyro.sample("phi_gamma", dist.Normal(0.0, 10.0))
        phi_t = expit(phi_beta + phi_gamma_t)
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample(
                    "z",
                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                    infer={"enumerate": "parallel"},
                )
                mu_y_t = rho * z
                numpyro.sample(
                    "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                )

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None
Beispiel #20
0
def masked_model(x, y, data_type):
    with numpyro.plate('data'):
        with handlers.mask(mask=data_type):
            Y = numpyro.sample("Y", dist.Normal(x, 1.), obs=y)
    return Y
Beispiel #21
0
 def model():
     with handlers.mask(mask=jnp.zeros(10, dtype=bool)):
         numpyro.factor('inf', -jnp.inf)
Beispiel #22
0
 def model(z=None):
     p = numpyro.param("p", np.array([0.75, 0.25]))
     z = numpyro.sample("z", dist.Categorical(p), obs=z)
     logger.info("z.shape = {}".format(z.shape))
     with numpyro.plate("data", 3), handlers.mask(mask=mask):
         numpyro.sample("x", dist.Normal(z, 1.0), obs=data)