def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None: cauchy_sd = jnp.max(y) / 150 nu = numpyro.sample("nu", dist.Uniform(2, 20)) powx = numpyro.sample("powx", dist.Uniform(0, 1)) sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd)) offset_sigma = numpyro.sample( "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)) coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd)) pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1)) pow_trend = 1.5 * pow_trend_beta - 0.5 pow_season = numpyro.sample("pow_season", dist.Beta(1, 1)) level_sm = numpyro.sample("level_sm", dist.Beta(1, 2)) s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1)) init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3)) num_lim = y.shape[0] def transition_fn( carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]: level, s, moving_sum = carry season = s[0] * level**pow_season exp_val = level + coef_trend * level**pow_trend + season exp_val = jnp.clip(exp_val, a_min=0) y_t = jnp.where(t >= num_lim, exp_val, y[t]) moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0) level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season) level = level_sm * level_p + (1 - level_sm) * level level = jnp.clip(level, a_min=0) new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0] new_s = jnp.where(t >= num_lim, s[0], new_s) s = jnp.concatenate([s[1:], new_s[None]], axis=0) omega = sigma * exp_val**powx + offset_sigma y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega)) return (level, s, moving_sum), y_ level_init = y[0] s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0) moving_sum = level_init with numpyro.handlers.condition(data={"y": y[1:]}): _, ys = scan(transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, num_lim + future)) numpyro.deterministic("y_forecast", ys)
def model(data, labels): nn = random_module("nn", linear_module, { bias_name: dist.Cauchy(), weight_name: dist.Normal() }, **kwargs) logits = nn(data).squeeze(-1) numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)
def transition_fn( carry: Tuple[jnp.ndarray], t: jnp.ndarray) -> Tuple[Tuple[jnp.ndarray], jnp.ndarray]: z_prev, *_ = carry z = numpyro.sample("z", dist.Normal(z_prev, jnp.ones(z_dim))) numpyro.sample( "x", dist.Cauchy(z + jnp.matmul(covariates[t], weight) + bias, sigma)) return (z, ), None
def __init__(self, formula, data, link, family, prior=None, term_priors=None, guess=None, theta=None, name="y"): self.X = patsy.dmatrix(formula, data) self.link = link self.family = family self.name = name self.guess = guess if theta is None: '''Sample theta from prior''' _, d = self.X.shape theta = np.zeros(d) info = self.X.design_info column_names = info.column_names if term_priors is None: prior = prior if prior is not None else dist.Cauchy(0, 1) term_priors = [prior] * len(info.terms) for term, prior in zip(info.terms, term_priors): term_slice = info.term_slices[term] num_cols = len(column_names[term_slice]) theta_term = numpyro.sample(name + "_" + term.name(), prior, sample_shape=(num_cols, )) theta = jax.ops.index_update(theta, term_slice, theta_term) self.theta = theta
def model(data, labels): nn = random_module("nn", linear_module, prior={bias_name: dist.Cauchy(), weight_name: dist.Normal()}, input_shape=(dim,)) logits = nn(data).squeeze(-1) numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)