Esempio n. 1
0
 def model(data=None):
     with numpyro.plate("dim", 2):
         beta = numpyro.sample("beta", dist.Beta(1.0, 1.0))
     with numpyro.plate("plate", N, dim=-2):
         numpyro.deterministic("beta_sq", beta**2)
         with numpyro.plate("dim", 2):
             numpyro.sample("obs", dist.Bernoulli(beta), obs=data)
Esempio n. 2
0
    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None
Esempio n. 3
0
def GP(X, y):

    X = numpyro.deterministic("X", X)

    # Set informative priors on kernel hyperparameters.
    η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0))
    ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0))
    σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0))

    # Compute kernel
    K = rbf_kernel(X, X, η, ℓ)
    K = add_to_diagonal(K, σ)
    K = add_to_diagonal(K, wandb.config.jitter)
    # cholesky decomposition
    Lff = numpyro.deterministic("Lff", cholesky(K, lower=True))

    # Sample y according to the standard gaussian process formula
    return numpyro.sample(
        "y",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]),
                                scale_tril=Lff).expand_by(
                                    y.shape[:-1])  # for multioutput scenarios
        .to_event(y.ndim - 1),
        obs=y,
    )
Esempio n. 4
0
def SEIR_dynamics(T, params, x0, obs=None, death=None, suffix=""):
    '''Run SEIR dynamics for T time steps
    
    Uses SEIRModel.run to run dynamics with pre-determined parameters.
    '''
    
    beta0, sigma, gamma, rw_scale, drift, \
    det_prob, det_noise_scale, death_prob, death_rate, det_prob_d  = params
    
    beta = numpyro.sample("beta" + suffix,
                  ExponentialRandomWalk(loc=beta0, scale=rw_scale, drift= drift, num_steps=T-1))

    # Run ODE
    x = SEIRModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate))
    x = x[1:] # first entry duplicates x0
    numpyro.deterministic("x" + suffix, x)


    # Noisy observations
    with numpyro.handlers.scale(scale_factor=0.5):
        y = observe("y" + suffix, x[:,6], det_prob, det_noise_scale, obs = obs)
        
    with numpyro.handlers.scale(scale_factor=2.0):
        z = observe("z" + suffix, x[:,5], det_prob_d, det_noise_scale, obs = death)
        
    return beta, x, y, z
Esempio n. 5
0
    def dynamics(self, T, params, x0, confirmed=None, death=None, det_rate=None, suffix=""):
        '''Run SEIRD dynamics for T time steps'''

        beta0, sigma, gamma, rw_scale, drift, \
        det_prob, det_noise_scale, death_prob, death_rate, det_prob_d,det_prob_future  = params

        beta = numpyro.sample("beta" + suffix,
                      ExponentialRandomWalk(loc=beta0, scale=rw_scale, drift=drift, num_steps=T-1))


        if suffix != "_future":
            det_prob_rw = numpyro.sample("det_rate_rw" + suffix,
                      LogisticRandomWalk(loc=det_prob, scale=rw_scale, drift=0, num_steps=T-1))
        else:
            det_prob_rw = det_prob_future

        # Run ODE
        x = SEIRDModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate))
        x = x[1:] # first entry duplicates x0
        numpyro.deterministic("x" + suffix, x)


        # Noisy observations
        with numpyro.handlers.scale(scale_factor=0.5):
            y = observe("y" + suffix, x[:,6], det_prob_rw, det_noise_scale, obs = confirmed)

        with numpyro.handlers.scale(scale_factor=2.0):
            z = observe("z" + suffix, x[:,5], det_prob_d, det_noise_scale, obs = death)

        return beta,det_prob_rw, x, y, z
Esempio n. 6
0
def fulldyn_mixture_model(beliefs, y, mask):
    M, T, N, _ = beliefs[0].shape

    c0 = beliefs[-1]

    tau = .5
    with npyro.plate('N', N):
        weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

        assert weights.shape == (N, M)

        mu = npyro.sample('mu', dist.Normal(5., 5.))

        lam12 = npyro.sample('lam12',
                             dist.HalfCauchy(1.).expand([2]).to_event(1))
        lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

        _lam34 = jnp.expand_dims(lam34, -1)
        lam0 = npyro.deterministic(
            'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

        eta = npyro.sample('eta', dist.Beta(1, 10))

        scale = npyro.sample('scale', dist.HalfNormal(1.))
        theta = npyro.sample('theta', dist.HalfCauchy(5.))
        rho = jnp.exp(-theta)
        sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale

    x0 = jnp.zeros(N)

    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (lam_start, x0), jnp.arange(T))
Esempio n. 7
0
 def transition(state, i):
     x0, mu0 = state
     x1 = numpyro.sample('x', dist.Normal(phi * x0, q))
     mu1 = beta * mu0 + x1
     y1 = numpyro.sample('y', dist.Normal(mu1, r))
     numpyro.deterministic('y2', y1 * 2)
     return (x1, mu1), (x1, y1)
Esempio n. 8
0
def model_nested_plates_3():
    outer = numpyro.plate("outer", 10, dim=-1)
    inner = numpyro.plate("inner", 5, dim=-2)
    numpyro.deterministic("z", 1.0)

    with inner, outer:
        xy = numpyro.sample("xy", dist.Normal(jnp.zeros((5, 10)), 1.0))
        assert xy.shape == (5, 10)
Esempio n. 9
0
def model_nested_plates_3():
    outer = numpyro.plate('outer', 10, dim=-1)
    inner = numpyro.plate('inner', 5, dim=-2)
    numpyro.deterministic('z', 1.)

    with inner, outer:
        xy = numpyro.sample('xy', dist.Normal(jnp.zeros((5, 10)), 1.))
        assert xy.shape == (5, 10)
Esempio n. 10
0
    def model(y=None):
        n = y.size if y is not None else 1

        mu = numpyro.sample("mu", dist.Normal(0, 5))
        sigma = numpyro.param("sigma", 1, constraint=constraints.positive)

        y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y)
        numpyro.deterministic("z", (y - mu) / sigma)
Esempio n. 11
0
    def dynamics(self,
                 T,
                 params,
                 x0,
                 num_frozen=0,
                 confirmed=None,
                 death=None,
                 suffix=""):
        '''Run SEIRD dynamics for T time steps'''

        beta0, \
        sigma, \
        gamma, \
        rw_scale, \
        drift, \
        det_prob0, \
        confirmed_dispersion, \
        death_dispersion, \
        death_prob, \
        death_rate, \
        det_prob_d = params

        rw = frozen_random_walk("rw" + suffix,
                                num_steps=T - 1,
                                num_frozen=num_frozen)

        beta = numpyro.deterministic("beta", beta0 * np.exp(rw_scale * rw))

        det_prob = numpyro.sample(
            "det_prob" + suffix,
            LogisticRandomWalk(loc=det_prob0,
                               scale=rw_scale,
                               drift=0,
                               num_steps=T - 1))

        # Run ODE
        x = SEIRDModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate))

        numpyro.deterministic("x" + suffix, x[1:])

        x_diff = np.diff(x, axis=0)

        # Noisy observations
        with numpyro.handlers.scale(scale=0.5):
            y = observe_nb2("dy" + suffix,
                            x_diff[:, 6],
                            det_prob,
                            confirmed_dispersion,
                            obs=confirmed)

        with numpyro.handlers.scale(scale=2.0):
            z = observe_nb2("dz" + suffix,
                            x_diff[:, 5],
                            det_prob_d,
                            death_dispersion,
                            obs=death)

        return beta, det_prob, x, y, z
Esempio n. 12
0
 def model():
     transform = hk.transform_with_state if batchnorm else hk.transform
     nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3))
     x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
     if dropout:
         y = nn(numpyro.prng_key(), x)
     else:
         y = nn(x)
     numpyro.deterministic("y", y)
Esempio n. 13
0
def SEIR_dynamics_hierarchical(T,
                               params,
                               x0,
                               obs=None,
                               death=None,
                               use_rw=True,
                               suffix=""):
    '''Run SEIR dynamics for T time steps
    
    Uses SEIRModel.run to run dynamics with pre-determined parameters.
    '''

    beta, sigma, gamma, det_rate, det_noise_scale, rw_loc, rw_scale, drift, death_prob, death_rate, det_prob_d = params

    num_places, T_minus_1 = beta.shape
    assert (T_minus_1 == T - 1)

    # prep for broadcasting over time
    sigma = sigma[:, None]
    gamma = gamma[:, None]
    det_rate = det_rate[:, None]
    death_prob = death_prob[:, None]
    death_rate = death_rate[:, None]
    det_prob_d = det_prob_d[:, None]

    if use_rw:
        with numpyro.plate("places", num_places):
            rw = numpyro.sample(
                "rw" + suffix,
                ExponentialRandomWalk(loc=rw_loc,
                                      scale=rw_scale,
                                      drift=drift,
                                      num_steps=T - 1))
    else:
        rw = rw_loc

    beta *= rw

    # Run ODE
    apply_model = lambda x0, beta, sigma, gamma, death_prob, death_rate: SEIRDModel.run(
        T, x0, (beta, sigma, gamma, death_prob, death_rate))
    x = jax.vmap(apply_model)(x0, beta, sigma, gamma, death_prob, death_rate)

    x = x[:,
          1:, :]  # drop first time step from result (duplicates initial value)
    numpyro.deterministic("x" + suffix, x)

    # Noisy observations
    y = observe("y" + suffix, x[:, :, 6], det_rate, det_noise_scale, obs=obs)
    z = observe("z" + suffix,
                x[:, :, 5],
                det_prob_d,
                det_noise_scale,
                obs=death)

    return rw, x, y, z
Esempio n. 14
0
def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None:

    cauchy_sd = jnp.max(y) / 150

    nu = numpyro.sample("nu", dist.Uniform(2, 20))
    powx = numpyro.sample("powx", dist.Uniform(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
    offset_sigma = numpyro.sample(
        "offset_sigma",
        dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd))

    coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
    pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
    pow_trend = 1.5 * pow_trend_beta - 0.5
    pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

    level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
    s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
    init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    num_lim = y.shape[0]

    def transition_fn(
        carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray
    ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]:

        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = jnp.clip(exp_val, a_min=0)
        y_t = jnp.where(t >= num_lim, exp_val, y[t])

        moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality,
                                                   y[t - seasonality], 0.0)
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality,
                            y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        new_s = jnp.where(t >= num_lim, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        omega = sigma * exp_val**powx + offset_sigma
        y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega))

        return (level, s, moving_sum), y_

    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(transition_fn, (level_init, s_init, moving_sum),
                     jnp.arange(1, num_lim + future))

    numpyro.deterministic("y_forecast", ys)
Esempio n. 15
0
    def model(X, Y):
        N, P = X.shape

        log_sigma = numpyro.sample("log_sigma", dist.Normal(1.0))
        sigma = jnp.exp(log_sigma)
        beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P)))
        mean = jnp.sum(beta * X, axis=-1)
        numpyro.deterministic("mean", mean)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)
    def dynamics(self,
                 T,
                 params,
                 x0,
                 num_frozen=0,
                 confirmed=None,
                 death=None,
                 suffix=""):
        '''Run SEIRD dynamics for T time steps'''

        beta0, \
        sigma, \
        gamma, \
        rw_scale, \
        drift, \
        det_prob0, \
        det_noise_scale, \
        death_prob, \
        death_rate, \
        det_prob_d = params

        rw = frozen_random_walk("rw" + suffix,
                                num_steps=T - 1,
                                num_frozen=num_frozen)

        beta = numpyro.deterministic("beta", beta0 * np.exp(rw_scale * rw))

        det_prob = numpyro.sample(
            "det_prob" + suffix,
            LogisticRandomWalk(loc=det_prob0,
                               scale=rw_scale,
                               drift=0,
                               num_steps=T - 1))

        # Run ODE
        x = SEIRDModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate))
        x = x[1:]  # first entry duplicates x0
        numpyro.deterministic("x" + suffix, x)

        # Noisy observations
        with numpyro.handlers.scale(scale_factor=0.5):
            y = observe("y" + suffix,
                        x[:, 6],
                        det_prob,
                        det_noise_scale,
                        obs=confirmed)

        with numpyro.handlers.scale(scale_factor=2.0):
            z = observe("z" + suffix,
                        x[:, 5],
                        det_prob_d,
                        det_noise_scale,
                        obs=death)

        return beta, det_prob, x, y, z
Esempio n. 17
0
def holt_winters(y, n_seasons, future=0):
    T = y.shape[0]
    level_smoothing = numpyro.sample("level_smoothing", dist.Beta(1, 1))
    trend_smoothing = numpyro.sample("trend_smoothing", dist.Beta(1, 1))
    seasonality_smoothing = numpyro.sample("seasonality_smoothing",
                                           dist.Beta(1, 1))
    adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing)
    noise = numpyro.sample("noise", dist.HalfNormal(1))
    level_init = numpyro.sample("level_init", dist.Normal(0, 1))
    trend_init = numpyro.sample("trend_init", dist.Normal(0, 1))
    with numpyro.plate("n_seasons", n_seasons):
        seasonality_init = numpyro.sample("seasonality_init",
                                          dist.Normal(0, 1))

    def transition_fn(carry, t):
        previous_level, previous_trend, previous_seasonality = carry
        level = jnp.where(
            t < T,
            level_smoothing * (y[t] - previous_seasonality[0]) +
            (1 - level_smoothing) * (previous_level + previous_trend),
            previous_level,
        )
        trend = jnp.where(
            t < T,
            trend_smoothing * (level - previous_level) +
            (1 - trend_smoothing) * previous_trend,
            previous_trend,
        )
        new_season = jnp.where(
            t < T,
            adj_seasonality_smoothing * (y[t] -
                                         (previous_level + previous_trend)) +
            (1 - adj_seasonality_smoothing) * previous_seasonality[0],
            previous_seasonality[0],
        )
        step = jnp.where(t < T, 1, t - T + 1)
        mu = previous_level + step * previous_trend + previous_seasonality[0]
        pred = numpyro.sample("pred", dist.Normal(mu, noise))

        seasonality = jnp.concatenate(
            [previous_seasonality[1:], new_season[None]], axis=0)
        return (level, trend, seasonality), pred

    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(
            transition_fn,
            (level_init, trend_init, seasonality_init),
            jnp.arange(T + future),
        )

    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])
Esempio n. 18
0
 def model():
     net = flax_module(
         "nn",
         Net(),
         apply_rng=["dropout"] if dropout else None,
         mutable=["batch_stats"] if batchnorm else None,
         input_shape=(4, 3),
     )
     x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
     if dropout:
         y = net(x, rngs={"dropout": numpyro.prng_key()})
     else:
         y = net(x)
     numpyro.deterministic("y", y)
Esempio n. 19
0
File: gp.py Progetto: sagar87/numgp
 def _build_conditional(self, Xs=None, Xconds=None, **kwargs):
     # sets deterministic sites to sample from the condtional
     Xs = cartesian(Xs)
     Xconds = cartesian(Xconds)
     npy.deterministic(f"{self.name}_mean", self.mean_func(Xs))
     npy.deterministic(f"{self.name}_cond", self.mean_func(Xconds))
     npy.deterministic(f"{self.name}_Kss", self.cov_func(Xconds))
     npy.deterministic(f"{self.name}_Ksx", self.cov_func(Xconds, Xs))
Esempio n. 20
0
def fulldyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    c0 = beliefs[-1]

    mu = npyro.sample('mu', dist.Normal(5., 5.))

    lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1))
    lam34 = npyro.sample('lam34', dist.HalfCauchy(1.))

    _lam34 = jnp.expand_dims(lam34, -1)
    lam0 = npyro.deterministic(
        'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1))

    eta = npyro.sample('eta', dist.Beta(1, 10))

    scale = npyro.sample('scale', dist.HalfNormal(1.))
    theta = npyro.sample('theta', dist.HalfCauchy(5.))
    rho = jnp.exp(-theta)
    sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale

    x0 = jnp.zeros(1)

    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))
        noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None

    lam_start = npyro.deterministic('lam_start',
                                    lam0 + jnp.expand_dims(eta, -1) * c0)
    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, (lam_start, x0), jnp.arange(T))
Esempio n. 21
0
    def _glitch_amplitudes(self, nu):
        # nu_max = numpyro.sample("nu_max", self._nu_max)
        # numpyro.deterministic("he_nu_max", self.he_glitch.amplitude(nu_max))
        # numpyro.deterministic("cz_nu_max", self.cz_glitch.amplitude(nu_max))

        if self.window_width == "full":
            low, high = nu.min(), nu.max()
        else:
            low = self._nu_max.mean - self.window_width * self.background._delta_nu
            high = self._nu_max.mean + self.window_width * self.background._delta_nu

        he_amp = numpyro.deterministic(
            "he_amplitude", self.he_glitch._average_amplitude(low, high))
        cz_amp = numpyro.deterministic(
            "cz_amplitude", self.cz_glitch._average_amplitude(low, high))
        return he_amp, cz_amp
Esempio n. 22
0
def weekday_effect(day_of_week):
    with plate("plate_day_of_week", 6):
        weekday = sample("_beta", dist.Normal(0, 1))

    monday = jnp.array([-jnp.sum(weekday)])  # Monday = 0 in original
    beta = deterministic("beta", jnp.concatenate((monday, weekday)))
    return beta[day_of_week]
Esempio n. 23
0
def birthdays_model(
    x,
    day_of_week,
    day_of_year,
    memorial_days_indicator,
    labour_days_indicator,
    thanksgiving_days_indicator,
    w0,
    L,
    M1,
    M2,
    M3,
    y=None,
):
    intercept = sample("intercept", dist.Normal(0, 1))
    f1 = scope(trend_gp, "trend")(x, L, M1)
    f2 = scope(year_gp, "year")(x, w0, M2)
    g3 = scope(trend_gp,
               "week-trend")(x, L, M3)  # length ~ lognormal(-1, 1) in original
    weekday = scope(weekday_effect, "week")(day_of_week)
    yearday = scope(yearday_effect, "day")(day_of_year)

    # # --- special days
    memorial = scope(special_effect, "memorial")(memorial_days_indicator)
    labour = scope(special_effect, "labour")(labour_days_indicator)
    thanksgiving = scope(special_effect,
                         "thanksgiving")(thanksgiving_days_indicator)

    day = yearday + memorial + labour + thanksgiving
    # --- Combine components
    f = deterministic("f", intercept + f1 + f2 + jnp.exp(g3) * weekday + day)
    sigma = sample("sigma", dist.HalfNormal(0.5))
    with plate("obs", x.shape[0]):
        sample("y", dist.Normal(f, sigma), obs=y)
Esempio n. 24
0
def nondynamic_mixture_model(beliefs, y, mask):
    M, T, _ = beliefs[0].shape

    tau = .5
    weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M)))

    assert weights.shape == (M, )

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 5.))

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))

    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    U = jnp.log(p0)

    def transition_fn(carry, t):

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs).mask(mask[t])
        npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist))

        return None, None

    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, None, jnp.arange(T))
def model(M, A, D=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bM * M + bA * A)
    numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
Esempio n. 26
0
def nondyn_single_model(beliefs, y, mask):
    T, _ = beliefs[0].shape

    gamma = npyro.sample('gamma', dist.InverseGamma(2., 3.))

    p = npyro.sample('p', dist.Dirichlet(jnp.ones(3)))
    p0 = npyro.deterministic(
        'p0',
        jnp.concatenate([
            p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2,
            p[..., 2:] / 2
        ], -1))

    U = jnp.log(p0)

    def transition_fn(carry, t):

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))
        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return None, None

    with npyro.handlers.condition(data={"y": y}):
        scan(transition_fn, None, jnp.arange(T))
Esempio n. 27
0
def observe_poisson(name, latent, det_prob, obs=None):

    mask = True
    if obs is not None:
        mask = np.isfinite(obs) & (obs >= 0)
        obs = np.where(mask, obs, 0.0)

    det_prob = np.broadcast_to(det_prob, latent.shape)

    mean = det_prob * latent
    d = dist.Poisson(mean)
    numpyro.deterministic("mean_" + name, mean)

    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample(name, d, obs=obs)

    return y
Esempio n. 28
0
def observe_nonrandom(name, latent, det_noise_scale, obs=None):
    mask = True

    if obs is not None:
        mask = np.isfinite(obs) & (obs >= 0)
        obs = np.where(mask, obs, 0.0)

    mean = latent
    scale = det_noise_scale * mean + 1
    d = dist.TruncatedNormal(0., mean, scale)

    numpyro.deterministic("mean_" + name, mean)

    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample(name, d, obs=obs)

    return y
Esempio n. 29
0
    def __call__(self) -> Callable:
        """Samples the convective zone glitch function.

        Returns:
            function: The function :math:`f`.
        """
        log_a = numpyro.sample("log_a_cz", self.log_a)
        log_tau = numpyro.sample("log_tau_cz", self.log_tau)

        self._a = numpyro.deterministic("a_cz", 10**log_a)
        self._tau = numpyro.deterministic("tau_cz", 10**log_tau)
        self._phi = numpyro.sample("phi_cz", self.phi)

        def fn(nu):
            return self.amplitude(nu) * self.oscillation(nu)

        return fn
Esempio n. 30
0
def model_PMD(z, N, y=None, phi_prior=1 / 1000):
    z = jnp.abs(z)

    q = numpyro.sample("q", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    A = numpyro.sample("A", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    c = numpyro.sample("c", dist.Beta(1, 9))  # mean = 0.1, shape = 10
    # Dz = numpyro.deterministic("Dz", A * (1 - q) ** (z - 1) + c)
    Dz = jnp.clip(numpyro.deterministic("Dz", A * (1 - q)**(z - 1) + c), 0, 1)
    D_max = numpyro.deterministic("D_max", A + c)  # pylint: disable=unused-variable

    delta = numpyro.sample("delta", dist.Exponential(phi_prior))
    phi = numpyro.deterministic("phi", delta + 2)

    alpha = numpyro.deterministic("alpha", Dz * phi)
    beta = numpyro.deterministic("beta", (1 - Dz) * phi)

    numpyro.sample("obs", dist.BetaBinomial(alpha, beta, N), obs=y)