def __exit__(self, *args, **kwargs): import funsor _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(*args, **kwargs) # Convert delayed statements to pyro.factor() reduced_vars = [] log_prob_terms = [] plates = frozenset() for name, site in self.trace.items(): if not site["is_observed"]: reduced_vars.append(name) dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]} fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name) value = site["value"] if not isinstance(value, str): value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name) log_prob_terms.append(fn(value=value)) plates |= frozenset(f.name for f in site["cond_indep_stack"]) assert log_prob_terms, "nothing to collapse" reduced_plates = plates - self.preserved_plates log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, log_prob_terms, eliminate=frozenset(reduced_vars) | reduced_plates, plates=plates, ) name = reduced_vars[0] numpyro.factor(name, log_prob.data)
def guide(): with numpyro.handlers.seed(rng_seed=1): x = numpyro.sample("x", dist.Normal(0, 1)) called.add("guide-always") if numpyro.get_mask() is not False: called.add("guide-sometimes") numpyro.factor("g", 2 - x)
def _unconstrain_reparam(params, site): name = site["name"] if name in params: p = params[name] support = site["fn"].support t = biject_to(support) # in scan, we might only want to substitute an item at index i, rather than the whole sequence i = site["infer"].get("_scan_current_index", None) if i is not None: event_dim_shift = t.codomain.event_dim - t.domain.event_dim expected_unconstrained_dim = len( site["fn"].shape()) - event_dim_shift # check if p has additional time dimension if jnp.ndim(p) > expected_unconstrained_dim: p = p[i] if support in [constraints.real, constraints.real_vector]: return p value = t(p) log_det = t.log_abs_det_jacobian(p, value) log_det = sum_rightmost( log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape)) if site["scale"] is not None: log_det = site["scale"] * log_det numpyro.factor("_{}_log_det".format(name), log_det) return value
def __call__(self, name, fn, obs): if name not in self.guide.prototype_trace: return fn, obs assert obs is None, "NeuTraReparam does not support observe statements" log_density = 0. if not self._x_unconstrained: # On first sample site. # Sample a shared latent. z_unconstrained = numpyro.sample( "{}_shared_latent".format(self.guide.prefix), self.guide.get_base_dist().mask(False)) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) # TODO: find a way to only compute those log_prob terms when needed log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained) self._x_unconstrained = self.guide._unpack_latent(x_unconstrained) # Extract a single site's value from the shared latent. unconstrained_value = self._x_unconstrained.pop(name) transform = biject_to(fn.support) value = transform(unconstrained_value) logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost( logdet, jnp.ndim(logdet) - jnp.ndim(value) + len(fn.event_shape)) log_density = log_density + fn.log_prob(value) + logdet numpyro.factor("_{}_log_prob".format(name), log_density) return None, value
def semi_supervised_hmm(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words): num_categories, num_words = transition_prior.shape[ 0], emission_prior.shape[0] transition_prob = numpyro.sample( 'transition_prob', dist.Dirichlet( np.broadcast_to(transition_prior, (num_categories, num_categories)))) emission_prob = numpyro.sample( 'emission_prob', dist.Dirichlet( np.broadcast_to(emission_prior, (num_categories, num_words)))) # models supervised data; # here we don't make any assumption about the first supervised category, in other words, # we place a flat/uniform prior on it. numpyro.sample('supervised_categories', dist.Categorical( transition_prob[supervised_categories[:-1]]), obs=supervised_categories[1:]) numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]), obs=supervised_words) # computes log prob of unsupervised data transition_log_prob = np.log(transition_prob) emission_log_prob = np.log(emission_prob) init_log_prob = emission_log_prob[:, unsupervised_words[0]] log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:], transition_log_prob, emission_log_prob) log_prob = logsumexp(log_prob, axis=0, keepdims=True) # inject log_prob to potential function numpyro.factor('forward_log_prob', log_prob)
def model(): with numpyro.handlers.seed(rng_seed=0): x = numpyro.sample("x", dist.Normal(0, 1)) numpyro.sample("y", dist.Normal(x, 1), obs=0.) called.add("model-always") if numpyro.get_mask() is not False: called.add("model-sometimes") numpyro.factor("f", x + 1)
def _unconstrain_reparam(params, site): name = site['name'] if name in params: p = params[name] t = biject_to(site['fn'].support) value = t(p) log_det = t.log_abs_det_jacobian(p, value) log_det = sum_rightmost( log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site['fn'].event_shape)) if site['scale'] is not None: log_det = site['scale'] * log_det numpyro.factor('_{}_log_det'.format(name), log_det) return value
def __exit__(self, exc_type, exc_value, traceback): # make sure exit trackback is nice if an error happens super().__exit__(exc_type, exc_value, traceback) if exc_type is not None: return if self.params is None: return if numpyro.get_mask() is not False: numpyro.factor("_biased_corrected_log_likelihood", self.method(self.likelihoods, self.params, self.gibbs_state)) # clean up self.params = None self.likelihoods = {} self.subsample_plates = {} self.gibbs_state = None
def __exit__(self, exc_type, exc_value, traceback): # make sure exit trackback is nice if an error happens super().__exit__(exc_type, exc_value, traceback) if exc_type is not None: return if self.params is None: return # add numpyro.factor; ideally, we will want to skip this computation when making prediction # see: https://github.com/pyro-ppl/pyro/issues/2744 numpyro.factor( "_biased_corrected_log_likelihood", self.method(self.likelihoods, self.params, self.gibbs_state)) # clean up self.params = None self.likelihoods = {} self.subsample_plates = {} self.gibbs_state = None
def semi_supervised_hmm( num_categories: int, num_words: int, supervised_categories: jnp.ndarray, supervised_words: jnp.ndarray, unsupervised_words: jnp.ndarray, ) -> None: transition_prior = jnp.ones(num_categories) emission_prior = jnp.repeat(0.1, num_words) transition_prob = numpyro.sample( "transition_prob", dist.Dirichlet( jnp.broadcast_to(transition_prior, (num_categories, num_categories))), ) emission_prob = numpyro.sample( "emission_prob", dist.Dirichlet( jnp.broadcast_to(emission_prior, (num_categories, num_words))), ) numpyro.sample( "supervised_categories", dist.Categorical(transition_prob[supervised_categories[:-1]]), obs=supervised_categories[1:], ) numpyro.sample( "supervised_words", dist.Categorical(emission_prob[supervised_categories]), obs=supervised_words, ) transition_log_prob = jnp.log(transition_prob) emission_log_prob = jnp.log(emission_prob) init_log_prob = emission_log_prob[:, unsupervised_words[0]] log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:], transition_log_prob, emission_log_prob) log_prob = logsumexp(log_prob, axis=0, keepdims=True) numpyro.factor("forward_log_prob", log_prob)
def __call__(self, name, fn, obs): # Support must be circular support = fn.support if isinstance(support, constraints.independent): support = fn.support.base_constraint assert support is constraints.circular # Draw parameter-free noise. new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape) value = numpyro.sample( f"{name}_unwrapped", new_fn, obs=obs, ) # Differentiably transform. value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi # Simulate a pyro.deterministic() site. numpyro.factor(f"{name}_factor", fn.log_prob(value)) return None, value
def to_numpyro(self, y=None): f_loc = self.mean(self.X) _, W, D = vfe_precompute( self.X, self.X_u, self.obs_noise, self.kernel, jitter=self.jitter ) numpyro.factor("trace_term", self.trace_term(self.X, W, self.obs_noise)) # Sample y according SGP if y is not None: return numpyro.sample( "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D) .expand_by(self.y.shape[:-1]) .to_event(self.y.ndim - 1), obs=self.y, ) else: return numpyro.sample( "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D) )
def model(): with handlers.mask(mask=jnp.zeros(10, dtype=bool)): numpyro.factor('inf', -jnp.inf)
def _amplitude_prior(self, he_amp, cz_amp): # Prior that log(he_amp) == log(cz_amp) is a 2-sigma event # and the He amplitude is ~ 4 times the BCZ amplitude (log10(4) ~ 0.6) delta = 2.0 * (jnp.log10(he_amp) - jnp.log10(cz_amp) - 0.6) / 0.6 logp = numpyro.distributions.Normal().log_prob(delta) numpyro.factor("amp", logp)
def __call__( self, n: ArrayLike, nu: Optional[ArrayLike] = None, nu_err: Optional[ArrayLike] = None, n_pred: Optional[ArrayLike] = None, ): """Sample the model for given observables. Args: nu (:term:`array_like`, optional): Observed radial mode frequencies. nu_err (:term:`array_like`, optional): Gaussian observational uncertainties (sigma) for nu. pred (bool): If True, make predictions nu and nu_pred from n and num_pred. """ # Same kernel function for both models var = numpyro.param("kernel_var", self._kernel_var) length = numpyro.param("kernel_length", self._kernel_length) kernel = var * kernels.ExpSquared(length) diag = 1e-6 if nu_err is None else nu_err**2 # No need for jitter args = ("models", 2) with dimension(*args): with numpyro.plate(*args): # Broadcast background function to both models bkg_func = self.background() # MODEL 0 with numpyro.handlers.scope(prefix=self._prefix, divider="."): # Contain null model parameters in the null scope def mean0(n): return jnp.squeeze(bkg_func(n)[0]) # gp0 = GP(kernel, mean=mean0) # dist0 = gp0.distribution(n, noise=nu_err) gp0 = GaussianProcess(kernel, n, mean=mean0, diag=diag) dist0 = gp0.numpyro_dist() with dimension("n", n.shape[-1], coords=n): nu0 = numpyro.sample("nu_obs", dist0, obs=nu) numpyro.deterministic("nu_bkg", bkg_func(n)[0]) if n_pred is not None: with dimension("n", n.shape[-1], coords=n): # gp0.predict("nu", n) numpyro.sample( "nu", gp0.condition(nu0, n).gp.numpyro_dist(), ) with dimension("n_pred", n_pred.shape[-1], coords=n_pred): # gp0.predict("nu_pred", n_pred) numpyro.sample( "nu_pred", gp0.condition(nu0, n_pred).gp.numpyro_dist()) numpyro.deterministic("nu_bkg_pred", bkg_func(n_pred)[0]) # MODEL 1 he_glitch_func = self.he_glitch() cz_glitch_func = self.cz_glitch() def mean(n): nu_bkg = jnp.squeeze(bkg_func(n)[1]) return nu_bkg + he_glitch_func(nu_bkg) + cz_glitch_func(nu_bkg) # gp = GP(kernel, mean=mean) # dist = gp.distribution(n, noise=nu_err) gp = GaussianProcess(kernel, n, mean=mean, diag=diag) dist = gp.numpyro_dist() with dimension("n", n.shape[-1], coords=n): nu = numpyro.sample("nu_obs", dist, obs=nu) # redefines nu! nu_bkg = numpyro.deterministic("nu_bkg", bkg_func(n)[1]) numpyro.deterministic("dnu_he", he_glitch_func(nu_bkg)) numpyro.deterministic("dnu_cz", cz_glitch_func(nu_bkg)) if n_pred is not None: with dimension("n", n.shape[-1], coords=n): # gp.predict("nu", n) numpyro.sample("nu", gp.condition(nu, n).gp.numpyro_dist()) with dimension("n_pred", n_pred.shape[-1], coords=n_pred): # gp.predict("nu_pred", n_pred) numpyro.sample("nu_pred", gp.condition(nu, n_pred).gp.numpyro_dist()) nu_bkg = numpyro.deterministic("nu_bkg_pred", bkg_func(n_pred)[1]) numpyro.deterministic("dnu_he_pred", he_glitch_func(nu_bkg)) numpyro.deterministic("dnu_cz_pred", cz_glitch_func(nu_bkg)) # Other deterministics and priors self._amplitude_prior(*self._glitch_amplitudes(nu)) # LIKELIHOOD # Model comparison - if nu is not None, then nu0 == nu logL0 = dist0.log_prob(nu0) logL = dist.log_prob(nu) numpyro.factor("obs", (logL0 + logL).sum()) # Log10 Bayes factor numpyro.deterministic("log_k", (logL - logL0).sum() / np.log(10.0))
def SparseGP(X, y): n_samples = X.shape[0] X = numpyro.deterministic("X", X) # Set priors on kernel hyperparameters. η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0)) ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0)) σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0)) x_u = numpyro.param("x_u", init_value=X_u_init) # η = numpyro.param("kernel_var", init_value=1.0, constraints=dist.constraints.positive) # ℓ = numpyro.param("kernel_length", init_value=0.1, constraints=dist.constraints.positive) # σ = numpyro.param("sigma", init_value=0.1, onstraints=dist.constraints.positive) # ================================ # Mean Function # ================================ f_loc = np.zeros(n_samples) # ================================ # Qff Term # ================================ # W = (inv(Luu) @ Kuf).T # Qff = Kfu @ inv(Kuu) @ Kuf # Qff = W @ W.T # ================================ Kuu = rbf_kernel(x_u, x_u, η, ℓ) Kuf = rbf_kernel(x_u, X, η, ℓ) # Kuu += jnp.eye(Ninducing) * jitter # add jitter Kuu = add_to_diagonal(Kuu, jitter) # cholesky factorization Luu = cholesky(Kuu, lower=True) Luu = numpyro.deterministic("Luu", Luu) # W matrix W = solve_triangular(Luu, Kuf, lower=True) W = numpyro.deterministic("W", W).T # ================================ # Likelihood Noise Term # ================================ # D = noise # ================================ D = numpyro.deterministic("G", jnp.ones(n_samples) * σ) # ================================ # trace term # ================================ # t = tr(Kff - Qff) / noise # t /= - 2.0 # ================================ Kffdiag = jnp.diag(rbf_kernel(X, X, η, ℓ)) Qffdiag = jnp.power(W, 2).sum(axis=1) trace_term = (Kffdiag - Qffdiag).sum() / σ trace_term = jnp.clip(trace_term, a_min=0.0) # numerical errors # add trace term to the log probability loss numpyro.factor("trace_term", -trace_term / 2.0) # Sample y according SGP return numpyro.sample( "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D).expand_by( y.shape[:-1]).to_event(y.ndim - 1), obs=y, )
def model(): a = numpyro.sample("a", dist.Normal(0, 1)) numpyro.factor("b", 0.0) numpyro.factor("c", a)