Exemple #1
0
def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None:

    cauchy_sd = jnp.max(y) / 150

    nu = numpyro.sample("nu", dist.Uniform(2, 20))
    powx = numpyro.sample("powx", dist.Uniform(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
    offset_sigma = numpyro.sample(
        "offset_sigma",
        dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd))

    coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
    pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
    pow_trend = 1.5 * pow_trend_beta - 0.5
    pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

    level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
    s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
    init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    num_lim = y.shape[0]

    def transition_fn(
        carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray
    ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]:

        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = jnp.clip(exp_val, a_min=0)
        y_t = jnp.where(t >= num_lim, exp_val, y[t])

        moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality,
                                                   y[t - seasonality], 0.0)
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality,
                            y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        new_s = jnp.where(t >= num_lim, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        omega = sigma * exp_val**powx + offset_sigma
        y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega))

        return (level, s, moving_sum), y_

    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(transition_fn, (level_init, s_init, moving_sum),
                     jnp.arange(1, num_lim + future))

    numpyro.deterministic("y_forecast", ys)
Exemple #2
0
 def model(data, labels):
     nn = random_module("nn", linear_module, {
         bias_name: dist.Cauchy(),
         weight_name: dist.Normal()
     }, **kwargs)
     logits = nn(data).squeeze(-1)
     numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)
Exemple #3
0
    def transition_fn(
            carry: Tuple[jnp.ndarray],
            t: jnp.ndarray) -> Tuple[Tuple[jnp.ndarray], jnp.ndarray]:

        z_prev, *_ = carry
        z = numpyro.sample("z", dist.Normal(z_prev, jnp.ones(z_dim)))
        numpyro.sample(
            "x",
            dist.Cauchy(z + jnp.matmul(covariates[t], weight) + bias, sigma))
        return (z, ), None
Exemple #4
0
    def __init__(self,
                 formula,
                 data,
                 link,
                 family,
                 prior=None,
                 term_priors=None,
                 guess=None,
                 theta=None,
                 name="y"):

        self.X = patsy.dmatrix(formula, data)
        self.link = link
        self.family = family
        self.name = name
        self.guess = guess

        if theta is None:
            '''Sample theta from prior'''
            _, d = self.X.shape

            theta = np.zeros(d)

            info = self.X.design_info
            column_names = info.column_names

            if term_priors is None:
                prior = prior if prior is not None else dist.Cauchy(0, 1)
                term_priors = [prior] * len(info.terms)

            for term, prior in zip(info.terms, term_priors):
                term_slice = info.term_slices[term]
                num_cols = len(column_names[term_slice])
                theta_term = numpyro.sample(name + "_" + term.name(),
                                            prior,
                                            sample_shape=(num_cols, ))
                theta = jax.ops.index_update(theta, term_slice, theta_term)

        self.theta = theta
Exemple #5
0
 def model(data, labels):
     nn = random_module("nn", linear_module,
                        prior={bias_name: dist.Cauchy(), weight_name: dist.Normal()},
                        input_shape=(dim,))
     logits = nn(data).squeeze(-1)
     numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)