def model(X, Y, hypers): S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0] sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['alpha3'])) phi = sigma * (S / np.sqrt(N)) / (P - S) eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi)) msq = numpyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1'])) xisq = numpyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2'])) eta2 = np.square(eta1) * np.sqrt(xisq) / msq lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(P))) kappa = np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam)) # sample observation noise var_obs = numpyro.sample( "var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs'])) # compute kernel kX = kappa * X k = kernel(kX, kX, eta1, eta2, hypers['c']) + var_obs * np.eye(N) assert k.shape == (N, N) # sample Y according to the standard gaussian process formula numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k), obs=Y)
def fulldyn_mixture_model(beliefs, y, mask): M, T, N, _ = beliefs[0].shape c0 = beliefs[-1] tau = .5 with npyro.plate('N', N): weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (N, M) mu = npyro.sample('mu', dist.Normal(5., 5.)) lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) scale = npyro.sample('scale', dist.HalfNormal(1.)) theta = npyro.sample('theta', dist.HalfCauchy(5.)) rho = jnp.exp(-theta) sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale x0 = jnp.zeros(N) def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (lam_start, x0), jnp.arange(T))
def singlevariate_kf(T=None, T_forecast=15, obs=None): """Define Kalman Filter in a single variate fashion. Parameters ---------- T: int T_forecast: int Times to forecast ahead. obs: np.array observed variable (infected, deaths...) """ # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind) T = len(obs) if T is None else T beta = numpyro.sample(name="beta", fn=dist.Normal(loc=0.0, scale=1)) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=0.1)) noises = numpyro.sample( "noises", fn=dist.Normal(0, 1.0), sample_shape=(T + T_forecast - 2,) ) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1)) z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=0, scale=0.1)) # Propagate the dynamics forward using jax.lax.scan carry = (beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 2) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) # Sample the observed y (y_obs) and missing y (y_mis) numpyro.sample( name="y_obs", fn=dist.Normal(loc=z_collection[:T], scale=sigma), obs=obs ) numpyro.sample( name="y_pred", fn=dist.Normal(loc=z_collection[T:], scale=sigma), obs=None )
def model(T, T_forecast, X, obs=None): """ Define priors over delta, tau, noises, sigma, z_prev, eta """ tau = numpyro.sample("tau", dist.HalfCauchy(10.0)) noises = numpyro.sample( "noises", dist.Normal(jnp.zeros(X.shape[0]), 5.0 * jnp.ones(X.shape[0])), ) sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=5.0)) delta = numpyro.sample("delta", dist.Normal(0, 5.0)) z_prev = numpyro.sample("z_prev", dist.Normal(0, 3.0)) eta = numpyro.sample( "eta", dist.Normal(jnp.zeros(X.shape[1]), 5.0 * jnp.ones(X.shape[1])), ) """ Propagate the dynamics forward using jax.lax.scan """ carry = (delta, eta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan( f=f, init=carry, xs=(X, noises), ) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) """ Sample the observed_y (y_obs) and predicted_y (y_pred) - note that you don't need a pyro.plate! """ numpyro.sample("y_obs", dist.Normal(z_collection[:T], sigma), obs=obs) numpyro.sample("y_pred", dist.Normal(z_collection[T:], sigma), obs=None) return z_collection
def twoh_c_kf(T=None, T_forecast=15, obs=None): """Define Kalman Filter with two hidden variates.""" T = len(obs) if T is None else T # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind) #W = numpyro.sample(name="W", fn=dist.Normal(loc=jnp.zeros((2,4)), scale=jnp.ones((2,4)))) beta = numpyro.sample(name="beta", fn=dist.Normal(loc=jnp.array([0.,0.]), scale=jnp.ones(2))) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=.1)) z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.)) Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega) # lower cholesky factor of the covariance matrix noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast,)) # Propagate the dynamics forward using jax.lax.scan carry = (beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, noises, T+T_forecast) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) c = numpyro.sample(name="c", fn=dist.Normal(loc=jnp.array([[0.], [0.]]), scale=jnp.ones((2,1)))) obs_mean = jnp.dot(z_collection[:T,:], c).squeeze() pred_mean = jnp.dot(z_collection[T:,:], c).squeeze() # Sample the observed y (y_obs) numpyro.sample(name="y_obs", fn=dist.Normal(loc=obs_mean, scale=sigma), obs=obs) numpyro.sample(name="y_pred", fn=dist.Normal(loc=pred_mean, scale=sigma), obs=None)
def model(X, Y, hypers): S, P, N = hypers["expected_sparsity"], X.shape[1], X.shape[0] sigma = numpyro.sample("sigma", dist.HalfNormal(hypers["alpha3"])) phi = sigma * (S / jnp.sqrt(N)) / (P - S) eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi)) msq = numpyro.sample("msq", dist.InverseGamma(hypers["alpha1"], hypers["beta1"])) xisq = numpyro.sample("xisq", dist.InverseGamma(hypers["alpha2"], hypers["beta2"])) eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P))) kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam)) # compute kernel kX = kappa * X k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N) assert k.shape == (N, N) # sample Y according to the standard gaussian process formula numpyro.sample( "Y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k), obs=Y, )
def GP(X, y): X = numpyro.deterministic("X", X) # Set informative priors on kernel hyperparameters. η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0)) ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0)) σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0)) # Compute kernel K = rbf_kernel(X, X, η, ℓ) K = add_to_diagonal(K, σ) K = add_to_diagonal(K, wandb.config.jitter) # cholesky decomposition Lff = numpyro.deterministic("Lff", cholesky(K, lower=True)) # Sample y according to the standard gaussian process formula return numpyro.sample( "y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), scale_tril=Lff).expand_by( y.shape[:-1]) # for multioutput scenarios .to_event(y.ndim - 1), obs=y, )
def seir_model(initial_seir_state, num_days, infection_data, recovery_data, step_size=0.1): if infection_data is not None: assert num_days == infection_data.shape[0] if recovery_data is not None: assert num_days == recovery_data.shape[0] beta = numpyro.sample('beta', dist.HalfCauchy(scale=1000.)) gamma = numpyro.sample('gamma', dist.HalfCauchy(scale=1000.)) sigma = numpyro.sample('sigma', dist.HalfCauchy(scale=1000.)) mu = numpyro.sample('mu', dist.HalfCauchy(scale=1000.)) nu = np.array(0.0) # No vaccine yet def seir_update(day, seir_state): s = seir_state[..., 0] e = seir_state[..., 1] i = seir_state[..., 2] r = seir_state[..., 3] n = s + e + i + r s_upd = mu * (n - s) - beta * (s * i / n) - nu * s e_upd = beta * (s * i / n) - (mu + sigma) * e i_upd = sigma * e - (mu + gamma) * i r_upd = gamma * i - mu * r + nu * s return np.stack((s_upd, e_upd, i_upd, r_upd), axis=-1) num_steps = int(num_days / step_size) sim = runge_kutta_4(seir_update, initial_seir_state, step_size, num_steps) sim = np.reshape(sim, (num_days, int(1 / step_size), 4))[:, -1, :] + 1e-3 with numpyro.plate('data', num_days): numpyro.sample('infections', dist.Poisson(sim[:, 2]), obs=infection_data) numpyro.sample('recovery', dist.Poisson(sim[:, 3]), obs=recovery_data)
def multivariate_kf(T=None, T_forecast=15, obs=None): """Define Kalman Filter in a multivariate fashion. The "time-series" are correlated. To define these relationships in a efficient manner, the covarianze matrix of h_t (or, equivalently, the noises) is drown from a Cholesky decomposed matrix. Parameters ---------- T: int T_forecast: int obs: np.array observed variable (infected, deaths...) """ T = len(obs) if T is None else T beta = numpyro.sample( name="beta", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)) ) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1)) z_prev = numpyro.sample( name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)) ) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0)) Sigma_lower = jnp.matmul( jnp.diag(jnp.sqrt(tau)), L_Omega ) # lower cholesky factor of the covariance matrix noises = numpyro.sample( "noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T + T_forecast - 1,), ) # Propagate the dynamics forward using jax.lax.scan carry = (beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) # Sample the observed y (y_obs) and missing y (y_mis) numpyro.sample( name="y_obs1", fn=dist.Normal(loc=z_collection[:T, 0], scale=sigma), obs=obs[:, 0], ) numpyro.sample( name="y_pred1", fn=dist.Normal(loc=z_collection[T:, 0], scale=sigma), obs=None ) numpyro.sample( name="y_obs2", fn=dist.Normal(loc=z_collection[:T, 1], scale=sigma), obs=obs[:, 1], ) numpyro.sample( name="y_pred2", fn=dist.Normal(loc=z_collection[T:, 1], scale=sigma), obs=None )
def multih_kf(T=None, T_forecast=15, hidden=4, obs=None): """Define Kalman Filter: multiple hidden variables; just one time series. Parameters ---------- T: int T_forecast: int Times to forecast ahead. hidden: int number of variables in the latent space obs: np.array observed variable (infected, deaths...) """ # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind) T = len(obs) if T is None else T beta = numpyro.sample( name="beta", fn=dist.Normal(loc=jnp.zeros(hidden), scale=jnp.ones(hidden)) ) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1)) z_prev = numpyro.sample( name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)) ) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0)) Sigma_lower = jnp.matmul( jnp.diag(jnp.sqrt(tau)), L_Omega ) # lower cholesky factor of the covariance matrix noises = numpyro.sample( "noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T + T_forecast - 1,), ) # Propagate the dynamics forward using jax.lax.scan carry = (beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, noises, T + T_forecast - 1) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) # Sample the observed y (y_obs) and missing y (y_mis) numpyro.sample( name="y_obs", fn=dist.Normal(loc=z_collection[:T, :].sum(axis=1), scale=sigma), obs=obs[:, 0], ) numpyro.sample( name="y_pred", fn=dist.Normal(loc=z_collection[T:, :].sum(axis=1), scale=sigma), obs=None )
def fulldyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape c0 = beliefs[-1] mu = npyro.sample('mu', dist.Normal(5., 5.)) lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) scale = npyro.sample('scale', dist.HalfNormal(1.)) theta = npyro.sample('theta', dist.HalfCauchy(5.)) rho = jnp.exp(-theta) sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale x0 = jnp.zeros(1) def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (lam_start, x0), jnp.arange(T))
def pacs_model(priors): pointing_matrices = [([p.amat_row, p.amat_col], p.amat_data) for p in priors] flux_lower = np.asarray([p.prior_flux_lower for p in priors]).T flux_upper = np.asarray([p.prior_flux_upper for p in priors]).T bkg_mu = np.asarray([p.bkg[0] for p in priors]).T bkg_sig = np.asarray([p.bkg[1] for p in priors]).T with numpyro.plate('bands', len(priors)): sigma_conf = numpyro.sample('sigma_conf', dist.HalfCauchy(1.0, 0.5)) bkg = numpyro.sample('bkg', dist.Normal(bkg_mu, bkg_sig)) with numpyro.plate('nsrc', priors[0].nsrc): src_f = numpyro.sample('src_f', dist.Uniform(flux_lower, flux_upper)) db_hat_psw = sp_matmul(pointing_matrices[0], src_f[:, 0][:, None], priors[0].snpix).reshape(-1) + bkg[0] db_hat_pmw = sp_matmul(pointing_matrices[1], src_f[:, 1][:, None], priors[1].snpix).reshape(-1) + bkg[1] sigma_tot_psw = jnp.sqrt( jnp.power(priors[0].snim, 2) + jnp.power(sigma_conf[0], 2)) sigma_tot_pmw = jnp.sqrt( jnp.power(priors[1].snim, 2) + jnp.power(sigma_conf[1], 2)) with numpyro.plate('psw_pixels', priors[0].sim.size): # as ind_psw: numpyro.sample("obs_psw", dist.Normal(db_hat_psw, sigma_tot_psw), obs=priors[0].sim) with numpyro.plate('pmw_pixels', priors[1].sim.size): # as ind_pmw: numpyro.sample("obs_pmw", dist.Normal(db_hat_pmw, sigma_tot_pmw), obs=priors[1].sim)
def schools_model(): mu = numpyro.sample("mu", dist.Normal(0, 5)) tau = numpyro.sample("tau", dist.HalfCauchy(5)) theta = numpyro.sample("theta", dist.Normal(mu, tau), sample_shape=(data["J"], )) numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"])
def schools_model(): mu = numpyro.sample('mu', dist.Normal(0, 5)) tau = numpyro.sample('tau', dist.HalfCauchy(5)) theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'], )) numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])
def eight_schools(J, sigma): mu = numpyro.sample("mu", dist.Normal(0, 5)) tau = numpyro.sample("tau", dist.HalfCauchy(5)) with numpyro.plate("J", J): theta = numpyro.sample("theta", dist.Normal(mu, tau)) numpyro.sample("scores", dist.Normal(theta, sigma), sample_shape=(n, ))
def partially_pooled_with_logit(at_bats: jnp.ndarray, hits: Optional[jnp.ndarray] = None) -> None: loc = numpyro.sample("loc", dist.Normal(-1, 1)) scale = numpyro.sample("scale", dist.HalfCauchy(1)) num_players = at_bats.shape[0] with numpyro.plate("num_players", num_players): alpha = numpyro.sample("alpha", dist.Normal(loc, scale)) numpyro.sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
def model(X, Y): N, P = X.shape sigma = numpyro.sample("sigma", dist.HalfCauchy(1.0)) beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P))) mean = jnp.sum(beta * X, axis=-1) numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)
def numpyro_model(X, y): if inference == "map" or "vi_mf" or "vi_full": # Set priors on hyperparameters. η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0)) ℓ = numpyro.sample( "length_scale", dist.Gamma(2.0, 1.0), sample_shape=(n_features,) ) σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0)) elif inference == "mll": # set params and constraints on hyperparams η = numpyro.param( "variance", init_value=1.0, constraints=dist.constraints.positive ) ℓ = numpyro.param( "length_scale", init_value=jnp.ones(n_features), constraints=dist.constraints.positive, ) σ = numpyro.param( "obs_noise", init_value=0.01, onstraints=dist.constraints.positive ) else: raise ValueError(f"Unrecognized inference scheme: {inference}") x_u = numpyro.param("x_u", init_value=X_u_init) # Kernel Function rbf_kernel = RBF(variance=η, length_scale=ℓ) # GP Model gp_model = SGPVFE( X=X, X_u=x_u, y=y, mean=zero_mean, kernel=rbf_kernel, obs_noise=σ, jitter=jitter, ) # Sample y according SGP return gp_model.to_numpyro(y=y)
def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None: cauchy_sd = jnp.max(y) / 150 nu = numpyro.sample("nu", dist.Uniform(2, 20)) powx = numpyro.sample("powx", dist.Uniform(0, 1)) sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd)) offset_sigma = numpyro.sample( "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)) coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd)) pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1)) pow_trend = 1.5 * pow_trend_beta - 0.5 pow_season = numpyro.sample("pow_season", dist.Beta(1, 1)) level_sm = numpyro.sample("level_sm", dist.Beta(1, 2)) s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1)) init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3)) num_lim = y.shape[0] def transition_fn( carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]: level, s, moving_sum = carry season = s[0] * level**pow_season exp_val = level + coef_trend * level**pow_trend + season exp_val = jnp.clip(exp_val, a_min=0) y_t = jnp.where(t >= num_lim, exp_val, y[t]) moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0) level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season) level = level_sm * level_p + (1 - level_sm) * level level = jnp.clip(level, a_min=0) new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0] new_s = jnp.where(t >= num_lim, s[0], new_s) s = jnp.concatenate([s[1:], new_s[None]], axis=0) omega = sigma * exp_val**powx + offset_sigma y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega)) return (level, s, moving_sum), y_ level_init = y[0] s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0) moving_sum = level_init with numpyro.handlers.condition(data={"y": y[1:]}): _, ys = scan(transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, num_lim + future)) numpyro.deterministic("y_forecast", ys)
def model_bernoulli_likelihood(X, Y): D_X = X.shape[1] # sample from horseshoe prior lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X))) tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1))) # note that this reparameterization (i.e. coordinate transformation) improves # posterior geometry and makes NUTS sampling more efficient unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(0.0, jnp.ones(D_X))) scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas) # compute mean function using linear coefficients mean_function = jnp.dot(X, scaled_betas) # observe data numpyro.sample("Y", dist.Bernoulli(logits=mean_function), obs=Y)
def model_w_c(T, T_forecast, x, obs=None): # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind) W = numpyro.sample(name="W", fn=dist.Normal(loc=jnp.zeros((2, 4)), scale=jnp.ones((2, 4)))) beta = numpyro.sample(name="beta", fn=dist.Normal(loc=jnp.array([0.0, 0.0]), scale=jnp.ones(2))) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1)) z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))) # Define LKJ prior L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0)) Sigma_lower = jnp.matmul( jnp.diag(jnp.sqrt(tau)), L_Omega) # lower cholesky factor of the covariance matrix noises = numpyro.sample( "noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T + T_forecast, ), ) # Propagate the dynamics forward using jax.lax.scan carry = (W, beta, z_prev, tau) z_collection = [z_prev] carry, zs_exp = lax.scan(f, carry, (x, noises), T + T_forecast) z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0) c = numpyro.sample(name="c", fn=dist.Normal(loc=jnp.array([[0.0], [0.0]]), scale=jnp.ones((2, 1)))) obs_mean = jnp.dot(z_collection[:T, :], c).squeeze() pred_mean = jnp.dot(z_collection[T:, :], c).squeeze() # Sample the observed y (y_obs) numpyro.sample(name="y_obs", fn=dist.Normal(loc=obs_mean, scale=sigma), obs=obs) numpyro.sample(name="y_pred", fn=dist.Normal(loc=pred_mean, scale=sigma), obs=None)
def prefdyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape c0 = beliefs[-1] lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.)) def transition_fn(carry, t): lam_prev = carry U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return lam_next, None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, lam_start, jnp.arange(T))
def model_normal_likelihood(X, Y): D_X = X.shape[1] # sample from horseshoe prior lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X))) tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1))) # note that in practice for a normal likelihood we would probably want to # integrate out the coefficients (as is done for example in sparse_regression.py). # however, this trick wouldn't be applicable to other likelihoods # (e.g. bernoulli, see below) so we don't make use of it here. unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(0.0, jnp.ones(D_X))) scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas) # compute mean function using linear coefficients mean_function = jnp.dot(X, scaled_betas) prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0)) sigma_obs = 1.0 / jnp.sqrt(prec_obs) # observe data numpyro.sample("Y", dist.Normal(mean_function, sigma_obs), obs=Y)
def partially_pooled_with_logit(at_bats, hits=None): r""" Number of hits has a Binomial distribution with a logit link function. The logits $\alpha$ for each player is normally distributed with the mean and scale parameters sharing a common prior. :param (jnp.DeviceArray) at_bats: Number of at bats for each player. :param (jnp.DeviceArray) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ loc = numpyro.sample("loc", dist.Normal(-1, 1)) scale = numpyro.sample("scale", dist.HalfCauchy(1)) num_players = at_bats.shape[0] with numpyro.plate("num_players", num_players): alpha = numpyro.sample("alpha", dist.Normal(loc, scale)) return numpyro.sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
def model_noncentered(num: int, sigma: np.ndarray, y: Optional[np.ndarray] = None) -> None: mu = numpyro.sample("mu", dist.Normal(0, 5)) tau = numpyro.sample("tau", dist.HalfCauchy(5)) with numpyro.plate("num", num): with numpyro.handlers.reparam(config={"theta": TransformReparam()}): theta = numpyro.sample( "theta", dist.TransformedDistribution( dist.Normal(0.0, 1.0), dist.transforms.AffineTransform(mu, tau)), ) numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
def partially_pooled_with_logit(at_bats, hits=None): r""" Number of hits has a Binomial distribution with a logit link function. The logits $\alpha$ for each player is normally distributed with the mean and scale parameters sharing a common prior. :param (np.DeviceArray) at_bats: Number of at bats for each player. :param (np.DeviceArray) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] loc = sample("loc", dist.Normal(np.array([-1.]), np.array([1.]))) scale = sample("scale", dist.HalfCauchy(np.array([1.]))) shape = np.shape(loc)[:np.ndim(loc) - 1] + (num_players,) alpha = sample("alpha", dist.Normal(np.broadcast_to(loc, shape), np.broadcast_to(scale, shape))) return sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)
def yearday_effect(day_of_year): slab_df = 50 # 100 in original case study slab_scale = 2 scale_global = 0.1 tau = sample("tau", dist.HalfNormal( 2 * scale_global)) # Original uses half-t with 100df c_aux = sample("c_aux", dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df)) c = slab_scale * jnp.sqrt(c_aux) # Jan 1st: Day 0 # Feb 29th: Day 59 # Dec 31st: Day 365 with plate("plate_day_of_year", 366): lam = sample("lam", dist.HalfCauchy(scale=1)) lam_tilde = jnp.sqrt(c) * lam / jnp.sqrt(c + (tau * lam)**2) beta = sample("beta", dist.Normal(loc=0, scale=tau * lam_tilde)) return beta[day_of_year]
def ar_k(n_coefs, obs=None, X=None): beta = numpyro.sample(name="beta", sample_shape=(n_coefs, ), fn=dist.TransformedDistribution( dist.Normal(loc=0., scale=1), transforms=dist.transforms.AffineTransform( loc=0, scale=1, domain=dist.constraints.interval(-1, 1)))) tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=1)) z_init = numpyro.sample(name='z_init', fn=dist.Normal(0, 1), sample_shape=(n_coefs, )) obs_init = z_init[:n_coefs - 1] (beta, obs_last), zs_exp = scan_fn(n_coefs, beta, obs_init, obs) Z_exp = np.concatenate((z_init, zs_exp), axis=0) Z = numpyro.sample(name="Z", fn=dist.Normal(loc=Z_exp, scale=tau), obs=obs) return Z_exp, obs_last
def model(y_obs): mu = numpyro.sample('mu', dist.Normal(0., 1.)) sigma = numpyro.sample("sigma", dist.HalfCauchy(3.)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)
def horseshoe_model( y_vals, gid, cid, N, # array of number of y_vals in each gene slab_df=1, slab_scale=1, expected_large_covar_num=5, # expected large covar num here is the prior on the number of conditions we expect to affect expression of a given gene condition_intercept=False): gene_count = gid.max() + 1 condition_count = cid.max() + 1 # separate regularizing prior on intercept for each gene a_prior = dist.Normal(10., 10.) a = numpyro.sample("alpha", a_prior, sample_shape=(gene_count, )) # implement Finnish horseshoe half_slab_df = slab_df / 2 variance = y_vals.var() slab_scale2 = slab_scale**2 hs_shape = (gene_count, condition_count) # set up "local" horseshoe priors for each gene and condition beta_tilde = numpyro.sample( 'beta_tilde', dist.Normal(0., 1.), sample_shape=hs_shape ) # beta_tilde contains betas for all hs parameters lambd = numpyro.sample( 'lambd', dist.HalfCauchy(1.), sample_shape=hs_shape) # lambd contains lambda for each hs covariate # set up global hyperpriors. # each gene gets its own hyperprior for regularization of large effects to keep the sampling from wandering unfettered from 0. tau_tilde = numpyro.sample('tau_tilde', dist.HalfCauchy(1.), sample_shape=(gene_count, 1)) c2_tilde = numpyro.sample('c2_tilde', dist.InverseGamma(half_slab_df, half_slab_df), sample_shape=(gene_count, 1)) bC = finnish_horseshoe( M=hs_shape[1], # total number of conditions m0= expected_large_covar_num, # number of condition we expect to affect expression of a given gene N=N, # number of observations for the gene var=variance, half_slab_df=half_slab_df, slab_scale2=slab_scale2, tau_tilde=tau_tilde, c2_tilde=c2_tilde, lambd=lambd, beta_tilde=beta_tilde) numpyro.sample("b_condition", dist.Delta(bC), obs=bC) if condition_intercept: a_C_prior = dist.Normal(0., 1.) a_C = numpyro.sample('a_condition', a_C_prior, sample_shape=(condition_count, )) mu = a[gid] + a_C[cid] + bC[gid, cid] else: # calculate implied log2(signal) for each gene/condition # by adding each gene's intercept (a) to each of that gene's # condition effects (bC). mu = a[gid] + bC[gid, cid] sig_prior = dist.Exponential(1.) sigma = numpyro.sample('sigma', sig_prior) return numpyro.sample('obs', dist.Normal(mu, sigma), obs=y_vals)