def model(data=None): with numpyro.plate("dim", 2): beta = numpyro.sample("beta", dist.Beta(1.0, 1.0)) with numpyro.plate("plate", N, dim=-2): numpyro.deterministic("beta_sq", beta**2) with numpyro.plate("dim", 2): numpyro.sample("obs", dist.Bernoulli(beta), obs=data)
def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None
def GP(X, y): X = numpyro.deterministic("X", X) # Set informative priors on kernel hyperparameters. η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0)) ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0)) σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0)) # Compute kernel K = rbf_kernel(X, X, η, ℓ) K = add_to_diagonal(K, σ) K = add_to_diagonal(K, wandb.config.jitter) # cholesky decomposition Lff = numpyro.deterministic("Lff", cholesky(K, lower=True)) # Sample y according to the standard gaussian process formula return numpyro.sample( "y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), scale_tril=Lff).expand_by( y.shape[:-1]) # for multioutput scenarios .to_event(y.ndim - 1), obs=y, )
def SEIR_dynamics(T, params, x0, obs=None, death=None, suffix=""): '''Run SEIR dynamics for T time steps Uses SEIRModel.run to run dynamics with pre-determined parameters. ''' beta0, sigma, gamma, rw_scale, drift, \ det_prob, det_noise_scale, death_prob, death_rate, det_prob_d = params beta = numpyro.sample("beta" + suffix, ExponentialRandomWalk(loc=beta0, scale=rw_scale, drift= drift, num_steps=T-1)) # Run ODE x = SEIRModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate)) x = x[1:] # first entry duplicates x0 numpyro.deterministic("x" + suffix, x) # Noisy observations with numpyro.handlers.scale(scale_factor=0.5): y = observe("y" + suffix, x[:,6], det_prob, det_noise_scale, obs = obs) with numpyro.handlers.scale(scale_factor=2.0): z = observe("z" + suffix, x[:,5], det_prob_d, det_noise_scale, obs = death) return beta, x, y, z
def dynamics(self, T, params, x0, confirmed=None, death=None, det_rate=None, suffix=""): '''Run SEIRD dynamics for T time steps''' beta0, sigma, gamma, rw_scale, drift, \ det_prob, det_noise_scale, death_prob, death_rate, det_prob_d,det_prob_future = params beta = numpyro.sample("beta" + suffix, ExponentialRandomWalk(loc=beta0, scale=rw_scale, drift=drift, num_steps=T-1)) if suffix != "_future": det_prob_rw = numpyro.sample("det_rate_rw" + suffix, LogisticRandomWalk(loc=det_prob, scale=rw_scale, drift=0, num_steps=T-1)) else: det_prob_rw = det_prob_future # Run ODE x = SEIRDModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate)) x = x[1:] # first entry duplicates x0 numpyro.deterministic("x" + suffix, x) # Noisy observations with numpyro.handlers.scale(scale_factor=0.5): y = observe("y" + suffix, x[:,6], det_prob_rw, det_noise_scale, obs = confirmed) with numpyro.handlers.scale(scale_factor=2.0): z = observe("z" + suffix, x[:,5], det_prob_d, det_noise_scale, obs = death) return beta,det_prob_rw, x, y, z
def fulldyn_mixture_model(beliefs, y, mask): M, T, N, _ = beliefs[0].shape c0 = beliefs[-1] tau = .5 with npyro.plate('N', N): weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (N, M) mu = npyro.sample('mu', dist.Normal(5., 5.)) lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) scale = npyro.sample('scale', dist.HalfNormal(1.)) theta = npyro.sample('theta', dist.HalfCauchy(5.)) rho = jnp.exp(-theta) sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale x0 = jnp.zeros(N) def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (lam_start, x0), jnp.arange(T))
def transition(state, i): x0, mu0 = state x1 = numpyro.sample('x', dist.Normal(phi * x0, q)) mu1 = beta * mu0 + x1 y1 = numpyro.sample('y', dist.Normal(mu1, r)) numpyro.deterministic('y2', y1 * 2) return (x1, mu1), (x1, y1)
def model_nested_plates_3(): outer = numpyro.plate("outer", 10, dim=-1) inner = numpyro.plate("inner", 5, dim=-2) numpyro.deterministic("z", 1.0) with inner, outer: xy = numpyro.sample("xy", dist.Normal(jnp.zeros((5, 10)), 1.0)) assert xy.shape == (5, 10)
def model_nested_plates_3(): outer = numpyro.plate('outer', 10, dim=-1) inner = numpyro.plate('inner', 5, dim=-2) numpyro.deterministic('z', 1.) with inner, outer: xy = numpyro.sample('xy', dist.Normal(jnp.zeros((5, 10)), 1.)) assert xy.shape == (5, 10)
def model(y=None): n = y.size if y is not None else 1 mu = numpyro.sample("mu", dist.Normal(0, 5)) sigma = numpyro.param("sigma", 1, constraint=constraints.positive) y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y) numpyro.deterministic("z", (y - mu) / sigma)
def dynamics(self, T, params, x0, num_frozen=0, confirmed=None, death=None, suffix=""): '''Run SEIRD dynamics for T time steps''' beta0, \ sigma, \ gamma, \ rw_scale, \ drift, \ det_prob0, \ confirmed_dispersion, \ death_dispersion, \ death_prob, \ death_rate, \ det_prob_d = params rw = frozen_random_walk("rw" + suffix, num_steps=T - 1, num_frozen=num_frozen) beta = numpyro.deterministic("beta", beta0 * np.exp(rw_scale * rw)) det_prob = numpyro.sample( "det_prob" + suffix, LogisticRandomWalk(loc=det_prob0, scale=rw_scale, drift=0, num_steps=T - 1)) # Run ODE x = SEIRDModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate)) numpyro.deterministic("x" + suffix, x[1:]) x_diff = np.diff(x, axis=0) # Noisy observations with numpyro.handlers.scale(scale=0.5): y = observe_nb2("dy" + suffix, x_diff[:, 6], det_prob, confirmed_dispersion, obs=confirmed) with numpyro.handlers.scale(scale=2.0): z = observe_nb2("dz" + suffix, x_diff[:, 5], det_prob_d, death_dispersion, obs=death) return beta, det_prob, x, y, z
def model(): transform = hk.transform_with_state if batchnorm else hk.transform nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3)) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: y = nn(numpyro.prng_key(), x) else: y = nn(x) numpyro.deterministic("y", y)
def SEIR_dynamics_hierarchical(T, params, x0, obs=None, death=None, use_rw=True, suffix=""): '''Run SEIR dynamics for T time steps Uses SEIRModel.run to run dynamics with pre-determined parameters. ''' beta, sigma, gamma, det_rate, det_noise_scale, rw_loc, rw_scale, drift, death_prob, death_rate, det_prob_d = params num_places, T_minus_1 = beta.shape assert (T_minus_1 == T - 1) # prep for broadcasting over time sigma = sigma[:, None] gamma = gamma[:, None] det_rate = det_rate[:, None] death_prob = death_prob[:, None] death_rate = death_rate[:, None] det_prob_d = det_prob_d[:, None] if use_rw: with numpyro.plate("places", num_places): rw = numpyro.sample( "rw" + suffix, ExponentialRandomWalk(loc=rw_loc, scale=rw_scale, drift=drift, num_steps=T - 1)) else: rw = rw_loc beta *= rw # Run ODE apply_model = lambda x0, beta, sigma, gamma, death_prob, death_rate: SEIRDModel.run( T, x0, (beta, sigma, gamma, death_prob, death_rate)) x = jax.vmap(apply_model)(x0, beta, sigma, gamma, death_prob, death_rate) x = x[:, 1:, :] # drop first time step from result (duplicates initial value) numpyro.deterministic("x" + suffix, x) # Noisy observations y = observe("y" + suffix, x[:, :, 6], det_rate, det_noise_scale, obs=obs) z = observe("z" + suffix, x[:, :, 5], det_prob_d, det_noise_scale, obs=death) return rw, x, y, z
def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None: cauchy_sd = jnp.max(y) / 150 nu = numpyro.sample("nu", dist.Uniform(2, 20)) powx = numpyro.sample("powx", dist.Uniform(0, 1)) sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd)) offset_sigma = numpyro.sample( "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)) coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd)) pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1)) pow_trend = 1.5 * pow_trend_beta - 0.5 pow_season = numpyro.sample("pow_season", dist.Beta(1, 1)) level_sm = numpyro.sample("level_sm", dist.Beta(1, 2)) s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1)) init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3)) num_lim = y.shape[0] def transition_fn( carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]: level, s, moving_sum = carry season = s[0] * level**pow_season exp_val = level + coef_trend * level**pow_trend + season exp_val = jnp.clip(exp_val, a_min=0) y_t = jnp.where(t >= num_lim, exp_val, y[t]) moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0) level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season) level = level_sm * level_p + (1 - level_sm) * level level = jnp.clip(level, a_min=0) new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0] new_s = jnp.where(t >= num_lim, s[0], new_s) s = jnp.concatenate([s[1:], new_s[None]], axis=0) omega = sigma * exp_val**powx + offset_sigma y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega)) return (level, s, moving_sum), y_ level_init = y[0] s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0) moving_sum = level_init with numpyro.handlers.condition(data={"y": y[1:]}): _, ys = scan(transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, num_lim + future)) numpyro.deterministic("y_forecast", ys)
def model(X, Y): N, P = X.shape log_sigma = numpyro.sample("log_sigma", dist.Normal(1.0)) sigma = jnp.exp(log_sigma) beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P))) mean = jnp.sum(beta * X, axis=-1) numpyro.deterministic("mean", mean) numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)
def dynamics(self, T, params, x0, num_frozen=0, confirmed=None, death=None, suffix=""): '''Run SEIRD dynamics for T time steps''' beta0, \ sigma, \ gamma, \ rw_scale, \ drift, \ det_prob0, \ det_noise_scale, \ death_prob, \ death_rate, \ det_prob_d = params rw = frozen_random_walk("rw" + suffix, num_steps=T - 1, num_frozen=num_frozen) beta = numpyro.deterministic("beta", beta0 * np.exp(rw_scale * rw)) det_prob = numpyro.sample( "det_prob" + suffix, LogisticRandomWalk(loc=det_prob0, scale=rw_scale, drift=0, num_steps=T - 1)) # Run ODE x = SEIRDModel.run(T, x0, (beta, sigma, gamma, death_prob, death_rate)) x = x[1:] # first entry duplicates x0 numpyro.deterministic("x" + suffix, x) # Noisy observations with numpyro.handlers.scale(scale_factor=0.5): y = observe("y" + suffix, x[:, 6], det_prob, det_noise_scale, obs=confirmed) with numpyro.handlers.scale(scale_factor=2.0): z = observe("z" + suffix, x[:, 5], det_prob_d, det_noise_scale, obs=death) return beta, det_prob, x, y, z
def holt_winters(y, n_seasons, future=0): T = y.shape[0] level_smoothing = numpyro.sample("level_smoothing", dist.Beta(1, 1)) trend_smoothing = numpyro.sample("trend_smoothing", dist.Beta(1, 1)) seasonality_smoothing = numpyro.sample("seasonality_smoothing", dist.Beta(1, 1)) adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing) noise = numpyro.sample("noise", dist.HalfNormal(1)) level_init = numpyro.sample("level_init", dist.Normal(0, 1)) trend_init = numpyro.sample("trend_init", dist.Normal(0, 1)) with numpyro.plate("n_seasons", n_seasons): seasonality_init = numpyro.sample("seasonality_init", dist.Normal(0, 1)) def transition_fn(carry, t): previous_level, previous_trend, previous_seasonality = carry level = jnp.where( t < T, level_smoothing * (y[t] - previous_seasonality[0]) + (1 - level_smoothing) * (previous_level + previous_trend), previous_level, ) trend = jnp.where( t < T, trend_smoothing * (level - previous_level) + (1 - trend_smoothing) * previous_trend, previous_trend, ) new_season = jnp.where( t < T, adj_seasonality_smoothing * (y[t] - (previous_level + previous_trend)) + (1 - adj_seasonality_smoothing) * previous_seasonality[0], previous_seasonality[0], ) step = jnp.where(t < T, 1, t - T + 1) mu = previous_level + step * previous_trend + previous_seasonality[0] pred = numpyro.sample("pred", dist.Normal(mu, noise)) seasonality = jnp.concatenate( [previous_seasonality[1:], new_season[None]], axis=0) return (level, trend, seasonality), pred with numpyro.handlers.condition(data={"pred": y}): _, preds = scan( transition_fn, (level_init, trend_init, seasonality_init), jnp.arange(T + future), ) if future > 0: numpyro.deterministic("y_forecast", preds[-future:])
def model(): net = flax_module( "nn", Net(), apply_rng=["dropout"] if dropout else None, mutable=["batch_stats"] if batchnorm else None, input_shape=(4, 3), ) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: y = net(x, rngs={"dropout": numpyro.prng_key()}) else: y = net(x) numpyro.deterministic("y", y)
def _build_conditional(self, Xs=None, Xconds=None, **kwargs): # sets deterministic sites to sample from the condtional Xs = cartesian(Xs) Xconds = cartesian(Xconds) npy.deterministic(f"{self.name}_mean", self.mean_func(Xs)) npy.deterministic(f"{self.name}_cond", self.mean_func(Xconds)) npy.deterministic(f"{self.name}_Kss", self.cov_func(Xconds)) npy.deterministic(f"{self.name}_Ksx", self.cov_func(Xconds, Xs))
def fulldyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape c0 = beliefs[-1] mu = npyro.sample('mu', dist.Normal(5., 5.)) lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) scale = npyro.sample('scale', dist.HalfNormal(1.)) theta = npyro.sample('theta', dist.HalfCauchy(5.)) rho = jnp.exp(-theta) sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale x0 = jnp.zeros(1) def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (lam_start, x0), jnp.arange(T))
def _glitch_amplitudes(self, nu): # nu_max = numpyro.sample("nu_max", self._nu_max) # numpyro.deterministic("he_nu_max", self.he_glitch.amplitude(nu_max)) # numpyro.deterministic("cz_nu_max", self.cz_glitch.amplitude(nu_max)) if self.window_width == "full": low, high = nu.min(), nu.max() else: low = self._nu_max.mean - self.window_width * self.background._delta_nu high = self._nu_max.mean + self.window_width * self.background._delta_nu he_amp = numpyro.deterministic( "he_amplitude", self.he_glitch._average_amplitude(low, high)) cz_amp = numpyro.deterministic( "cz_amplitude", self.cz_glitch._average_amplitude(low, high)) return he_amp, cz_amp
def weekday_effect(day_of_week): with plate("plate_day_of_week", 6): weekday = sample("_beta", dist.Normal(0, 1)) monday = jnp.array([-jnp.sum(weekday)]) # Monday = 0 in original beta = deterministic("beta", jnp.concatenate((monday, weekday))) return beta[day_of_week]
def birthdays_model( x, day_of_week, day_of_year, memorial_days_indicator, labour_days_indicator, thanksgiving_days_indicator, w0, L, M1, M2, M3, y=None, ): intercept = sample("intercept", dist.Normal(0, 1)) f1 = scope(trend_gp, "trend")(x, L, M1) f2 = scope(year_gp, "year")(x, w0, M2) g3 = scope(trend_gp, "week-trend")(x, L, M3) # length ~ lognormal(-1, 1) in original weekday = scope(weekday_effect, "week")(day_of_week) yearday = scope(yearday_effect, "day")(day_of_year) # # --- special days memorial = scope(special_effect, "memorial")(memorial_days_indicator) labour = scope(special_effect, "labour")(labour_days_indicator) thanksgiving = scope(special_effect, "thanksgiving")(thanksgiving_days_indicator) day = yearday + memorial + labour + thanksgiving # --- Combine components f = deterministic("f", intercept + f1 + f2 + jnp.exp(g3) * weekday + day) sigma = sample("sigma", dist.HalfNormal(0.5)) with plate("obs", x.shape[0]): sample("y", dist.Normal(f, sigma), obs=y)
def nondynamic_mixture_model(beliefs, y, mask): M, T, _ = beliefs[0].shape tau = .5 weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (M, ) gamma = npyro.sample('gamma', dist.InverseGamma(2., 5.)) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) U = jnp.log(p0) def transition_fn(carry, t): logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs).mask(mask[t]) npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist)) return None, None with npyro.handlers.condition(data={"y": y}): scan(transition_fn, None, jnp.arange(T))
def model(M, A, D=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bM * M + bA * A) numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
def nondyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape gamma = npyro.sample('gamma', dist.InverseGamma(2., 3.)) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) U = jnp.log(p0) def transition_fn(carry, t): logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return None, None with npyro.handlers.condition(data={"y": y}): scan(transition_fn, None, jnp.arange(T))
def observe_poisson(name, latent, det_prob, obs=None): mask = True if obs is not None: mask = np.isfinite(obs) & (obs >= 0) obs = np.where(mask, obs, 0.0) det_prob = np.broadcast_to(det_prob, latent.shape) mean = det_prob * latent d = dist.Poisson(mean) numpyro.deterministic("mean_" + name, mean) with numpyro.handlers.mask(mask_array=mask): y = numpyro.sample(name, d, obs=obs) return y
def observe_nonrandom(name, latent, det_noise_scale, obs=None): mask = True if obs is not None: mask = np.isfinite(obs) & (obs >= 0) obs = np.where(mask, obs, 0.0) mean = latent scale = det_noise_scale * mean + 1 d = dist.TruncatedNormal(0., mean, scale) numpyro.deterministic("mean_" + name, mean) with numpyro.handlers.mask(mask_array=mask): y = numpyro.sample(name, d, obs=obs) return y
def __call__(self) -> Callable: """Samples the convective zone glitch function. Returns: function: The function :math:`f`. """ log_a = numpyro.sample("log_a_cz", self.log_a) log_tau = numpyro.sample("log_tau_cz", self.log_tau) self._a = numpyro.deterministic("a_cz", 10**log_a) self._tau = numpyro.deterministic("tau_cz", 10**log_tau) self._phi = numpyro.sample("phi_cz", self.phi) def fn(nu): return self.amplitude(nu) * self.oscillation(nu) return fn
def model_PMD(z, N, y=None, phi_prior=1 / 1000): z = jnp.abs(z) q = numpyro.sample("q", dist.Beta(2, 3)) # mean = 0.4, shape = 5 A = numpyro.sample("A", dist.Beta(2, 3)) # mean = 0.4, shape = 5 c = numpyro.sample("c", dist.Beta(1, 9)) # mean = 0.1, shape = 10 # Dz = numpyro.deterministic("Dz", A * (1 - q) ** (z - 1) + c) Dz = jnp.clip(numpyro.deterministic("Dz", A * (1 - q)**(z - 1) + c), 0, 1) D_max = numpyro.deterministic("D_max", A + c) # pylint: disable=unused-variable delta = numpyro.sample("delta", dist.Exponential(phi_prior)) phi = numpyro.deterministic("phi", delta + 2) alpha = numpyro.deterministic("alpha", Dz * phi) beta = numpyro.deterministic("beta", (1 - Dz) * phi) numpyro.sample("obs", dist.BetaBinomial(alpha, beta, N), obs=y)