def main(_args): data = generate_data() init_rng_key = PRNGKey(1273) # nuts = NUTS(gmm) # mcmc = MCMC(nuts, 100, 1000) # mcmc.print_summary() seeded_gmm = seed(gmm, init_rng_key) model_trace = trace(seeded_gmm).get_trace(data) max_plate_nesting = _guess_max_plate_nesting(model_trace) enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1) svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.)) svi_state = svi.init(init_rng_key, data) upd_fun = jax.jit(svi.update) with tqdm.trange(100_000) as pbar: for i in pbar: svi_state, loss = upd_fun(svi_state, data) pbar.set_description(f"SVI {loss}", True)
def run(self, rng_key, *args, **kwargs): """ Run the nested samplers and collect weighted samples. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. :param args: The arguments needed by the `model`. :param kwargs: The keyword arguments needed by the `model`. """ rng_sampling, rng_predictive = random.split(rng_key) # reparam the model so that latent sites have Uniform(0, 1) priors prototype_trace = trace(seed(self.model, rng_key)).get_trace(*args, **kwargs) param_names = [ site["name"] for site in prototype_trace.values() if site["type"] == "sample" and not site["is_observed"] and site["infer"].get("enumerate", "") != "parallel" ] deterministics = [ site["name"] for site in prototype_trace.values() if site["type"] == "deterministic" ] reparam_model = reparam( self.model, config={k: UniformReparam() for k in param_names}) # enable enumerate if needed has_enum = any(site["type"] == "sample" and site["infer"].get("enumerate", "") == "parallel" for site in prototype_trace.values()) if has_enum: from numpyro.contrib.funsor import enum, log_density as log_density_ max_plate_nesting = _guess_max_plate_nesting(prototype_trace) _validate_model(prototype_trace) reparam_model = enum(reparam_model, -max_plate_nesting - 1) else: log_density_ = log_density def loglik_fn(**params): return log_density_(reparam_model, args, kwargs, params)[0] # use NestedSampler with identity prior chain prior_chain = PriorChain() for name in param_names: prior = UniformPrior(name + "_base", prototype_trace[name]["fn"].shape()) prior_chain.push(prior) # XXX: the `marginalised` keyword in jaxns can be used to get expectation of some # quantity over posterior samples; it can be helpful to expose it in this wrapper ns = OrigNestedSampler( loglik_fn, prior_chain, sampler_name=self.sampler_name, sampler_kwargs={ "depth": self.depth, "num_slices": self.num_slices }, max_samples=self.max_samples, num_live_points=self.num_live_points, collect_samples=True, ) # some places of jaxns uses float64 and raises some warnings if the default dtype is # float32, so we suppress them here to avoid confusion with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=".*will be truncated to dtype float32.*") results = ns(rng_sampling, termination_frac=self.termination_frac) # transform base samples back to original domains # Here we only transform the first valid num_samples samples # NB: the number of weighted samples obtained from jaxns is results.num_samples # and only the first num_samples values of results.samples are valid. num_samples = results.num_samples samples = tree_util.tree_map(lambda x: x[:num_samples], results.samples) predictive = Predictive(reparam_model, samples, return_sites=param_names + deterministics) samples = predictive(rng_predictive, *args, **kwargs) # replace base samples in jaxns results by transformed samples self._results = results._replace(samples=samples)
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with block(), enum(first_available_dim=first_available_dim): with plate_to_enum_plate(): model_tr = packed_trace(model).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, list(terms["log_factors"].values()) + list(terms["log_measures"].values()), eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[terms["log_measures"][name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) with replay(guide_trace=sample_tr): return model(*args, **kwargs)
def discrete_gibbs_fn(model, model_args=(), model_kwargs={}, *, random_walk=False, modified=False): """ [EXPERIMENTAL INTERFACE] Returns a gibbs_fn to be used in :class:`HMCGibbs`, which works for discrete latent sites with enumerate support. The site update order is randomly permuted at each step. Note that those discrete latent sites that are not specified in the constructor of :class:`HMCGibbs` will be marginalized out by default (if they have enumerate supports). :param callable model: a callable with NumPyro primitives. This should be the same model as the one used in the `inner_kernel` of :class:`HMCGibbs`. :param tuple model_args: Arguments provided to the model. :param dict model_kwargs: Keyword arguments provided to the model. :param bool random_walk: If False, Gibbs sampling will be used to draw a sample from the conditional `p(gibbs_site | remaining sites)`. Otherwise, a sample will be drawn uniformly from the domain of `gibbs_site`. :param bool modified: whether to use a modified proposal, as suggested in reference [1], which always proposes a new state for the current Gibbs site. The modified scheme appears in the literature under the name "modified Gibbs sampler" or "Metropolised Gibbs sampler". :return: a callable `gibbs_fn` to be used in :class:`HMCGibbs` **References:** 1. *Peskun's theorem and a modified discrete-state Gibbs sampler*, Liu, J. S. (1996) **Example** .. doctest:: >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, NUTS, HMCGibbs, discrete_gibbs_fn ... >>> def model(probs, locs): ... c = numpyro.sample("c", dist.Categorical(probs)) ... numpyro.sample("x", dist.Normal(locs[c], 0.5)) ... >>> probs = jnp.array([0.15, 0.3, 0.3, 0.25]) >>> locs = jnp.array([-2, 0, 2, 4]) >>> gibbs_fn = discrete_gibbs_fn(model, (probs, locs)) >>> kernel = HMCGibbs(NUTS(model), gibbs_fn, gibbs_sites=["c"]) >>> mcmc = MCMC(kernel, 1000, 100000, progress_bar=False) >>> mcmc.run(random.PRNGKey(0), probs, locs) >>> mcmc.print_summary() # doctest: +SKIP """ # NB: all of the information such as `model`, `model_args`, `model_kwargs` # can be accessed from HMCGibbs.sample but we require them here to # simplify the api of `gibbs_fn` prototype_trace = trace(seed(model, rng_seed=0)).get_trace( *model_args, **model_kwargs) support_sizes = { name: jnp.broadcast_to(site["fn"].enumerate_support(False).shape[0], jnp.shape(site["value"])) for name, site in prototype_trace.items() if site["type"] == "sample" and site["fn"].has_enumerate_support and not site["is_observed"] } max_plate_nesting = _guess_max_plate_nesting(prototype_trace) if random_walk: if modified: proposal_fn = partial(_discrete_modified_rw_proposal, stay_prob=0.) else: proposal_fn = _discrete_rw_proposal else: if modified: proposal_fn = partial(_discrete_modified_gibbs_proposal, stay_prob=0.) else: proposal_fn = _discrete_gibbs_proposal def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # convert to unconstrained values z_hmc = { k: biject_to(prototype_trace[k]["fn"].support).inv(v) for k, v in hmc_sites.items() if k in prototype_trace and prototype_trace[k]["type"] == "sample" } use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 wrapped_model = _wrap_model(model) if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites return gibbs_fn
def _discrete_gibbs_fn(wrapped_model, model_args, model_kwargs, prototype_trace, random_walk=False, modified=False): support_sizes = { name: jnp.broadcast_to(site["fn"].enumerate_support(False).shape[0], jnp.shape(site["value"])) for name, site in prototype_trace.items() if site["type"] == "sample" and site["fn"].has_enumerate_support and not site["is_observed"] } max_plate_nesting = _guess_max_plate_nesting(prototype_trace) if random_walk: if modified: proposal_fn = partial(_discrete_modified_rw_proposal, stay_prob=0.) else: proposal_fn = _discrete_rw_proposal else: if modified: proposal_fn = partial(_discrete_modified_gibbs_proposal, stay_prob=0.) else: proposal_fn = _discrete_gibbs_proposal def gibbs_fn(rng_key, gibbs_sites, hmc_sites): z_hmc = hmc_sites use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model_ = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) else: wrapped_model_ = wrapped_model def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model_, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites return gibbs_fn
def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: initial :data:`SteinVIState` """ rng_key, kernel_seed, model_seed, guide_seed = jax.random.split( rng_key, 4) model_init = handlers.seed(self.model, model_seed) guide_init = handlers.seed(self.guide, guide_seed) guide_trace = handlers.trace(guide_init).get_trace( *args, **kwargs, **self.static_kwargs) model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs) rng_key, particle_seed = jax.random.split(rng_key) guide_init_params = self._find_init_params(particle_seed, self.guide, guide_trace) params = {} transforms = {} inv_transforms = {} particle_transforms = {} guide_param_names = set() should_enum = False for site in model_trace.values(): if ("fn" in site and site["type"] == "sample" and not site["is_observed"] and isinstance(site["fn"], Distribution) and site["fn"].is_discrete): if site["fn"].has_enumerate_support and self.enum: should_enum = True else: raise Exception( "Cannot enumerate model with discrete variables without enumerate support" ) # NB: params in model_trace will be overwritten by params in guide_trace for site in chain(model_trace.values(), guide_trace.values()): if site["type"] == "param": transform = get_parameter_transform(site) inv_transforms[site["name"]] = transform transforms[site["name"]] = transform.inv particle_transforms[site["name"]] = site.get( "particle_transform", IdentityTransform()) if site["name"] in guide_init_params: pval, _ = guide_init_params[site["name"]] if self.classic_guide_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: pval = site["value"] params[site["name"]] = transform.inv(pval) if site["name"] in guide_trace: guide_param_names.add(site["name"]) if should_enum: mpn = _guess_max_plate_nesting(model_trace) self._inference_model = enum(config_enumerate(self.model), -mpn - 1) self.guide_param_names = guide_param_names self.constrain_fn = partial(transform_fn, inv_transforms) self.uconstrain_fn = partial(transform_fn, transforms) self.particle_transforms = particle_transforms self.particle_transform_fn = partial(transform_fn, particle_transforms) stein_particles, _, _ = batch_ravel_pytree( { k: params[k] for k, site in guide_trace.items() if site["type"] == "param" and site["name"] in guide_init_params }, nbatch_dims=1, ) self.kernel_fn.init(kernel_seed, stein_particles.shape) return SteinVIState(self.optim.init(params), rng_key)
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with funsor.adjoint.AdjointTape() as tape: with block(), enum(first_available_dim=first_available_dim): log_prob, model_tr, log_measures = _enum_log_density( model, args, kwargs, {}, sum_op, prod_op) with approx: approx_factors = tape.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[log_measures[name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) data = { name: site["value"] for name, site in sample_tr.items() if site["type"] == "sample" } # concatenate _PREV_foo to foo time_vars = defaultdict(list) for name in data: if name.startswith("_PREV_"): root_name = _shift_name(name, -_get_shift(name)) time_vars[root_name].append(name) for name in time_vars: if name in data: time_vars[name].append(name) time_vars[name] = sorted(time_vars[name], key=len, reverse=True) for root_name, vars in time_vars.items(): prototype_shape = model_trace[root_name]["value"].shape values = [data.pop(name) for name in vars] if len(values) == 1: data[root_name] = values[0].reshape(prototype_shape) else: assert len(prototype_shape) >= 1 values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values] data[root_name] = jnp.concatenate(values) return data