def logistic_random_effects(positions, annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] - theta numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def model(self, home_team, away_team): sigma_a = pyro.sample("sigma_a", dist.HalfNormal(1.0)) sigma_b = pyro.sample("sigma_b", dist.HalfNormal(1.0)) mu_b = pyro.sample("mu_b", dist.Normal(0.0, 1.0)) rho_raw = pyro.sample("rho_raw", dist.Beta(2, 2)) rho = 2.0 * rho_raw - 1.0 log_gamma = pyro.sample("log_gamma", dist.Normal(0, 1)) with pyro.plate("teams", self.n_teams): abilities = pyro.sample( "abilities", dist.MultivariateNormal( np.array([0.0, mu_b]), covariance_matrix=np.array([ [sigma_a**2.0, rho * sigma_a * sigma_b], [rho * sigma_a * sigma_b, sigma_b**2.0], ]), ), ) log_a = abilities[:, 0] log_b = abilities[:, 1] home_inds = np.array([self.team_to_index[team] for team in home_team]) away_inds = np.array([self.team_to_index[team] for team in away_team]) home_rate = np.exp(log_a[home_inds] + log_b[away_inds] + log_gamma) away_rate = np.exp(log_a[away_inds] + log_b[home_inds]) pyro.sample("home_goals", dist.Poisson(home_rate).to_event(1)) pyro.sample("away_goals", dist.Poisson(away_rate).to_event(1))
def model(X: DeviceArray) -> DeviceArray: """Gamma-Poisson hierarchical model for daily sales forecasting Args: X: input data Returns: output data """ n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target eps = 1e-12 # epsilon plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) plate_days = numpyro.plate(Plate.days, n_days, dim=-1) disp_param_mu = numpyro.sample(Site.disp_param_mu, dist.Normal(loc=4., scale=1.)) disp_param_sigma = numpyro.sample(Site.disp_param_sigma, dist.HalfNormal(scale=1.)) with plate_stores: disp_param_offsets = numpyro.sample( Site.disp_param_offsets, dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1)) disp_params = disp_param_mu + disp_param_offsets * disp_param_sigma disp_params = numpyro.sample(Site.disp_params, dist.Delta(disp_params), obs=disp_params) with plate_features: coef_mus = numpyro.sample( Site.coef_mus, dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features))) coef_sigmas = numpyro.sample( Site.coef_sigmas, dist.HalfNormal(scale=2. * jnp.ones(n_features))) with plate_stores: coef_offsets = numpyro.sample( Site.coef_offsets, dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.)) coefs = coef_mus + coef_offsets * coef_sigmas coefs = numpyro.sample(Site.coefs, dist.Delta(coefs), obs=coefs) with plate_days, plate_stores: targets = X[..., -1] features = jnp.nan_to_num(X[..., :-1]) # padded features to 0 is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets), jnp.ones_like(targets)) not_observed = 1 - is_observed means = (is_observed * jnp.exp( jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) + not_observed * eps) betas = is_observed * jnp.exp(-disp_params) + not_observed alphas = means * betas return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas), obs=jnp.nan_to_num(targets))
def model(self, home_team, away_team, gameweek): n_gameweeks = max(gameweek) + 1 sigma_0 = pyro.sample("sigma_0", dist.HalfNormal(5)) sigma_b = pyro.sample("sigma_b", dist.HalfNormal(5)) gamma = pyro.sample("gamma", dist.LogNormal(0, 1)) b = pyro.sample("b", dist.Normal(0, 1)) loc_mu_b = pyro.sample("loc_mu_b", dist.Normal(0, 1)) scale_mu_b = pyro.sample("scale_mu_b", dist.HalfNormal(1)) with pyro.plate("teams", self.n_teams): log_a0 = pyro.sample("log_a0", dist.Normal(0, sigma_0)) mu_b = pyro.sample( "mu_b", dist.TransformedDistribution( dist.Normal(0, 1), dist.transforms.AffineTransform(loc_mu_b, scale_mu_b), ), ) sigma_rw = pyro.sample("sigma_rw", dist.HalfNormal(0.1)) with pyro.plate("random_walk", n_gameweeks - 1): diffs = pyro.sample( "diff", dist.TransformedDistribution( dist.Normal(0, 1), dist.transforms.AffineTransform(0, sigma_rw)), ) diffs = np.vstack((log_a0, diffs)) log_a = np.cumsum(diffs, axis=-2) with pyro.plate("weeks", n_gameweeks): log_b = pyro.sample( "log_b", dist.TransformedDistribution( dist.Normal(0, 1), dist.transforms.AffineTransform( mu_b + b * log_a, sigma_b), ), ) pyro.sample("log_a", dist.Delta(log_a), obs=log_a) home_inds = np.array([self.team_to_index[team] for team in home_team]) away_inds = np.array([self.team_to_index[team] for team in away_team]) home_rate = np.clip( log_a[gameweek, home_inds] - log_b[gameweek, away_inds] + gamma, -7, 2) away_rate = np.clip( log_a[gameweek, away_inds] - log_b[gameweek, home_inds], -7, 2) pyro.sample("home_goals", dist.Poisson(np.exp(home_rate))) pyro.sample("away_goals", dist.Poisson(np.exp(away_rate)))
def model(a=None, b=None, z=None): int_term = numpyro.sample('a', dist.Normal(0., 0.2)) x_term, y_term = 0., 0. if a is not None: x = numpyro.sample('x', dist.HalfNormal(0.5)) x_term = a * x if b is not None: y = numpyro.sample('y', dist.HalfNormal(0.5)) y_term = b * y sigma = numpyro.sample('sigma', dist.Exponential(1.)) mu = int_term + x_term + y_term numpyro.sample('obs', dist.Normal(mu, sigma), obs=z)
def model(a=None, b=None, z=None): int_term = numpyro.sample("a", dist.Normal(0.0, 0.2)) x_term, y_term = 0.0, 0.0 if a is not None: x = numpyro.sample("x", dist.HalfNormal(0.5)) x_term = a * x if b is not None: y = numpyro.sample("y", dist.HalfNormal(0.5)) y_term = b * y sigma = numpyro.sample("sigma", dist.Exponential(1.0)) mu = int_term + x_term + y_term numpyro.sample("obs", dist.Normal(mu, sigma), obs=z)
def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): numpyro.sample('x', x_prior) numpyro.sample('y', y_prior) params = {'x': jnp.array(-5.), 'y': jnp.array(7.)} model = handlers.seed(model, random.PRNGKey(0)) inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support) } expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian( params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian( params['y'], expected_samples['y'])) reparam_model = handlers.reparam(model, {'y': TransformReparam()}) base_params = {'x': params['x'], 'y_base': params['y']} actual_samples = constrain_fn(handlers.seed(reparam_model, random.PRNGKey(0)), (), {}, base_params, return_deterministic=True) actual_potential_energy = potential_energy(reparam_model, (), {}, base_params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def hierarchical_dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 4 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is # invariant up to a constant, so we'll follow [1]: fix the last term of `beta` # to 0 and only define hyperpriors for the first `num_classes - 1` terms. zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): # non-centered parameterization with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) # pad 0 to the last item beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def model(X, Y, hypers): S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0] sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['alpha3'])) phi = sigma * (S / np.sqrt(N)) / (P - S) eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi)) msq = numpyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1'])) xisq = numpyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2'])) eta2 = np.square(eta1) * np.sqrt(xisq) / msq lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(P))) kappa = np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam)) # sample observation noise var_obs = numpyro.sample( "var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs'])) # compute kernel kX = kappa * X k = kernel(kX, kX, eta1, eta2, hypers['c']) + var_obs * np.eye(N) assert k.shape == (N, N) # sample Y according to the standard gaussian process formula numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k), obs=Y)
def sample_y(dist_y, theta, y, sigma_obs=None): if not sigma_obs: if dist_y == 'gamma': sigma_obs = numpyro.sample('sigma_obs', dist.Exponential(1)) else: sigma_obs = numpyro.sample('sigma_obs', dist.HalfNormal(1)) if dist_y == 'student': numpyro.sample('y', dist.StudentT(numpyro.sample('nu_y', dist.Gamma(1, .1)), theta, sigma_obs), obs=y) elif dist_y == 'normal': numpyro.sample('y', dist.Normal(theta, sigma_obs), obs=y) elif dist_y == 'lognormal': numpyro.sample('y', dist.LogNormal(theta, sigma_obs), obs=y) elif dist_y == 'gamma': numpyro.sample('y', dist.Gamma(jnp.exp(theta), sigma_obs), obs=y) elif dist_y == 'gamma_raw': numpyro.sample('y', dist.Gamma(theta, sigma_obs), obs=y) elif dist_y == 'poisson': numpyro.sample('y', dist.Poisson(theta), obs=y) elif dist_y == 'exponential': numpyro.sample('y', dist.Exponential(jnp.exp(theta)), obs=y) elif dist_y == 'exponential_raw': numpyro.sample('y', dist.Exponential(theta), obs=y) elif dist_y == 'uniform': numpyro.sample('y', dist.Uniform(0, 1), obs=y) else: raise NotImplementedError
def model(wave, mne, n, eta=E): pf = numpyro.sample('penum', dist.HalfNormal(jnp.ones(n))) y = numpyro.sample('y', dist.Normal(0, 1), obs=jnp.power(wave - jnp.matmul(mne, pf), 2) + eta * jnp.sum(pf)) return y
def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): sample('x', x_prior) sample('y', y_prior) params = {'x': np.array(-5.), 'y': np.array(7.)} model = seed(model, random.PRNGKey(0)) inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support) } expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian( params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian( params['y'], expected_samples['y'])) base_inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.base_dist.support) } actual_samples = constrain_fn(seed(model, random.PRNGKey(0)), (), {}, base_inv_transforms, params) actual_potential_energy = potential_energy(model, (), {}, base_inv_transforms, params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def fulldyn_mixture_model(beliefs, y, mask): M, T, N, _ = beliefs[0].shape c0 = beliefs[-1] tau = .5 with npyro.plate('N', N): weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (N, M) mu = npyro.sample('mu', dist.Normal(5., 5.)) lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) scale = npyro.sample('scale', dist.HalfNormal(1.)) theta = npyro.sample('theta', dist.HalfCauchy(5.)) rho = jnp.exp(-theta) sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale x0 = jnp.zeros(N) def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (lam_start, x0), jnp.arange(T))
def birthdays_model( x, day_of_week, day_of_year, memorial_days_indicator, labour_days_indicator, thanksgiving_days_indicator, w0, L, M1, M2, M3, y=None, ): intercept = sample("intercept", dist.Normal(0, 1)) f1 = scope(trend_gp, "trend")(x, L, M1) f2 = scope(year_gp, "year")(x, w0, M2) g3 = scope(trend_gp, "week-trend")(x, L, M3) # length ~ lognormal(-1, 1) in original weekday = scope(weekday_effect, "week")(day_of_week) yearday = scope(yearday_effect, "day")(day_of_year) # # --- special days memorial = scope(special_effect, "memorial")(memorial_days_indicator) labour = scope(special_effect, "labour")(labour_days_indicator) thanksgiving = scope(special_effect, "thanksgiving")(thanksgiving_days_indicator) day = yearday + memorial + labour + thanksgiving # --- Combine components f = deterministic("f", intercept + f1 + f2 + jnp.exp(g3) * weekday + day) sigma = sample("sigma", dist.HalfNormal(0.5)) with plate("obs", x.shape[0]): sample("y", dist.Normal(f, sigma), obs=y)
def item_difficulty(annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): eta = numpyro.sample( "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample( "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"}) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", annotations.shape[-1]): numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)
def model(X, Y, hypers): S, P, N = hypers["expected_sparsity"], X.shape[1], X.shape[0] sigma = numpyro.sample("sigma", dist.HalfNormal(hypers["alpha3"])) phi = sigma * (S / jnp.sqrt(N)) / (P - S) eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi)) msq = numpyro.sample("msq", dist.InverseGamma(hypers["alpha1"], hypers["beta1"])) xisq = numpyro.sample("xisq", dist.InverseGamma(hypers["alpha2"], hypers["beta2"])) eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P))) kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam)) # compute kernel kX = kappa * X k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N) assert k.shape == (N, N) # sample Y according to the standard gaussian process formula numpyro.sample( "Y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k), obs=Y, )
def model_hierarchical_next(y, condition=None, group=None, treatment=None, dist_y='normal', add_group_slope=False, add_group_intercept=True): # Hyperpriors: a = numpyro.sample('a', dist.Normal(0., 5)) sigma_a = numpyro.sample('sigma_a', dist.HalfNormal(1)) b = numpyro.sample('b', dist.Normal(0., 1)) sigma_b = numpyro.sample('sigma_b', dist.HalfNormal(1)) with numpyro.plate('n_conditions', np.unique(condition).size): # if add_group_slope else nullcontext(): # Varying slopes: b_condition = numpyro.sample('slope_per_group', dist.Normal(b, sigma_b)) with numpyro.plate('n_groups', np.unique(group).size): # if add_group_intercept else nullcontext(): # Varying intercepts: a_group = numpyro.sample('a_group', dist.Normal(a, sigma_a)) theta = a_group[group] + b_condition[condition] * treatment sample_y(dist_y=dist_y, theta=theta, y=y)
def gmm(data, num_components=3): mus = numpyro.sample('mus', dist.Normal(jnp.zeros(num_components), jnp.ones(num_components) * 100.).to_event(1)) sigmas = numpyro.sample('sigmas', dist.HalfNormal(jnp.ones(num_components) * 100.).to_event(1)) mixture_probs = numpyro.sample('mixture_probs', dist.Dirichlet( jnp.ones(num_components) / num_components)) with numpyro.plate('data', len(data), dim=-1): z = numpyro.sample('z', dist.Categorical(mixture_probs)) numpyro.sample('ll', dist.Normal(mus[z], sigmas[z]), obs=data)
def model_single(y, condition, group=None, dist_y='normal', add_group_intercept=True, add_intercept=False, **kwargs): n_conditions = np.unique(condition).shape[0] sigma_neuron = numpyro.sample('sigma', dist.HalfNormal(1)) a = numpyro.sample('mu', dist.Normal(0, 1)) a_neuron_per_condition = numpyro.sample('mu_per_condition', dist.Normal(jnp.tile(a, n_conditions), sigma_neuron)) theta = a_neuron_per_condition[condition] if add_intercept: theta += a if group is not None and add_group_intercept: if dist_y == 'poisson': a_group = numpyro.sample('mu_intercept_per_group', dist.HalfNormal(jnp.tile(10, np.unique(group).shape[0]))) else: sigma_group = numpyro.sample('sigma_intercept_per_group', dist.HalfNormal(1)) a_group = numpyro.sample('mu_intercept_per_group', dist.Normal(jnp.tile(0, np.unique(group).shape[0]), sigma_group)) theta += a_group[group] sample_y(dist_y=dist_y, theta=theta, y=y)
def holt_winters(y, n_seasons, future=0): T = y.shape[0] level_smoothing = numpyro.sample("level_smoothing", dist.Beta(1, 1)) trend_smoothing = numpyro.sample("trend_smoothing", dist.Beta(1, 1)) seasonality_smoothing = numpyro.sample("seasonality_smoothing", dist.Beta(1, 1)) adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing) noise = numpyro.sample("noise", dist.HalfNormal(1)) level_init = numpyro.sample("level_init", dist.Normal(0, 1)) trend_init = numpyro.sample("trend_init", dist.Normal(0, 1)) with numpyro.plate("n_seasons", n_seasons): seasonality_init = numpyro.sample("seasonality_init", dist.Normal(0, 1)) def transition_fn(carry, t): previous_level, previous_trend, previous_seasonality = carry level = jnp.where( t < T, level_smoothing * (y[t] - previous_seasonality[0]) + (1 - level_smoothing) * (previous_level + previous_trend), previous_level, ) trend = jnp.where( t < T, trend_smoothing * (level - previous_level) + (1 - trend_smoothing) * previous_trend, previous_trend, ) new_season = jnp.where( t < T, adj_seasonality_smoothing * (y[t] - (previous_level + previous_trend)) + (1 - adj_seasonality_smoothing) * previous_seasonality[0], previous_seasonality[0], ) step = jnp.where(t < T, 1, t - T + 1) mu = previous_level + step * previous_trend + previous_seasonality[0] pred = numpyro.sample("pred", dist.Normal(mu, noise)) seasonality = jnp.concatenate( [previous_seasonality[1:], new_season[None]], axis=0) return (level, trend, seasonality), pred with numpyro.handlers.condition(data={"pred": y}): _, preds = scan( transition_fn, (level_init, trend_init, seasonality_init), jnp.arange(T + future), ) if future > 0: numpyro.deterministic("y_forecast", preds[-future:])
def glmm(dept, male, applications, admit): v_mu = numpyro.sample('v_mu', dist.Normal(0, np.array([4., 1.]))) sigma = numpyro.sample('sigma', dist.HalfNormal(np.ones(2))) L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2)) scale_tril = sigma[..., np.newaxis] * L_Rho # non-centered parameterization num_dept = len(onp.unique(dept)) z = numpyro.sample('z', dist.Normal(np.zeros((num_dept, 2)), 1)) v = np.dot(scale_tril, z.T).T logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def glmm(dept, male, applications, admit): v_mu = sample('v_mu', dist.Normal(0, np.array([4., 1.]))) sigma = sample('sigma', dist.HalfNormal(np.ones(2))) L_Rho = sample('L_Rho', dist.LKJCholesky(2)) scale_tril = np.expand_dims(sigma, axis=-1) * L_Rho # non-centered parameterization num_dept = len(onp.unique(dept)) z = sample('z', dist.Normal(np.zeros((num_dept, 2)), 1)) v = np.squeeze(np.matmul(np.expand_dims(scale_tril, axis=-3), np.expand_dims(z, axis=-1)), axis=-1) logits = v_mu[..., :1] + v[..., dept, 0] + (v_mu[..., 1:] + v[..., dept, 1]) * male sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def model_hierarchical_for_render(y, condition=None, group=None, treatment=None, dist_y='normal', add_group_slope=False, add_group_intercept=True, add_condition_slope=True): # Hyperpriors: a = numpyro.sample('hyper_a', dist.Normal(0., 5)) sigma_a = numpyro.sample('hyper_sigma_a', dist.HalfNormal(1)) # with numpyro.plate('n_groups1', np.unique(group).size) if add_group_slope else nullcontext(): # sigma_b = numpyro.sample('sigma_b', dist.HalfNormal(1)) # b = numpyro.sample('b', dist.Normal(0., 1)) # with numpyro.plate('n_conditions', np.unique(condition).size): # # Varying slopes: # b_condition = numpyro.sample('slope', dist.Normal(b, sigma_b)) sigma_b = numpyro.sample('hyper_sigma_b_condition', dist.HalfNormal(1)) b = numpyro.sample('hyper_b_condition', dist.Normal(0., 1)) with ( # numpyro.plate('n_groups1', np.unique(group).size) if add_group_slope else numpyro.plate('n_conditions', np.unique(condition).size)): # Varying slopes: b_condition = numpyro.sample('slope_per_condition', dist.Normal(b, sigma_b)) sigma_b = numpyro.sample('hyper_sigma_b_group', dist.HalfNormal(1)) b = numpyro.sample('hyper_b_group', dist.Normal(0., 1)) if add_group_slope: with numpyro.plate('n_groups', np.unique(group).size): # Varying slopes: b_group = numpyro.sample('b_group', dist.Normal(b, sigma_b)) else: b_group = 0 if add_group_intercept: with numpyro.plate('n_groups', np.unique(group).size): # Varying intercepts: a_group = numpyro.sample('a_group', dist.Normal(a, sigma_a)) theta = a_group[group] + (b_condition[condition] + b_group[group]) * treatment else: theta = a + (b_condition[condition] + b_group[group]) * treatment sample_y(dist_y=dist_y, theta=theta, y=y)
def fulldyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape c0 = beliefs[-1] mu = npyro.sample('mu', dist.Normal(5., 5.)) lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) scale = npyro.sample('scale', dist.HalfNormal(1.)) theta = npyro.sample('theta', dist.HalfCauchy(5.)) rho = jnp.exp(-theta) sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale x0 = jnp.zeros(1) def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (lam_start, x0), jnp.arange(T))
def sample_intervention_effects(nCMs, intervention_prior=None): """ Sample interventions from some options :param nCMs: number of interventions :param intervention_prior: dictionary with relevant keys. usually type and scale :return: sample parameters """ if intervention_prior is None: intervention_prior = { "type": "asymmetric_laplace", "scale": 30, "asymmetry": 0.5, } if intervention_prior["type"] == "trunc_normal": alpha_i = numpyro.sample( "alpha_i", dist.TruncatedNormal(low=-0.1, loc=jnp.zeros(nCMs), scale=intervention_prior["scale"]), ) elif intervention_prior["type"] == "half_normal": alpha_i = numpyro.sample( "alpha_i", dist.HalfNormal(scale=jnp.ones(nCMs) * intervention_prior["scale"]), ) elif intervention_prior["type"] == "normal": alpha_i = numpyro.sample( "alpha_i", dist.Normal(loc=jnp.zeros(nCMs), scale=intervention_prior["scale"]), ) elif intervention_prior["type"] == "asymmetric_laplace": alpha_i = numpyro.sample( "alpha_i", AsymmetricLaplace( asymmetry=intervention_prior["asymmetry"], scale=jnp.ones(nCMs) * intervention_prior["scale"], ), ) else: raise ValueError( "Intervention effect prior must take a value in [trunc_normal, normal, asymmetric_laplace, half_normal]" ) return alpha_i
def glmm(dept, male, applications, admit=None): v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.]))) sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2))) L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2)) scale_tril = sigma[..., jnp.newaxis] * L_Rho # non-centered parameterization num_dept = len(np.unique(dept)) z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1)) v = jnp.dot(scale_tril, z.T).T logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male if admit is None: # we use a Delta site to record probs for predictive distribution probs = expit(logits) numpyro.sample('probs', dist.Delta(probs), obs=probs) numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def yearday_effect(day_of_year): slab_df = 50 # 100 in original case study slab_scale = 2 scale_global = 0.1 tau = sample("tau", dist.HalfNormal( 2 * scale_global)) # Original uses half-t with 100df c_aux = sample("c_aux", dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df)) c = slab_scale * jnp.sqrt(c_aux) # Jan 1st: Day 0 # Feb 29th: Day 59 # Dec 31st: Day 365 with plate("plate_day_of_year", 366): lam = sample("lam", dist.HalfCauchy(scale=1)) lam_tilde = jnp.sqrt(c) * lam / jnp.sqrt(c + (tau * lam)**2) beta = sample("beta", dist.Normal(loc=0, scale=tau * lam_tilde)) return beta[day_of_year]
def model(y): mu = numpyro.sample("mu", dist.Normal()) sigma = numpyro.sample("sigma", dist.HalfNormal()) with numpyro.plate("plate1", y.shape[0]): numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
def model(x, y): nn = numpyro.module("nn", Dense(1), (10,)) mu = nn(x).squeeze(-1) sigma = numpyro.sample("sigma", dist.HalfNormal(1)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
def model(X: DeviceArray) -> DeviceArray: """Gamma-Poisson hierarchical model for daily sales forecasting Args: X: input data Returns: output data """ n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target eps = 1e-12 # epsilon plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) plate_days = numpyro.plate(Plate.days, n_days, dim=-1) disp_param_mu = numpyro.sample(Site.disp_param_mu, dist.Normal(loc=4.0, scale=1.0)) disp_param_sigma = numpyro.sample(Site.disp_param_sigma, dist.HalfNormal(scale=1.0)) with plate_stores: with numpyro.handlers.reparam( config={Site.disp_params: TransformReparam()}): disp_params = numpyro.sample( Site.disp_params, dist.TransformedDistribution( dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1), dist.transforms.AffineTransform(disp_param_mu, disp_param_sigma), ), ) with plate_features: coef_mus = numpyro.sample( Site.coef_mus, dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)), ) coef_sigmas = numpyro.sample( Site.coef_sigmas, dist.HalfNormal(scale=2.0 * jnp.ones(n_features))) with plate_stores: with numpyro.handlers.reparam( config={Site.coefs: TransformReparam()}): coefs = numpyro.sample( Site.coefs, dist.TransformedDistribution( dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.0), dist.transforms.AffineTransform(coef_mus, coef_sigmas), ), ) with plate_days, plate_stores: targets = X[..., -1] features = jnp.nan_to_num(X[..., :-1]) # padded features to 0 is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets), jnp.ones_like(targets)) not_observed = 1 - is_observed means = (is_observed * jnp.exp( jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) + not_observed * eps) betas = is_observed * jnp.exp(-disp_params) + not_observed alphas = means * betas return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas), obs=jnp.nan_to_num(targets))