def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) subs_guide = substitute(seeded_guide, data=param_map) guide_trace = trace(subs_guide).get_trace(*args, **kwargs) mutable_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "mutable" } params.update(mutable_params) subs_model = substitute(replay(seeded_model, guide_trace), data=params) model_trace = trace(subs_model).get_trace(*args, **kwargs) mutable_params.update({ name: site["value"] for name, site in model_trace.items() if site["type"] == "mutable" }) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") _check_mean_field_requirement(model_trace, guide_trace) elbo_particle = 0 for name, model_site in model_trace.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + _get_log_prob_sum( model_site) else: guide_site = guide_trace[name] try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"]) elbo_particle = elbo_particle - jnp.sum(kl_qp) except NotImplementedError: elbo_particle = (elbo_particle + _get_log_prob_sum(model_site) - _get_log_prob_sum(guide_site)) # handle auxiliary sites in the guide for name, site in guide_trace.items(): if site["type"] == "sample" and name not in model_trace: assert site["infer"].get( "is_auxiliary") or site["is_observed"] elbo_particle = elbo_particle - _get_log_prob_sum(site) if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params else: raise ValueError( "Currently, we only support mutable states with num_particles=1." ) else: return elbo_particle, None
def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density( seeded_guide, args, kwargs, param_map) mutable_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "mutable" } params.update(mutable_params) seeded_model = replay(seeded_model, guide_trace) model_log_density, model_trace = log_density( seeded_model, args, kwargs, params) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") mutable_params.update({ name: site["value"] for name, site in model_trace.items() if site["type"] == "mutable" }) # log p(z) - log q(z) elbo_particle = model_log_density - guide_log_density if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params else: raise ValueError( "Currently, we only support mutable states with num_particles=1." ) else: return elbo_particle, None
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density( seeded_guide, args, kwargs, param_map) # NB: we only want to substitute params not available in guide_trace model_param_map = { k: v for k, v in param_map.items() if k not in guide_trace } seeded_model = replay(seeded_model, guide_trace) model_log_density, model_trace = log_density( seeded_model, args, kwargs, model_param_map) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") # log p(z) - log q(z) elbo = model_log_density - guide_log_density return elbo
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) model_trace, guide_trace = get_importance_trace( seeded_model, seeded_guide, args, kwargs, param_map) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="strict") # XXX: different from Pyro, we don't support baseline_loss here non_reparam_nodes = { name for name, site in guide_trace.items() if site["type"] == "sample" and (not site["is_observed"]) and ( not site["fn"].has_rsample) } if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) elbo = 0.0 for site in model_trace.values(): if site["type"] == "sample": elbo = elbo + jnp.sum(site["log_prob"]) for name, site in guide_trace.items(): if site["type"] == "sample": log_prob_sum = jnp.sum(site["log_prob"]) if name in non_reparam_nodes: surrogate = jnp.sum( site["log_prob"] * stop_gradient(downstream_costs[name])) log_prob_sum = ( stop_gradient(log_prob_sum + surrogate) - surrogate) elbo = elbo - log_prob_sum return elbo
def initialize_model( rng_key, model, *, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` only supports forward-mode differentiation. See `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_ for more information. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy, ) ( inv_transforms, replay_model, has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["param"] }, ) constrained_values = { k: v["value"] for k, v in model_trace.items() if v["type"] == "sample" and not v["is_observed"] and not v["fn"].support.is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) _validate_model(model_trace, plate_warning="error") model = enum(config_enumerate(model), -max_plate_nesting - 1) else: _validate_model(model_trace, plate_warning="loose") potential_fn, postprocess_fn = get_potential_fn( model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs, ) init_strategy = (init_strategy if isinstance(init_strategy, partial) else init_strategy()) if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["plate"] }, ), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params, forward_mode_differentiation=forward_mode_differentiation, validate_grad=validate_grad, ) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): with numpyro.validation_enabled(), trace() as tr: # validate parameters substituted_model(*model_args, **model_kwargs) # validate values for site in tr.values(): if site["type"] == "sample": with warnings.catch_warnings(record=True) as ws: site["fn"]._validate_sample(site["value"]) if len(ws) > 0: for w in ws: # at site information to the warning message w.message.args = ("Site {}: {}".format( site["name"], w.message.args[0]), ) + w.message.args[1:] warnings.showwarning( w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line, ) raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)