def model(X, Y, hypers): S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0] sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['alpha3'])) phi = sigma * (S / np.sqrt(N)) / (P - S) eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi)) msq = numpyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1'])) xisq = numpyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2'])) eta2 = np.square(eta1) * np.sqrt(xisq) / msq lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(P))) kappa = np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam)) # sample observation noise var_obs = numpyro.sample( "var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs'])) # compute kernel kX = kappa * X k = kernel(kX, kX, eta1, eta2, hypers['c']) + var_obs * np.eye(N) assert k.shape == (N, N) # sample Y according to the standard gaussian process formula numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k), obs=Y)
def model(X, Y, hypers): S, P, N = hypers["expected_sparsity"], X.shape[1], X.shape[0] sigma = numpyro.sample("sigma", dist.HalfNormal(hypers["alpha3"])) phi = sigma * (S / jnp.sqrt(N)) / (P - S) eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi)) msq = numpyro.sample("msq", dist.InverseGamma(hypers["alpha1"], hypers["beta1"])) xisq = numpyro.sample("xisq", dist.InverseGamma(hypers["alpha2"], hypers["beta2"])) eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P))) kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam)) # compute kernel kX = kappa * X k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N) assert k.shape == (N, N) # sample Y according to the standard gaussian process formula numpyro.sample( "Y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k), obs=Y, )
def nondynamic_mixture_model(beliefs, y, mask): M, T, _ = beliefs[0].shape tau = .5 weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (M, ) gamma = npyro.sample('gamma', dist.InverseGamma(2., 5.)) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) U = jnp.log(p0) def transition_fn(carry, t): logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs).mask(mask[t]) npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist)) return None, None with npyro.handlers.condition(data={"y": y}): scan(transition_fn, None, jnp.arange(T))
def nondyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape gamma = npyro.sample('gamma', dist.InverseGamma(2., 3.)) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) U = jnp.log(p0) def transition_fn(carry, t): logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return None, None with npyro.handlers.condition(data={"y": y}): scan(transition_fn, None, jnp.arange(T))
def model(z=None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1)) sigma = sample('sigma', dists.InverseGamma(1.).expand_by((2, )).to_event(1)) with plate('batch', batch_size, batch_size): sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def guide(z=None, num_obs_total=None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] if num_obs_total is None: num_obs_total = batch_size mu_param = param('mu_param', 0.) sample('mu', dists.Normal(mu_param, 1.).expand_by((d, )).to_event(1)) sample('sigma', dists.InverseGamma(1.).expand_by((d, )).to_event(1))
def model(z = None, num_obs_total = None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] if num_obs_total is None: num_obs_total = batch_size mu = sample('mu', dists.Normal(args.prior_mu).expand_by((d,)).to_event(1)) sigma = sample('sigma', dists.InverseGamma(1.).expand_by((d,)).to_event(1)) with plate('batch', num_obs_total, batch_size): sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def model(x_first=None, x_second=None, num_obs_total=None) -> None: batch_size = 1 if x_first is not None: batch_size = x_first.shape[0] if num_obs_total is None: num_obs_total = batch_size mu = sample('mu', dists.Normal()) sigma = sample('sigma', dists.InverseGamma(1.)) with plate('batch', num_obs_total, batch_size): sample('x_first', dists.Normal(mu, sigma), obs=x_first) sample('x_second', dists.Normal(mu, sigma), obs=x_second)
def model(z=None, z2=None, num_obs_total=None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] assert (z.shape is not None) assert (z.shape[0] == z2.shape[0]) if num_obs_total is None: num_obs_total = batch_size mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1)) sigma = sample('sigma', dists.InverseGamma(1.).expand_by((2, )).to_event(1)) with plate('batch', num_obs_total, batch_size): sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def gammadyn_mixture_model(beliefs, y, mask): M, T, _ = beliefs[0].shape tau = .5 weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (M, ) gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.)) mu = jnp.log(jnp.exp(gamma) - 1) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) scale = npyro.sample('scale', dist.Gamma(1., 1.)) rho = npyro.sample('rho', dist.Beta(1., 2.)) sigma = jnp.sqrt(-(1 - rho**2) / (2 * jnp.log(rho))) * scale U = jnp.log(p0) def transition_fn(carry, t): x_prev = carry gamma_dyn = npyro.deterministic('gamma_dyn', nn.softplus(mu + x_prev)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma_dyn, -1), jnp.expand_dims(U, -2)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs).mask(mask[t]) npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist)) with npyro.handlers.reparam( config={"x_next": npyro.infer.reparam.TransformReparam()}): affine = dist.transforms.AffineTransform(rho * x_prev, sigma) x_next = npyro.sample( 'x_next', dist.TransformedDistribution(dist.Normal(0., 1.), affine)) return (x_next), None x0 = jnp.zeros(1) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (x0), jnp.arange(T))
def guide(k, obs=None, num_obs_total=None, d=None): # the latent MixGaus distribution which learns the parameters if obs is not None: assert(jnp.ndim(obs) == 2) _, d = jnp.shape(obs) else: assert(num_obs_total is not None) assert(d is not None) alpha_log = param('alpha_log', jnp.zeros(k)) alpha = jnp.exp(alpha_log) pis = sample('pis', dist.Dirichlet(alpha)) mus_loc = param('mus_loc', jnp.zeros((k, d))) mus = sample('mus', dist.Normal(mus_loc, 1.)) sigs = sample('sigs', dist.InverseGamma(1., 1.), obs=jnp.ones_like(mus)) return pis, mus, sigs
def model(k, obs=None, num_obs_total=None, d=None): # this is our model function using the GaussianMixture distribution # with prior belief if obs is not None: assert(jnp.ndim(obs) == 2) batch_size, d = jnp.shape(obs) else: assert(num_obs_total is not None) batch_size = num_obs_total assert(d is not None) num_obs_total = batch_size if num_obs_total is None else num_obs_total pis = sample('pis', dist.Dirichlet(jnp.ones(k))) mus = sample('mus', dist.Normal(jnp.zeros((k, d)), 10.)) sigs = sample('sigs', dist.InverseGamma(1., 1.), sample_shape=jnp.shape(mus)) with plate('batch', num_obs_total, batch_size): return sample('obs', GaussianMixture(mus, sigs, pis), obs=obs, sample_shape=(batch_size,))
def yearday_effect(day_of_year): slab_df = 50 # 100 in original case study slab_scale = 2 scale_global = 0.1 tau = sample("tau", dist.HalfNormal( 2 * scale_global)) # Original uses half-t with 100df c_aux = sample("c_aux", dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df)) c = slab_scale * jnp.sqrt(c_aux) # Jan 1st: Day 0 # Feb 29th: Day 59 # Dec 31st: Day 365 with plate("plate_day_of_year", 366): lam = sample("lam", dist.HalfCauchy(scale=1)) lam_tilde = jnp.sqrt(c) * lam / jnp.sqrt(c + (tau * lam)**2) beta = sample("beta", dist.Normal(loc=0, scale=tau * lam_tilde)) return beta[day_of_year]
def gammadyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.)) mu = jnp.log(jnp.exp(gamma) - 1) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) scale = npyro.sample('scale', dist.Gamma(1., 1.)) rho = npyro.sample('rho', dist.Beta(1., 2.)) sigma = jnp.sqrt(-(1 - rho**2) / (2 * jnp.log(rho))) * scale U = jnp.log(p0) def transition_fn(carry, t): x_prev = carry dyn_gamma = npyro.deterministic('dyn_gamma', nn.softplus(mu + x_prev)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(dyn_gamma, -1), jnp.expand_dims(U, -2)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return x_next, None x0 = jnp.zeros(1) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, x0, jnp.arange(T))
def prefdyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape c0 = beliefs[-1] lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.)) def transition_fn(carry, t): lam_prev = carry U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return lam_next, None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, lam_start, jnp.arange(T))
def trend_gp(x, L, M): alpha = sample("alpha", dist.HalfNormal(1.0)) length = sample("length", dist.InverseGamma(10.0, 2.0)) f = approx_se_ncp(x, alpha, length, L, M) return f
def model(self, t, X): noise = sample('noise', dist.LogNormal(0.0, 1.0), sample_shape=(self.D, )) hyp = sample('hyp', dist.Gamma(1.0, 0.5), sample_shape=(self.D, )) W = sample('W', dist.LogNormal(0.0, 1.0), sample_shape=(self.D, )) m0 = self.M - 1 sigma = 1 tau0 = (m0 / (self.M - m0) * (sigma / np.sqrt(1.0 * sum(self.N)))) tau_tilde = sample('tau_tilde', dist.HalfCauchy(1.), sample_shape=(self.D, )) tau = np.repeat(tau0 * tau_tilde, self.M // self.D) slab_scale = 1 slab_scale2 = slab_scale**2 slab_df = 1 half_slab_df = slab_df / 2 c2_tilde = sample('c2_tilde', dist.InverseGamma(half_slab_df, half_slab_df)) c2 = slab_scale2 * c2_tilde lambd = sample('lambd', dist.HalfCauchy(1.), sample_shape=(self.M, )) lambd_tilde = tau**2 * c2 * lambd**2 / (c2 + tau**2 * lambd**2) par = sample( 'par', dist.MultivariateNormal(np.zeros(self.M, ), np.diag(lambd_tilde))) # compute kernel K_11 = W[0] * self.RBF(self.t_t[0], self.t_t[0], hyp[0]) + np.eye( self.N[0]) * (noise[0] + self.jitter) K_22 = W[1] * self.RBF(self.t_t[1], self.t_t[1], hyp[1]) + np.eye( self.N[1]) * (noise[1] + self.jitter) K_33 = W[2] * self.RBF(self.t_t[2], self.t_t[2], hyp[2]) + np.eye( self.N[2]) * (noise[2] + self.jitter) K = np.concatenate([ np.concatenate([ K_11, np.zeros((self.N[0], self.N[1])), np.zeros((self.N[0], self.N[2])) ], axis=1), np.concatenate([ np.zeros((self.N[1], self.N[0])), K_22, np.zeros((self.N[1], self.N[2])) ], axis=1), np.concatenate([ np.zeros((self.N[2], self.N[0])), np.zeros((self.N[2], self.N[1])), K_33 ], axis=1) ], axis=0) # compute mean mut = odeint(self.dxdt, self.x0, self.t.flatten(), par) mu1 = mut[self.i_t[0], ind[0]] / self.max_X[0] mu2 = mut[self.i_t[1], ind[1]] / self.max_X[1] mu3 = mut[self.i_t[2], ind[2]] / self.max_X[2] mu = np.concatenate((mu1, mu2, mu3), axis=0) mu = mu.flatten('F') X = np.concatenate((self.X[0], self.X[1], self.X[2]), axis=0) X = X.flatten('F') # sample X according to the standard gaussian process formula sample("X", dist.MultivariateNormal(loc=mu, covariance_matrix=K), obs=X)
def horseshoe_model( y_vals, gid, cid, N, # array of number of y_vals in each gene slab_df=1, slab_scale=1, expected_large_covar_num=5, # expected large covar num here is the prior on the number of conditions we expect to affect expression of a given gene condition_intercept=False): gene_count = gid.max() + 1 condition_count = cid.max() + 1 # separate regularizing prior on intercept for each gene a_prior = dist.Normal(10., 10.) a = numpyro.sample("alpha", a_prior, sample_shape=(gene_count, )) # implement Finnish horseshoe half_slab_df = slab_df / 2 variance = y_vals.var() slab_scale2 = slab_scale**2 hs_shape = (gene_count, condition_count) # set up "local" horseshoe priors for each gene and condition beta_tilde = numpyro.sample( 'beta_tilde', dist.Normal(0., 1.), sample_shape=hs_shape ) # beta_tilde contains betas for all hs parameters lambd = numpyro.sample( 'lambd', dist.HalfCauchy(1.), sample_shape=hs_shape) # lambd contains lambda for each hs covariate # set up global hyperpriors. # each gene gets its own hyperprior for regularization of large effects to keep the sampling from wandering unfettered from 0. tau_tilde = numpyro.sample('tau_tilde', dist.HalfCauchy(1.), sample_shape=(gene_count, 1)) c2_tilde = numpyro.sample('c2_tilde', dist.InverseGamma(half_slab_df, half_slab_df), sample_shape=(gene_count, 1)) bC = finnish_horseshoe( M=hs_shape[1], # total number of conditions m0= expected_large_covar_num, # number of condition we expect to affect expression of a given gene N=N, # number of observations for the gene var=variance, half_slab_df=half_slab_df, slab_scale2=slab_scale2, tau_tilde=tau_tilde, c2_tilde=c2_tilde, lambd=lambd, beta_tilde=beta_tilde) numpyro.sample("b_condition", dist.Delta(bC), obs=bC) if condition_intercept: a_C_prior = dist.Normal(0., 1.) a_C = numpyro.sample('a_condition', a_C_prior, sample_shape=(condition_count, )) mu = a[gid] + a_C[cid] + bC[gid, cid] else: # calculate implied log2(signal) for each gene/condition # by adding each gene's intercept (a) to each of that gene's # condition effects (bC). mu = a[gid] + bC[gid, cid] sig_prior = dist.Exponential(1.) sigma = numpyro.sample('sigma', sig_prior) return numpyro.sample('obs', dist.Normal(mu, sigma), obs=y_vals)