Exemplo n.º 1
0
    def model(self, home_team, away_team):

        sigma_a = pyro.sample("sigma_a", dist.HalfNormal(1.0))
        sigma_b = pyro.sample("sigma_b", dist.HalfNormal(1.0))
        mu_b = pyro.sample("mu_b", dist.Normal(0.0, 1.0))
        rho_raw = pyro.sample("rho_raw", dist.Beta(2, 2))
        rho = 2.0 * rho_raw - 1.0

        log_gamma = pyro.sample("log_gamma", dist.Normal(0, 1))

        with pyro.plate("teams", self.n_teams):
            abilities = pyro.sample(
                "abilities",
                dist.MultivariateNormal(
                    np.array([0.0, mu_b]),
                    covariance_matrix=np.array([
                        [sigma_a**2.0, rho * sigma_a * sigma_b],
                        [rho * sigma_a * sigma_b, sigma_b**2.0],
                    ]),
                ),
            )

        log_a = abilities[:, 0]
        log_b = abilities[:, 1]
        home_inds = np.array([self.team_to_index[team] for team in home_team])
        away_inds = np.array([self.team_to_index[team] for team in away_team])
        home_rate = np.exp(log_a[home_inds] + log_b[away_inds] + log_gamma)
        away_rate = np.exp(log_a[away_inds] + log_b[home_inds])

        pyro.sample("home_goals", dist.Poisson(home_rate).to_event(1))
        pyro.sample("away_goals", dist.Poisson(away_rate).to_event(1))
def seir_model(initial_seir_state,
               num_days,
               infection_data,
               recovery_data,
               step_size=0.1):
    if infection_data is not None:
        assert num_days == infection_data.shape[0]
    if recovery_data is not None:
        assert num_days == recovery_data.shape[0]
    beta = numpyro.sample('beta', dist.HalfCauchy(scale=1000.))
    gamma = numpyro.sample('gamma', dist.HalfCauchy(scale=1000.))
    sigma = numpyro.sample('sigma', dist.HalfCauchy(scale=1000.))
    mu = numpyro.sample('mu', dist.HalfCauchy(scale=1000.))
    nu = np.array(0.0)  # No vaccine yet

    def seir_update(day, seir_state):
        s = seir_state[..., 0]
        e = seir_state[..., 1]
        i = seir_state[..., 2]
        r = seir_state[..., 3]
        n = s + e + i + r
        s_upd = mu * (n - s) - beta * (s * i / n) - nu * s
        e_upd = beta * (s * i / n) - (mu + sigma) * e
        i_upd = sigma * e - (mu + gamma) * i
        r_upd = gamma * i - mu * r + nu * s
        return np.stack((s_upd, e_upd, i_upd, r_upd), axis=-1)

    num_steps = int(num_days / step_size)
    sim = runge_kutta_4(seir_update, initial_seir_state, step_size, num_steps)
    sim = np.reshape(sim, (num_days, int(1 / step_size), 4))[:, -1, :] + 1e-3
    with numpyro.plate('data', num_days):
        numpyro.sample('infections',
                       dist.Poisson(sim[:, 2]),
                       obs=infection_data)
        numpyro.sample('recovery', dist.Poisson(sim[:, 3]), obs=recovery_data)
def model(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 1.0))
    sd_att = numpyro.sample(
        "sd_att",
        dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
    )
    sd_def = numpyro.sample(
        "sd_def",
        dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
    )

    home = numpyro.sample("home", dist.Normal(0.0, 1.0))  # home advantage

    nt = len(np.unique(home_id))

    # team-specific model parameters
    with numpyro.plate("plate_teams", nt):
        attack = numpyro.sample("attack", dist.Normal(0, sd_att))
        defend = numpyro.sample("defend", dist.Normal(0, sd_def))

    # likelihood
    theta1 = jnp.exp(alpha + home + attack[home_id] - defend[away_id])
    theta2 = jnp.exp(alpha + attack[away_id] - defend[home_id])

    with numpyro.plate("data", len(home_id)):
        numpyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        numpyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)
Exemplo n.º 4
0
    def model(self, home_team, away_team, gameweek):
        n_gameweeks = max(gameweek) + 1
        sigma_0 = pyro.sample("sigma_0", dist.HalfNormal(5))
        sigma_b = pyro.sample("sigma_b", dist.HalfNormal(5))
        gamma = pyro.sample("gamma", dist.LogNormal(0, 1))

        b = pyro.sample("b", dist.Normal(0, 1))

        loc_mu_b = pyro.sample("loc_mu_b", dist.Normal(0, 1))
        scale_mu_b = pyro.sample("scale_mu_b", dist.HalfNormal(1))

        with pyro.plate("teams", self.n_teams):

            log_a0 = pyro.sample("log_a0", dist.Normal(0, sigma_0))
            mu_b = pyro.sample(
                "mu_b",
                dist.TransformedDistribution(
                    dist.Normal(0, 1),
                    dist.transforms.AffineTransform(loc_mu_b, scale_mu_b),
                ),
            )
            sigma_rw = pyro.sample("sigma_rw", dist.HalfNormal(0.1))

            with pyro.plate("random_walk", n_gameweeks - 1):
                diffs = pyro.sample(
                    "diff",
                    dist.TransformedDistribution(
                        dist.Normal(0, 1),
                        dist.transforms.AffineTransform(0, sigma_rw)),
                )

            diffs = np.vstack((log_a0, diffs))
            log_a = np.cumsum(diffs, axis=-2)

            with pyro.plate("weeks", n_gameweeks):
                log_b = pyro.sample(
                    "log_b",
                    dist.TransformedDistribution(
                        dist.Normal(0, 1),
                        dist.transforms.AffineTransform(
                            mu_b + b * log_a, sigma_b),
                    ),
                )

        pyro.sample("log_a", dist.Delta(log_a), obs=log_a)
        home_inds = np.array([self.team_to_index[team] for team in home_team])
        away_inds = np.array([self.team_to_index[team] for team in away_team])
        home_rate = np.clip(
            log_a[gameweek, home_inds] - log_b[gameweek, away_inds] + gamma,
            -7, 2)
        away_rate = np.clip(
            log_a[gameweek, away_inds] - log_b[gameweek, home_inds], -7, 2)

        pyro.sample("home_goals", dist.Poisson(np.exp(home_rate)))
        pyro.sample("away_goals", dist.Poisson(np.exp(away_rate)))
Exemplo n.º 5
0
def sample_y(dist_y, theta, y, sigma_obs=None):
    if not sigma_obs:
        if dist_y == 'gamma':
            sigma_obs = numpyro.sample('sigma_obs', dist.Exponential(1))
        else:
            sigma_obs = numpyro.sample('sigma_obs', dist.HalfNormal(1))

    if dist_y == 'student':
        numpyro.sample('y', dist.StudentT(numpyro.sample('nu_y', dist.Gamma(1, .1)), theta, sigma_obs), obs=y)
    elif dist_y == 'normal':
        numpyro.sample('y', dist.Normal(theta, sigma_obs), obs=y)
    elif dist_y == 'lognormal':
        numpyro.sample('y', dist.LogNormal(theta, sigma_obs), obs=y)
    elif dist_y == 'gamma':
        numpyro.sample('y', dist.Gamma(jnp.exp(theta), sigma_obs), obs=y)
    elif dist_y == 'gamma_raw':
        numpyro.sample('y', dist.Gamma(theta, sigma_obs), obs=y)
    elif dist_y == 'poisson':
        numpyro.sample('y', dist.Poisson(theta), obs=y)
    elif dist_y == 'exponential':
        numpyro.sample('y', dist.Exponential(jnp.exp(theta)), obs=y)
    elif dist_y == 'exponential_raw':
        numpyro.sample('y', dist.Exponential(theta), obs=y)
    elif dist_y == 'uniform':
        numpyro.sample('y', dist.Uniform(0, 1), obs=y)
    else:
        raise NotImplementedError
Exemplo n.º 6
0
def _load_data(num_seasons: int = 10,
               batch: int = 1,
               x_dim: int = 1) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Load sequential data with peaky noize.

    ref) http://docs.pyro.ai/en/stable/_modules/pyro/contrib/examples/bart.html

    Returns:
        Time series data with shape of `(seq_len, batch, data_dim)`.
    """

    rng_key = random.PRNGKey(1234)
    rng_key_0, rng_key_1 = random.split(rng_key, 2)
    x = dist.Poisson(100).sample(rng_key_0, (70 * num_seasons, batch, x_dim))
    x += jnp.array(
        ([1] * 65 + [50] * 5) * num_seasons)[:, None, None] * random.normal(
            rng_key_1, (70 * num_seasons, batch, x_dim))

    t = jnp.arange(len(x))[:, None, None]
    t = t.repeat(batch, axis=1)

    assert isinstance(x, jnp.ndarray)
    assert isinstance(t, jnp.ndarray)
    assert x.shape[0] == t.shape[0]
    assert x.shape[1] == t.shape[1]

    return x, t
Exemplo n.º 7
0
 def model(data):
     alpha = 1 / jnp.mean(data.astype(np.float32))
     lambda1 = numpyro.sample("lambda1", dist.Exponential(alpha))
     lambda2 = numpyro.sample("lambda2", dist.Exponential(alpha))
     tau = numpyro.sample("tau", dist.Uniform(0, 1))
     lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
     numpyro.sample("obs", dist.Poisson(lambda12), obs=data)
Exemplo n.º 8
0
 def model(data):
     alpha = 1 / jnp.mean(data)
     lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
     lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
     tau = numpyro.sample('tau', dist.Uniform(0, 1))
     lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
     numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
Exemplo n.º 9
0
 def model(count_data):
     n_count_data = count_data.shape[0]
     alpha = 1 / jnp.mean(count_data.astype(np.float32))
     lambda_1 = numpyro.sample('lambda_1', dist.Exponential(alpha))
     lambda_2 = numpyro.sample('lambda_2', dist.Exponential(alpha))
     # this is the same as DiscreteUniform(0, 69)
     tau = numpyro.sample('tau', dist.Categorical(logits=jnp.zeros(70)))
     idx = jnp.arange(n_count_data)
     lambda_ = jnp.where(tau > idx, lambda_1, lambda_2)
     with numpyro.plate("data", n_count_data):
         numpyro.sample('obs', dist.Poisson(lambda_), obs=count_data)
Exemplo n.º 10
0
def test_gamma_poisson_log_prob(shape):
    gamma_conc = onp.exp(onp.random.normal(size=shape))
    gamma_rate = onp.exp(onp.random.normal(size=shape))
    value = np.arange(15)

    num_samples = 300000
    poisson_rate = onp.random.gamma(gamma_conc, 1 / gamma_rate, size=(num_samples,) + shape)
    log_probs = dist.Poisson(poisson_rate).log_prob(value)
    expected = logsumexp(log_probs, 0) - np.log(num_samples)
    actual = dist.GammaPoisson(gamma_conc, gamma_rate).log_prob(value)
    assert_allclose(actual, expected, rtol=0.05)
Exemplo n.º 11
0
def test_ZIP_log_prob(rate):
    # if gate is 0 ZIP is Poisson
    zip_ = dist.ZeroInflatedPoisson(0., rate)
    pois = dist.Poisson(rate)
    s = zip_.sample(random.PRNGKey(0), (20,))
    zip_prob = zip_.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_allclose(zip_prob, pois_prob)

    # if gate is 1 ZIP is Delta(0)
    zip_ = dist.ZeroInflatedPoisson(1., rate)
    delta = dist.Delta(0.)
    s = np.array([0., 1.])
    zip_prob = zip_.log_prob(s)
    delta_prob = delta.log_prob(s)
    assert_allclose(zip_prob, delta_prob)
Exemplo n.º 12
0
def observe_poisson(name, latent, det_prob, obs=None):

    mask = True
    if obs is not None:
        mask = np.isfinite(obs) & (obs >= 0)
        obs = np.where(mask, obs, 0.0)

    det_prob = np.broadcast_to(det_prob, latent.shape)

    mean = det_prob * latent
    d = dist.Poisson(mean)
    numpyro.deterministic("mean_" + name, mean)

    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample(name, d, obs=obs)

    return y
Exemplo n.º 13
0
def _load_data(num_seasons: int = 100,
               batch: int = 1,
               x_dim: int = 1) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Load sequential data with seasonality and trend."""

    t = jnp.sin(jnp.arange(0, 6 * jnp.pi, step=6 * jnp.pi / 700))[:, None,
                                                                  None]

    x = dist.Poisson(100).sample(random.PRNGKey(1234),
                                 (7 * num_seasons, batch, x_dim))
    x += jnp.array(np.random.rand(7 * num_seasons).cumsum(0)[:, None, None])
    x += jnp.array(([50] * 5 + [1] * 2) * num_seasons)[:, None, None]
    x = jnp.log1p(x)
    x += t * 2

    assert isinstance(x, jnp.ndarray)
    assert isinstance(t, jnp.ndarray)
    assert x.shape[0] == t.shape[0]
    assert x.shape[1] == t.shape[1]

    return x, t
Exemplo n.º 14
0
 def likelihood_func(self, yhat):
     """Return a poisson likelihood."""
     return dist.Poisson(yhat)
Exemplo n.º 15
0
 def model():
     a = numpyro.sample("a", dist.Dirichlet(jnp.ones(3)))
     b = numpyro.sample("b", dist.Categorical(a))
     c = numpyro.sample("c", dist.Normal(jnp.zeros(3), 1).to_event(1))
     d = numpyro.sample("d", dist.Poisson(jnp.exp(c[b])))
     numpyro.sample("e", dist.Normal(d, 1), obs=jnp.ones(()))
Exemplo n.º 16
0
def poisson_regression(x, N):
    rate = numpyro.sample("param", dist.Gamma(1.0, 1.0))
    batch_size = len(x) if x is not None else None
    with numpyro.plate("batch", N, batch_size):
        numpyro.sample("x", dist.Poisson(rate), obs=x)
Exemplo n.º 17
0
def model_hierarchical(y, condition=None, group=None, treatment=None, dist_y='normal', add_group_slope=False,
                       add_group_intercept=True,
                       add_condition_slope=True, group2=None, add_group2_slope=False,
                       center_intercept=True, center_slope=False, robust_slopes=False,
                       add_condition_intercept=False,
                       dist_slope=dist.Normal
                      ):
    n_subjects = np.unique(group).shape[0]

    n_conditions = np.unique(condition).shape[0]
    if (condition is None) or not n_conditions:
        add_condition_slope = False  # Override
        add_condition_intercept = False

    if center_intercept:
        intercept = numpyro.sample('intercept', dist.Normal(0, 100))
    else:
        intercept = 0
    if (group is not None) and add_group_intercept:
        if dist_y == 'poisson':
            print('poisson intercepts')
            a_group = numpyro.sample(f'mu_intercept_per_group', dist.Poisson(jnp.tile(0, n_subjects)))
            intercept += a_group
        else:
            sigma_a_group = numpyro.sample('sigma_intercept_per_group', dist.HalfNormal(100))
            a_group = numpyro.sample(f'mu_intercept_per_group', dist.Normal(jnp.tile(0, n_subjects), 10))
            intercept += (a_group[group] * sigma_a_group)

    if add_condition_intercept:
        intercept_per_condition = numpyro.sample('intercept_per_condition',
                                                 dist.Normal(jnp.tile(0, n_conditions), 100))

        sigma_intercept_per_condition = numpyro.sample('sigma_intercept_per_condition', dist.HalfNormal(100))
        intercept += intercept_per_condition[condition] * sigma_intercept_per_condition

    if not add_condition_slope:
        center_slope = True  # override center_slope

    if center_slope:
        slope = numpyro.sample('slope', dist_slope(0, 100))
    else:
        slope = 0
    if add_condition_slope:
        if robust_slopes:
            # Robust slopes:
            b_per_condition = numpyro.sample('slope_per_condition',
                                             dist.StudentT(1, jnp.tile(0, n_conditions), 100))
        else:
            b_per_condition = numpyro.sample('slope_per_condition',
                                             dist_slope(jnp.tile(0, n_conditions), 100))

        sigma_b_condition = numpyro.sample('sigma_slope_per_condition', dist.HalfNormal(100))
        slope = slope + b_per_condition[condition] * sigma_b_condition

    if (group is not None) and add_group_slope:
        sigma_b_group = numpyro.sample('sigma_slope_per_group', dist.HalfNormal(100))
        b_per_group = numpyro.sample('slope_per_group', dist_slope(jnp.tile(0, n_subjects), 100))
        slope = slope + b_per_group[group] * sigma_b_group

    if (group2 is not None) and add_group2_slope:
        sigma_b_group = numpyro.sample('sigma_slope_per_group2', dist.HalfNormal(100))
        b_per_group = numpyro.sample('slope_per_group2', dist_slope(jnp.tile(0, n_subjects), 100))
        slope = slope + b_per_group[group] * sigma_b_group

    if type(intercept) is int:
        print('Caution: No intercept')
    if treatment is not None:
        slope = slope * treatment
    
    sample_y(dist_y=dist_y, theta=intercept + slope, y=y)
Exemplo n.º 18
0
 def sample(self, key, sample_shape=()):
     key_gamma, key_poisson = random.split(key)
     rate = self._gamma.sample(key_gamma, sample_shape)
     return dist.Poisson(rate).sample(key_poisson)