def generate_sin(n_obs, noise_level, seed): """ Generate data (noisy and non-noisy from 0.5*sin(3*x) :param n_obs: :param noise_level: :return: tuple of noisy data and original data (without noise) data = {'X': X, 'y': y} """ # jax random generator rng_key1, rng_key2 = random.split(random.PRNGKey(seed)) # uniform sample from X X = dist.Uniform(0.0, 5.0).sample(rng_key1, sample_shape=(n_obs, 1)) # generate noise noise = dist.Uniform(-noise_level, noise_level).sample( rng_key2, sample_shape=(X.shape[0],) ) # generate y y = 0.5 * np.sin(3 * X[:, 0]) + noise noisy_data = { "X": X, "y": y, } # generate real observation from 0.5*sin(3*x) X_ = np.linspace(0.0, 5.0, 400) y_ = 0.5 * np.sin(3 * X_) data = { "X": X_, "y": y_, } return noisy_data, data
def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
def model_3(capture_history, sex): N, T = capture_history.shape phi_mean = numpyro.sample("phi_mean", dist.Uniform(0.0, 1.0)) # mean survival probability phi_logit_mean = logit(phi_mean) # controls temporal variability of survival probability phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0)) rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}): phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma)) phi_t = expit(phi_logit_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan(transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1))
def _init_to_uniform(site, radius=2, skip_param=False): if site['type'] == 'sample' and not site['is_observed']: if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape']) base_transform = biject_to(fn.support) unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape( base_transform.inv(value))) return base_transform(unconstrained_value) if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape( transform.inv(value))) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] else: base_transform = transform return base_transform(unconstrained_value)
def model_1(capture_history, sex): N, T = capture_history.shape phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0)) # survival probability rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y ) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan( transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1), )
def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
def model(x): numpyro.sample( "x", dist.Normal( numpyro.sample("loc", dist.Uniform(0, 20)), numpyro.sample("scale", dist.Uniform(0, 20)), ), obs=x, )
def model(X, y=None): sd_randwalk = numpyro.sample('sd_randwalk', dist.Uniform(low=0, high=100.0)) randwalk = numpyro.sample('mu', dist.Normal(loc=0, scale=sd_randwalk)) value = X + randwalk sd_value = numpyro.sample('sd', dist.Uniform(low=0.0, high=100.0)) pred_y = numpyro.sample('pred_y', dist.Normal(loc=value, scale=sd_value), obs=y)
def sgt(y: jnp.ndarray, seasonality: int, future: int = 0) -> None: cauchy_sd = jnp.max(y) / 150 nu = numpyro.sample("nu", dist.Uniform(2, 20)) powx = numpyro.sample("powx", dist.Uniform(0, 1)) sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd)) offset_sigma = numpyro.sample( "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)) coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd)) pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1)) pow_trend = 1.5 * pow_trend_beta - 0.5 pow_season = numpyro.sample("pow_season", dist.Beta(1, 1)) level_sm = numpyro.sample("level_sm", dist.Beta(1, 2)) s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1)) init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3)) num_lim = y.shape[0] def transition_fn( carry: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], t: jnp.ndarray ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]: level, s, moving_sum = carry season = s[0] * level**pow_season exp_val = level + coef_trend * level**pow_trend + season exp_val = jnp.clip(exp_val, a_min=0) y_t = jnp.where(t >= num_lim, exp_val, y[t]) moving_sum = moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0) level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season) level = level_sm * level_p + (1 - level_sm) * level level = jnp.clip(level, a_min=0) new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0] new_s = jnp.where(t >= num_lim, s[0], new_s) s = jnp.concatenate([s[1:], new_s[None]], axis=0) omega = sigma * exp_val**powx + offset_sigma y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega)) return (level, s, moving_sum), y_ level_init = y[0] s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0) moving_sum = level_init with numpyro.handlers.condition(data={"y": y[1:]}): _, ys = scan(transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, num_lim + future)) numpyro.deterministic("y_forecast", ys)
def actual_model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)), ) with numpyro.plate("N", len(data)): numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
def model_c(nu1, y1): Rp = numpyro.sample('Rp', dist.Uniform(0.4, 1.2)) RV = numpyro.sample('RV', dist.Uniform(5.0, 15.0)) MMR_CO = numpyro.sample('MMR_CO', dist.Uniform(0.0, 0.015)) vsini = numpyro.sample('vsini', dist.Uniform(15.0, 25.0)) g = 2478.57730044555 * Mp / Rp**2 # gravity u1 = 0.0 u2 = 0.0 # Layer-by-layer T-P model// lnsT = 6.0 # lnsT = numpyro.sample('lnsT', dist.Uniform(3.0,5.0)) sT = 10**lnsT lntaup = 0.5 # lntaup = numpyro.sample('lntaup', dist.Uniform(0,1)) taup = 10**lntaup cov = modelcov(lnParr, taup, sT) # T0=numpyro.sample('T0', dist.Uniform(1000.0,1100.0)) T0 = numpyro.sample('T0', dist.Uniform(1000, 2000)) Tarr = numpyro.sample( 'Tarr', dist.MultivariateNormal(loc=ONEARR, covariance_matrix=cov)) + T0 # line computation CO qt_CO = vmap(mdbCO.qr_interp)(Tarr) def obyo(y, tag, nusd, nus, numatrix_CO, mdbCO, cdbH2H2): # CO SijM_CO = jit(vmap(SijT, (0, None, None, None, 0)))(Tarr, mdbCO.logsij0, mdbCO.dev_nu_lines, mdbCO.elower, qt_CO) gammaLMP_CO = jit(vmap(gamma_exomol, (0, 0, None, None)))(Parr, Tarr, mdbCO.n_Texp, mdbCO.alpha_ref) gammaLMN_CO = gamma_natural(mdbCO.A) gammaLM_CO = gammaLMP_CO + gammaLMN_CO[None, :] sigmaDM_CO = jit(vmap(doppler_sigma, (None, 0, None)))(mdbCO.dev_nu_lines, Tarr, molmassCO) xsm_CO = xsmatrix(numatrix_CO, sigmaDM_CO, gammaLM_CO, SijM_CO) dtaumCO = dtauM(dParr, xsm_CO, MMR_CO * ONEARR, molmassCO, g) # CIA dtaucH2H2 = dtauCIA(nus, Tarr, Parr, dParr, vmrH2, vmrH2, mmw, g, cdbH2H2.nucia, cdbH2H2.tcia, cdbH2H2.logac) dtau = dtaumCO + dtaucH2H2 sourcef = planck.piBarr(Tarr, nus) F0 = rtrun(dtau, sourcef) / norm Frot = response.rigidrot(nus, F0, vsini, u1, u2) mu = response.ipgauss_sampling(nusd, nus, Frot, beta, RV) numpyro.sample(tag, dist.Normal(mu, sigmain), obs=y) obyo(y1, 'y1', nu1, nus, numatrix_CO, mdbCO, cdbH2H2)
def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) # set base value of x_guide is 0.9 x_base = 0.9 guide = substitute(guide, base_param_map={'x': x_base}) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(random.PRNGKey(0), (), ()) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) x, _ = x_guide.transform_with_intermediates(x_base) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
def model(data): alpha = 1 / jnp.mean(data.astype(np.float32)) lambda1 = numpyro.sample("lambda1", dist.Exponential(alpha)) lambda2 = numpyro.sample("lambda2", dist.Exponential(alpha)) tau = numpyro.sample("tau", dist.Uniform(0, 1)) lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample("obs", dist.Poisson(lambda12), obs=data)
def model(data): alpha = 1 / jnp.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
def init_agents(self, num_steps=1000, pop_size=(1e2, 1e2), initial_infections=2): self.age = ny.sample('age', dist.Uniform(0, 100)) self.sex = ny.sample('sex', dist.Binomial(1, .5)) self.risk_tolerance = ny.sample('risk', dist.Beta(2, 5)) # self.risk_factors = ny.sample('health', dist.Binomial(5, .3)) self.hygiene = ny.sample('hygiene', dist.Beta(2, 5)) # self.worker_type = ny.sample('worker_type', dist.Categorical((.6, .1, .2, .1))) self.epidemic_state = ny.sample( 'state', dist.Binomial(1, initial_infections / pop_size[0] * pop_size[1])) self.social_radius = ny.sample('radius', dist.Binomial(10, .2)) self.base_isolation = ny.sample('base_isolation', dist.Beta(2, 2)) # TODO: make these depend on risk factors as well self.will_be_hospitalized = ny.sample( 'hosp', dist.Binomial(1, self.params.HOSP_AGE_MAP[self.age])) self.will_die = ny.sample( 'die', dist.Binomial(1, self.params.DEATH_MAP[self.age])) # The lengths of the infection are handled on a per agent basis via scenarios, these are just placeholders. self.date_infected = np.where( self.epidemic_state > 0, np.zeros(shape=pop_size), np.full(shape=pop_size, fill_value=np.inf)) self.date_contagious = np.where( self.epidemic_state > 0, np.ceil(self.params.EXPOSED_PERIOD), np.full(shape=pop_size, fill_value=np.inf)) self.date_symptomatic = np.full(shape=pop_size, fill_value=np.inf) self.date_recovered = np.full(shape=pop_size, fill_value=np.inf) self.date_hospitalized = np.full(shape=pop_size, fill_value=np.inf) self.date_died = np.full(shape=pop_size, fill_value=np.inf)
def sample_y(dist_y, theta, y, sigma_obs=None): if not sigma_obs: if dist_y == 'gamma': sigma_obs = numpyro.sample('sigma_obs', dist.Exponential(1)) else: sigma_obs = numpyro.sample('sigma_obs', dist.HalfNormal(1)) if dist_y == 'student': numpyro.sample('y', dist.StudentT(numpyro.sample('nu_y', dist.Gamma(1, .1)), theta, sigma_obs), obs=y) elif dist_y == 'normal': numpyro.sample('y', dist.Normal(theta, sigma_obs), obs=y) elif dist_y == 'lognormal': numpyro.sample('y', dist.LogNormal(theta, sigma_obs), obs=y) elif dist_y == 'gamma': numpyro.sample('y', dist.Gamma(jnp.exp(theta), sigma_obs), obs=y) elif dist_y == 'gamma_raw': numpyro.sample('y', dist.Gamma(theta, sigma_obs), obs=y) elif dist_y == 'poisson': numpyro.sample('y', dist.Poisson(theta), obs=y) elif dist_y == 'exponential': numpyro.sample('y', dist.Exponential(jnp.exp(theta)), obs=y) elif dist_y == 'exponential_raw': numpyro.sample('y', dist.Exponential(theta), obs=y) elif dist_y == 'uniform': numpyro.sample('y', dist.Uniform(0, 1), obs=y) else: raise NotImplementedError
def init_to_uniform(site=None, radius=2): """ Initialize to a random point in the area `(-radius, radius)` of unconstrained domain. :param float radius: specifies the range to draw an initial point in the unconstrained domain. """ if site is None: return partial(init_to_uniform, radius=radius) if (site["type"] == "sample" and not site["is_observed"] and not site["fn"].support.is_discrete): if site["value"] is not None: warnings.warn( f"init_to_uniform() skipping initialization of site '{site['name']}'" " which already stores a value.", stacklevel=find_stack_level(), ) return site["value"] # XXX: we import here to avoid circular import from numpyro.infer.util import helpful_support_errors rng_key = site["kwargs"].get("rng_key") sample_shape = site["kwargs"].get("sample_shape") with helpful_support_errors(site): transform = biject_to(site["fn"].support) unconstrained_shape = transform.inverse_shape(site["fn"].shape()) unconstrained_samples = dist.Uniform(-radius, radius)( rng_key=rng_key, sample_shape=sample_shape + unconstrained_shape) return transform(unconstrained_samples)
def pacs_model(priors): pointing_matrices = [([p.amat_row, p.amat_col], p.amat_data) for p in priors] flux_lower = np.asarray([p.prior_flux_lower for p in priors]).T flux_upper = np.asarray([p.prior_flux_upper for p in priors]).T bkg_mu = np.asarray([p.bkg[0] for p in priors]).T bkg_sig = np.asarray([p.bkg[1] for p in priors]).T with numpyro.plate('bands', len(priors)): sigma_conf = numpyro.sample('sigma_conf', dist.HalfCauchy(1.0, 0.5)) bkg = numpyro.sample('bkg', dist.Normal(bkg_mu, bkg_sig)) with numpyro.plate('nsrc', priors[0].nsrc): src_f = numpyro.sample('src_f', dist.Uniform(flux_lower, flux_upper)) db_hat_psw = sp_matmul(pointing_matrices[0], src_f[:, 0][:, None], priors[0].snpix).reshape(-1) + bkg[0] db_hat_pmw = sp_matmul(pointing_matrices[1], src_f[:, 1][:, None], priors[1].snpix).reshape(-1) + bkg[1] sigma_tot_psw = jnp.sqrt( jnp.power(priors[0].snim, 2) + jnp.power(sigma_conf[0], 2)) sigma_tot_pmw = jnp.sqrt( jnp.power(priors[1].snim, 2) + jnp.power(sigma_conf[1], 2)) with numpyro.plate('psw_pixels', priors[0].sim.size): # as ind_psw: numpyro.sample("obs_psw", dist.Normal(db_hat_psw, sigma_tot_psw), obs=priors[0].sim) with numpyro.plate('pmw_pixels', priors[1].sim.size): # as ind_pmw: numpyro.sample("obs_pmw", dist.Normal(db_hat_pmw, sigma_tot_pmw), obs=priors[1].sim)
def model(center1, center2, radius, width, enum=False): z = numpyro.sample("z", dist.Bernoulli(0.5), infer={"enumerate": "parallel"} if enum else {}) x = numpyro.sample("x", dist.Uniform(-6.0, 6.0).expand([2]).to_event(1)) center = jnp.stack([center1, center2])[z] numpyro.sample("shell", GaussianShell(center, radius, width), obs=x)
def model_c(t1, y1, e1): P = numpyro.sample('P', dist.Uniform(8.0, 12.0)) # should be modified Jeffery later Ksini = numpyro.sample('Ksini', dist.Exponential(0.1)) T0 = numpyro.sample('T0', dist.Uniform(-6.0, 6.0)) sesinw = numpyro.sample('sesinw', dist.Uniform(-1.0, 1.0)) secosw = numpyro.sample('secosw', dist.Uniform(-1.0, 1.0)) etmp = sesinw**2 + secosw**2 e = jnp.where(etmp > 1.0, 1.0, etmp) omegaA = jnp.arctan2(sesinw, secosw) # sigmajit=numpyro.sample('sigmajit', dist.Uniform(0.1,100.0)) sigmajit = numpyro.sample('sigmajit', dist.Exponential(1.0)) Vsys = numpyro.sample('Vsys', dist.Uniform(-10, 10.0)) mu = rvf(t1, T0, P, e, omegaA, Ksini, Vsys) errall = jnp.sqrt(e1**2 + sigmajit**2) numpyro.sample('y1', dist.Normal(mu, errall), obs=y1) # -
def init_to_uniform(site=None, radius=2): """ Initialize to a random point in the area `(-radius, radius)` of unconstrained domain. :param float radius: specifies the range to draw an initial point in the unconstrained domain. """ if site is None: return partial(init_to_uniform, radius=radius) if site['type'] == 'sample' and not site['is_observed'] and not site['fn'].is_discrete: rng_key = site['kwargs'].get('rng_key') sample_shape = site['kwargs'].get('sample_shape') rng_key, subkey = random.split(rng_key) # this is used to interpret the changes of event_shape in # domain and codomain spaces try: prototype_value = site['fn'].sample(subkey, sample_shape=()) except NotImplementedError: # XXX: this works for ImproperUniform prior, # we can't use this logic for general priors # because some distributions such as TransformedDistribution might # have wrong event_shape. prototype_value = jnp.full(site['fn'].shape(), jnp.nan) transform = biject_to(site['fn'].support) unconstrained_shape = jnp.shape(transform.inv(prototype_value)) unconstrained_samples = dist.Uniform(-radius, radius).sample( rng_key, sample_shape=sample_shape + unconstrained_shape) return transform(unconstrained_samples)
def test_elbo_dynamic_support(): x_prior = dist.Uniform(0, 5) x_unconstrained = 2. def model(): numpyro.sample('x', x_prior) class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_unconstrained})(*args, **kwargs) adam = optim.Adam(0.01) guide = _AutoGuide(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) guide_log_prob = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum() transfrom = transforms.biject_to(constraints.interval(0, 5)) x = transfrom(x_unconstrained) logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x) model_log_prob = x_prior.log_prob(x) + logdet expected_loss = guide_log_prob - model_log_prob assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def model(X, y=None): ndims = np.shape(X)[-1] ws = numpyro.sample('betas', dist.Normal(0.0,10.0*np.ones(ndims))) b = numpyro.sample('b', dist.Normal(0.0, 10.0)) sigma = numpyro.sample('sigma', dist.Uniform(0.0, 10.0)) f = numpyro.sample('f', dist.Normal(0.0, 2.5)) mu = f * (X @ ws + b) return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
def partially_pooled(at_bats: jnp.ndarray, hits: Optional[jnp.ndarray] = None) -> None: m = numpyro.sample("m", dist.Uniform(0, 1)) kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5)) num_players = at_bats.shape[0] with numpyro.plate("num_players", num_players): phi = numpyro.sample("phi", dist.Beta(m * kappa, (1 - m) * kappa)) numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
def model(data): xc = numpyro.sample('xc', dist.Normal(0, 10)) yc = numpyro.sample('yc', dist.Normal(0, 10)) w = numpyro.sample('w', dist.LogNormal(0, 10)) h = numpyro.sample('h', dist.LogNormal(0, 10)) phi = numpyro.sample('phi', dist.Uniform(0, np.pi)) px, py = forward(xc, yc, w, h, phi) sigma = numpyro.sample('sigma', dist.Exponential(1.)) return numpyro.sample('obs', dist.Normal(np.vstack([px, py]), sigma), obs=data)
def comp_Tarr(okey): okey, key = random.split(okey) lnsT = numpyro.sample('lnsT', dist.Uniform(3.0, 5.0), rng_key=key) # lnsT=4.0 sT = 10**lnsT okey, key = random.split(okey) lntaup = numpyro.sample('lntaup', dist.Uniform(0, 1), rng_key=key) # lntaup=0.5 taup = 10**lntaup cov = modelcov(lnParr, taup, sT) okey, key = random.split(okey) T0 = numpyro.sample('T0', dist.Uniform(800, 1500), rng_key=key) okey, key = random.split(okey) Tarr = numpyro.sample('Tarr', dist.MultivariateNormal( loc=ONEARR, covariance_matrix=cov), rng_key=key)+T0 # lnT0=3.0 #1000K #lnTarr=numpyro.sample("Tarr", dist.MultivariateNormal(loc=lnT0*ONEARR, covariance_matrix=cov),rng_key=key) # Tarr=10**lnTarr return Tarr
def fully_pooled(at_bats, hits=None): r""" Number of hits in $K$ at bats for each player has a Binomial distribution with a common probability of success, $\phi$. :param (np.DeviceArray) at_bats: Number of at bats for each player. :param (np.DeviceArray) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ phi_prior = dist.Uniform(np.array([0.]), np.array([1.])) phi = sample("phi", phi_prior) return sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
def sample_data2(): N = 100 # number of individuals # sim total height of each height = dist.Normal(10, 2).sample(random.PRNGKey(0), (N,)) # leg as proportion of height leg_prop = dist.Uniform(0.4, 0.5).sample(random.PRNGKey(1), (N,)) # sim left leg as proportion + error leg_left = leg_prop * height + dist.Normal(0, 0.02).sample(random.PRNGKey(2), (N,)) # sim right leg as proportion + error leg_right = leg_prop * height + dist.Normal(0, 0.02).sample(random.PRNGKey(3), (N,)) # combine into data frame d = pd.DataFrame({"height": height, "leg_left": leg_left, "leg_right": leg_right}) return d
def model_4(capture_history, sex): N, T = capture_history.shape # survival probabilities for males/females phi_male = numpyro.sample("phi_male", dist.Uniform(0.0, 1.0)) phi_female = numpyro.sample("phi_female", dist.Uniform(0.0, 1.0)) # we construct a N-dimensional vector that contains the appropriate # phi for each individual given its sex (female = 0, male = 1) phi = sex * phi_male + (1.0 - sex) * phi_female rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample( "z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), infer={"enumerate": "parallel"}, ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y ) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan( transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1), )
def sample(self): rng_key, rng_key_sample, rng_key_accept = split( self.nmc_status.rng_key, 3) params = self.nmc_status.params for site in params.keys(): # Collect accepted trace for i in range(len(params[site])): self.acc_trace[site + str(i)].append(params[site][i]) tr_current = trace(substitute(self.model, params)).get_trace( *self.model_args, **self.model_kwargs) ll_current = self.nmc_status.log_likelihood val_current = tr_current[site]["value"] dist_curr = tr_current[site]["fn"] def log_den_fun(var): return partial(log_density, self.model, self.model_args, self.model_kwargs)(var)[0] val_proposal, dist_proposal = self.proposal( site, log_den_fun, self.get_params(tr_current), dist_curr, rng_key_sample) tr_proposal = self.retrace(site, tr_current, dist_proposal, val_proposal, self.model_args, self.model_kwargs) ll_proposal = log_density(self.model, self.model_args, self.model_kwargs, self.get_params(tr_proposal))[0] ll_proposal_val = dist_proposal.log_prob(val_current).sum() ll_current_val = dist_curr.log_prob(val_proposal).sum() hastings_ratio = (ll_proposal + ll_proposal_val) - \ (ll_current + ll_current_val) accept_prob = np.minimum(1, np.exp(hastings_ratio)) u = sample("u", dist.Uniform(0, 1), rng_key=rng_key_accept) if u <= accept_prob: params, ll_current = self.get_params(tr_proposal), ll_proposal else: params, ll_current = self.get_params(tr_current), ll_current iter = self.nmc_status.i + 1 mean_accept_prob = self.nmc_status.accept_prob + \ (accept_prob - self.nmc_status.accept_prob) / iter return NMC_STATUS(iter, params, ll_current, mean_accept_prob, rng_key)