Esempio n. 1
0
def model_nested_plates_2():
    outer = numpyro.plate('outer', 10)
    inner = numpyro.plate('inner', 5, dim=-3)
    with outer:
        x = numpyro.sample('x', dist.Normal(0., 1.))
        assert x.shape == (10, )
    with inner:
        y = numpyro.sample('y', dist.Normal(0., 1.))
        assert y.shape == (5, 1, 1)
        z = numpyro.deterministic('z', x**2)
        assert z.shape == (10, )

    with outer, inner:
        xy = numpyro.sample('xy', dist.Normal(0., 1.), sample_shape=(10, ))
        assert xy.shape == (5, 1, 10)
Esempio n. 2
0
def model_subsample_1():
    outer = numpyro.plate("outer", 20, subsample_size=10)
    inner = numpyro.plate("inner", 10, subsample_size=5, dim=-3)
    with outer:
        x = numpyro.sample("x", dist.Normal(0.0, 1.0))
        assert x.shape == (10, )
    with inner:
        y = numpyro.sample("y", dist.Normal(0.0, 1.0))
        assert y.shape == (5, 1, 1)
        z = numpyro.deterministic("z", x**2)
        assert z.shape == (10, )

    with outer, inner:
        xy = numpyro.sample("xy", dist.Normal(0.0, 1.0))
        assert xy.shape == (5, 1, 10)
Esempio n. 3
0
def model_subsample_1():
    outer = numpyro.plate('outer', 20, subsample_size=10)
    inner = numpyro.plate('inner', 10, subsample_size=5, dim=-3)
    with outer:
        x = numpyro.sample('x', dist.Normal(0., 1.))
        assert x.shape == (10, )
    with inner:
        y = numpyro.sample('y', dist.Normal(0., 1.))
        assert y.shape == (5, 1, 1)
        z = numpyro.deterministic('z', x**2)
        assert z.shape == (10, )

    with outer, inner:
        xy = numpyro.sample('xy', dist.Normal(0., 1.))
        assert xy.shape == (5, 1, 10)
Esempio n. 4
0
def model_nested_plates_2():
    outer = numpyro.plate("outer", 10)
    inner = numpyro.plate("inner", 5, dim=-3)
    with outer:
        x = numpyro.sample("x", dist.Normal(0.0, 1.0))
        assert x.shape == (10, )
    with inner:
        y = numpyro.sample("y", dist.Normal(0.0, 1.0))
        assert y.shape == (5, 1, 1)
        z = numpyro.deterministic("z", x**2)
        assert z.shape == (10, )

    with outer, inner:
        xy = numpyro.sample("xy", dist.Normal(0.0, 1.0), sample_shape=(10, ))
        assert xy.shape == (5, 1, 10)
Esempio n. 5
0
 def model(loc, scale):
     with numpyro.plate_stack("plates", shape[:len(shape) - event_dim]):
         with numpyro.plate("particles", 10000):
             if "dist_type" == "Normal":
                 numpyro.sample("x", dist.Normal(loc, scale).to_event(event_dim))
             else:
                 numpyro.sample("x", dist.StudentT(10.0, loc, scale).to_event(event_dim))
Esempio n. 6
0
def birthdays_model(
    x,
    day_of_week,
    day_of_year,
    memorial_days_indicator,
    labour_days_indicator,
    thanksgiving_days_indicator,
    w0,
    L,
    M1,
    M2,
    M3,
    y=None,
):
    intercept = sample("intercept", dist.Normal(0, 1))
    f1 = scope(trend_gp, "trend")(x, L, M1)
    f2 = scope(year_gp, "year")(x, w0, M2)
    g3 = scope(trend_gp,
               "week-trend")(x, L, M3)  # length ~ lognormal(-1, 1) in original
    weekday = scope(weekday_effect, "week")(day_of_week)
    yearday = scope(yearday_effect, "day")(day_of_year)

    # # --- special days
    memorial = scope(special_effect, "memorial")(memorial_days_indicator)
    labour = scope(special_effect, "labour")(labour_days_indicator)
    thanksgiving = scope(special_effect,
                         "thanksgiving")(thanksgiving_days_indicator)

    day = yearday + memorial + labour + thanksgiving
    # --- Combine components
    f = deterministic("f", intercept + f1 + f2 + jnp.exp(g3) * weekday + day)
    sigma = sample("sigma", dist.HalfNormal(0.5))
    with plate("obs", x.shape[0]):
        sample("y", dist.Normal(f, sigma), obs=y)
Esempio n. 7
0
def weekday_effect(day_of_week):
    with plate("plate_day_of_week", 6):
        weekday = sample("_beta", dist.Normal(0, 1))

    monday = jnp.array([-jnp.sum(weekday)])  # Monday = 0 in original
    beta = deterministic("beta", jnp.concatenate((monday, weekday)))
    return beta[day_of_week]
Esempio n. 8
0
def guide(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    encode = numpyro.module("encoder", encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    with numpyro.plate("batch", batch_dim):
        return numpyro.sample("z", dist.Normal(z_loc, z_std).to_event(1))
Esempio n. 9
0
def sample_beta_PY(alpha: float, sigma: float = 0, T: int = 10) -> jnp.ndarray:
    with numpyro.plate("beta_plate", T - 1):
        beta = numpyro.sample("beta",
                              Beta(1 - sigma, alpha + sigma * np.arange(1, T)))
    assert beta.shape == (T - 1, ), (beta.shape, T)

    return beta
Esempio n. 10
0
def model(data, obs, subsample_size):
    n, m = data.shape
    theta = numpyro.sample('theta', dist.Normal(jnp.zeros(m), .5 * jnp.ones(m)))
    with numpyro.plate('N', n, subsample_size=subsample_size):
        batch_feats = numpyro.subsample(data, event_dim=1)
        batch_obs = numpyro.subsample(obs, event_dim=0)
        numpyro.sample('obs', dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs)
Esempio n. 11
0
 def model(data):
     x = numpyro.sample("x",
                        dist.Bernoulli(0.5),
                        infer={"enumerate": "parallel"})
     with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-1):
         batch = numpyro.subsample(data, event_dim=0)
         numpyro.sample("obs", dist.Normal(x, 1), obs=batch)
Esempio n. 12
0
def model(num: int, sigma: np.ndarray, y: Optional[np.ndarray] = None) -> None:

    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.Normal(0, 5))
    with numpyro.plate("num", num):
        theta = numpyro.sample("theta", dist.Normal(mu, tau))
        numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
Esempio n. 13
0
    def model(self, home_team, away_team):

        sigma_a = pyro.sample("sigma_a", dist.HalfNormal(1.0))
        sigma_b = pyro.sample("sigma_b", dist.HalfNormal(1.0))
        mu_b = pyro.sample("mu_b", dist.Normal(0.0, 1.0))
        rho_raw = pyro.sample("rho_raw", dist.Beta(2, 2))
        rho = 2.0 * rho_raw - 1.0

        log_gamma = pyro.sample("log_gamma", dist.Normal(0, 1))

        with pyro.plate("teams", self.n_teams):
            abilities = pyro.sample(
                "abilities",
                dist.MultivariateNormal(
                    np.array([0.0, mu_b]),
                    covariance_matrix=np.array([
                        [sigma_a**2.0, rho * sigma_a * sigma_b],
                        [rho * sigma_a * sigma_b, sigma_b**2.0],
                    ]),
                ),
            )

        log_a = abilities[:, 0]
        log_b = abilities[:, 1]
        home_inds = np.array([self.team_to_index[team] for team in home_team])
        away_inds = np.array([self.team_to_index[team] for team in away_team])
        home_rate = np.exp(log_a[home_inds] + log_b[away_inds] + log_gamma)
        away_rate = np.exp(log_a[away_inds] + log_b[home_inds])

        pyro.sample("home_goals", dist.Poisson(home_rate).to_event(1))
        pyro.sample("away_goals", dist.Poisson(away_rate).to_event(1))
Esempio n. 14
0
 def model(data, subsample_size):
     mean = numpyro.sample("mean", dist.Normal().expand((3,)).to_event(1))
     with numpyro.plate(
         "batch", data.shape[0], dim=-2, subsample_size=subsample_size
     ):
         sub_data = numpyro.subsample(data, 0)
         numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data)
Esempio n. 15
0
def fulldyn_mixture_model(beliefs, y, mask):
    M, T, N, _ = beliefs[0].shape

    c0 = beliefs[-1]

    tau = .5
    with npyro.plate('N', N):
        weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

        assert weights.shape == (N, M)

        mu = npyro.sample('mu', dist.Normal(5., 5.))

        lam12 = npyro.sample('lam12',
                             dist.HalfCauchy(1.).expand([2]).to_event(1))
        lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

        _lam34 = jnp.expand_dims(lam34, -1)
        lam0 = npyro.deterministic(
            'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

        eta = npyro.sample('eta', dist.Beta(1, 10))

        scale = npyro.sample('scale', dist.HalfNormal(1.))
        theta = npyro.sample('theta', dist.HalfCauchy(5.))
        rho = jnp.exp(-theta)
        sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale

    x0 = jnp.zeros(N)

    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (lam_start, x0), jnp.arange(T))
Esempio n. 16
0
 def transition_fn(x, y):
     probs = transition_probs[x]
     x = numpyro.sample("x", dist.Categorical(probs))
     with numpyro.plate("D", D, dim=-1):
         w = numpyro.sample("w", dist.Bernoulli(0.6))
         numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
     return x, None
Esempio n. 17
0
def model(X, Y, D_H, D_Y=1):
    N, D_X = X.shape

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample(
        "w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
    assert w1.shape == (D_X, D_H)
    z1 = nonlin(jnp.matmul(X, w1))  # <= first layer of activations
    assert z1.shape == (N, D_H)

    # sample second layer
    w2 = numpyro.sample(
        "w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
    assert w2.shape == (D_H, D_H)
    z2 = nonlin(jnp.matmul(z1, w2))  # <= second layer of activations
    assert z2.shape == (N, D_H)

    # sample final layer of weights and neural network output
    w3 = numpyro.sample(
        "w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
    assert w3.shape == (D_H, D_Y)
    z3 = jnp.matmul(z2, w3)  # <= output of the neural network
    assert z3.shape == (N, D_Y)

    if Y is not None:
        assert z3.shape == Y.shape

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    with numpyro.plate("data", N):
        # note we use to_event(1) because each observation has shape (1,)
        numpyro.sample("Y", dist.Normal(z3, sigma_obs).to_event(1), obs=Y)
Esempio n. 18
0
 def guide():
     alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17)
     beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143)
     alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log)
     numpyro.sample("lambda_latent", FakeGamma(alpha_q, beta_q))
     with numpyro.plate("data", len(data)):
         pass
Esempio n. 19
0
 def model(z=None):
     p = numpyro.param("p", np.array([0.75, 0.25]))
     iz = numpyro.sample("z", dist.Categorical(p), obs=z)
     z = jnp.array([0.0, 1.0])[iz]
     logger.info("z.shape = {}".format(z.shape))
     with numpyro.plate("data", 3):
         numpyro.sample("x", dist.Normal(z, 1.0), obs=data)
Esempio n. 20
0
    def gibbs_fn(rng_key: random.PRNGKey, gibbs_sites: Dict[str, jnp.ndarray],
                 hmc_sites: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
        beta = hmc_sites['beta']
        mu = hmc_sites['mu']
        theta = hmc_sites['theta']
        L_omega = hmc_sites['L_omega']
        L_Omega = jnp.sqrt(theta.T[:, :, None]) * L_omega

        T, _ = mu.shape

        assert beta.shape == (T - 1, )
        assert mu.shape == (T, Ndim)
        assert theta.shape == (Ndim, T)
        assert L_omega.shape == (T, Ndim, Ndim)
        assert L_Omega.shape == (T, Ndim, Ndim)

        log_probs = MultivariateNormal(loc=mu,
                                       scale_tril=L_Omega).log_prob(data[:,
                                                                         None])
        assert log_probs.shape == (Npoints, T)

        log_weights = jnp.log(mix_weights(beta))
        assert log_weights.shape == (T, )

        logits = log_probs + log_weights[None, :]
        assert logits.shape == (Npoints, T)

        with numpyro.plate("z", Npoints):
            z = CategoricalLogits(logits).sample(rng_key)
        assert z.shape == (Npoints, )
        return {'z': z}
Esempio n. 21
0
 def model(data, mask):
     with numpyro.plate('N', N):
         x = numpyro.sample('x', dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample('y', dist.Delta(x, log_density=1.))
             with handlers.scale(scale=2):
                 numpyro.sample('obs', dist.Normal(x, 1), obs=data)
Esempio n. 22
0
 def guide():
     loc = numpyro.param("loc", np.zeros(()))
     scale = numpyro.param("scale", np.ones(()), constraint=constraints.positive)
     x = numpyro.sample("x", dist.Normal(loc, scale))
     with numpyro.plate("plate", len(data)):
         with handlers.mask(mask=np.invert(mask)):
             numpyro.sample("y_unobserved", dist.Normal(x, 1.0))
Esempio n. 23
0
    def gibbs_fn(rng_key: random.PRNGKey, gibbs_sites: Dict[str, jnp.ndarray],
                 hmc_sites: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
        beta = hmc_sites['beta']
        mu = hmc_sites['mu']
        sigma2 = hmc_sites['sigma2']

        T, = mu.shape
        assert beta.shape == (T - 1, )
        assert sigma2.shape == (T, )

        log_probs = Normal(loc=mu, scale=jnp.sqrt(sigma2)).log_prob(data[:,
                                                                         None])
        assert log_probs.shape == (Npoints, T)

        log_weights = jnp.log(mix_weights(beta))
        assert log_weights.shape == (T, )

        logits = log_probs + log_weights[None, :]
        assert logits.shape == (Npoints, T)

        with numpyro.plate("z", Npoints):
            z = CategoricalLogits(logits).sample(rng_key)
            assert z.shape == (Npoints, )

        return {'z': z}
Esempio n. 24
0
 def model(data, mask):
     with numpyro.plate("N", N):
         x = numpyro.sample("x", dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample("y", dist.Delta(x, log_density=1.0))
             with handlers.scale(scale=2):
                 numpyro.sample("obs", dist.Normal(x, 1), obs=data)
Esempio n. 25
0
 def guide():
     loc = numpyro.param("loc", np.zeros(3))
     cov = numpyro.param("cov", np.eye(3), constraint=constraints.positive_definite)
     x = numpyro.sample("x", dist.MultivariateNormal(loc, cov))
     with numpyro.plate("plate", len(data)):
         with handlers.mask(mask=np.invert(mask)):
             numpyro.sample("y_unobserved", dist.MultivariateNormal(x, np.eye(3)))
def seir_model(initial_seir_state,
               num_days,
               infection_data,
               recovery_data,
               step_size=0.1):
    if infection_data is not None:
        assert num_days == infection_data.shape[0]
    if recovery_data is not None:
        assert num_days == recovery_data.shape[0]
    beta = numpyro.sample('beta', dist.HalfCauchy(scale=1000.))
    gamma = numpyro.sample('gamma', dist.HalfCauchy(scale=1000.))
    sigma = numpyro.sample('sigma', dist.HalfCauchy(scale=1000.))
    mu = numpyro.sample('mu', dist.HalfCauchy(scale=1000.))
    nu = np.array(0.0)  # No vaccine yet

    def seir_update(day, seir_state):
        s = seir_state[..., 0]
        e = seir_state[..., 1]
        i = seir_state[..., 2]
        r = seir_state[..., 3]
        n = s + e + i + r
        s_upd = mu * (n - s) - beta * (s * i / n) - nu * s
        e_upd = beta * (s * i / n) - (mu + sigma) * e
        i_upd = sigma * e - (mu + gamma) * i
        r_upd = gamma * i - mu * r + nu * s
        return np.stack((s_upd, e_upd, i_upd, r_upd), axis=-1)

    num_steps = int(num_days / step_size)
    sim = runge_kutta_4(seir_update, initial_seir_state, step_size, num_steps)
    sim = np.reshape(sim, (num_days, int(1 / step_size), 4))[:, -1, :] + 1e-3
    with numpyro.plate('data', num_days):
        numpyro.sample('infections',
                       dist.Poisson(sim[:, 2]),
                       obs=infection_data)
        numpyro.sample('recovery', dist.Poisson(sim[:, 3]), obs=recovery_data)
Esempio n. 27
0
    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None
Esempio n. 28
0
    def model(self, *views: np.ndarray):
        n = views[0].shape[0]
        p = [view.shape[1] for view in views]
        # mean of column in each view of data (p_1,)
        mu = [
            numpyro.sample("mu_" + str(i),
                           dist.MultivariateNormal(0., 10 * jnp.eye(p_)))
            for i, p_ in enumerate(p)
        ]
        """
        Generates cholesky factors of correlation matrices using an LKJ prior.

        The expected use is to combine it with a vector of variances and pass it
        to the scale_tril parameter of a multivariate distribution such as MultivariateNormal.

        E.g., if theta is a (positive) vector of covariances with the same dimensionality
        as this distribution, and Omega is sampled from this distribution,
        scale_tril=torch.mm(torch.diag(sqrt(theta)), Omega)
        """
        psi = [
            numpyro.sample("psi_" + str(i), dist.LKJCholesky(p_))
            for i, p_ in enumerate(p)
        ]
        # sample weights to get from latent to data space (k,p)
        with numpyro.plate("plate_views", self.latent_dims):
            self.weights_list = [
                numpyro.sample(
                    "W_" + str(i),
                    dist.MultivariateNormal(0., jnp.diag(jnp.ones(p_))))
                for i, p_ in enumerate(p)
            ]
        with numpyro.plate("plate_i", n):
            # sample from latent z - normally disributed (n,k)
            z = numpyro.sample(
                "z",
                dist.MultivariateNormal(0.,
                                        jnp.diag(jnp.ones(self.latent_dims))))
            # sample from multivariate normal and observe data
            [
                numpyro.sample("obs" + str(i),
                               dist.MultivariateNormal((z @ W_) + mu_,
                                                       scale_tril=psi_),
                               obs=X_)
                for i, (
                    X_, psi_, mu_,
                    W_) in enumerate(zip(views, psi, mu, self.weights_list))
            ]
Esempio n. 29
0
def partially_pooled_with_logit(at_bats: jnp.ndarray, hits: Optional[jnp.ndarray] = None) -> None:

    loc = numpyro.sample("loc", dist.Normal(-1, 1))
    scale = numpyro.sample("scale", dist.HalfCauchy(1))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        alpha = numpyro.sample("alpha", dist.Normal(loc, scale))
        numpyro.sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
Esempio n. 30
0
def partially_pooled(at_bats: jnp.ndarray, hits: Optional[jnp.ndarray] = None) -> None:

    m = numpyro.sample("m", dist.Uniform(0, 1))
    kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        phi = numpyro.sample("phi", dist.Beta(m * kappa, (1 - m) * kappa))
        numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)