def _create_traces_check_error_string(model, guide, expected_string): model_trace = numpyro.handlers.trace( numpyro.handlers.seed(model, rng_seed=42)).get_trace() guide_trace = numpyro.handlers.trace( numpyro.handlers.seed(guide, rng_seed=42)).get_trace() with pytest.raises(ValueError, match=expected_string): check_model_guide_match(model_trace, guide_trace)
def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) subs_guide = substitute(seeded_guide, data=param_map) guide_trace = trace(subs_guide).get_trace(*args, **kwargs) mutable_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "mutable" } params.update(mutable_params) subs_model = substitute(replay(seeded_model, guide_trace), data=params) model_trace = trace(subs_model).get_trace(*args, **kwargs) mutable_params.update({ name: site["value"] for name, site in model_trace.items() if site["type"] == "mutable" }) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") _check_mean_field_requirement(model_trace, guide_trace) elbo_particle = 0 for name, model_site in model_trace.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + _get_log_prob_sum( model_site) else: guide_site = guide_trace[name] try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"]) elbo_particle = elbo_particle - jnp.sum(kl_qp) except NotImplementedError: elbo_particle = (elbo_particle + _get_log_prob_sum(model_site) - _get_log_prob_sum(guide_site)) # handle auxiliary sites in the guide for name, site in guide_trace.items(): if site["type"] == "sample" and name not in model_trace: assert site["infer"].get( "is_auxiliary") or site["is_observed"] elbo_particle = elbo_particle - _get_log_prob_sum(site) if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params else: raise ValueError( "Currently, we only support mutable states with num_particles=1." ) else: return elbo_particle, None
def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density( seeded_guide, args, kwargs, param_map) mutable_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "mutable" } params.update(mutable_params) seeded_model = replay(seeded_model, guide_trace) model_log_density, model_trace = log_density( seeded_model, args, kwargs, params) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") mutable_params.update({ name: site["value"] for name, site in model_trace.items() if site["type"] == "mutable" }) # log p(z) - log q(z) elbo_particle = model_log_density - guide_log_density if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params else: raise ValueError( "Currently, we only support mutable states with num_particles=1." ) else: return elbo_particle, None
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density( seeded_guide, args, kwargs, param_map) # NB: we only want to substitute params not available in guide_trace model_param_map = { k: v for k, v in param_map.items() if k not in guide_trace } seeded_model = replay(seeded_model, guide_trace) model_log_density, model_trace = log_density( seeded_model, args, kwargs, model_param_map) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") # log p(z) - log q(z) elbo = model_log_density - guide_log_density return elbo
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) model_trace, guide_trace = get_importance_trace( seeded_model, seeded_guide, args, kwargs, param_map) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="strict") # XXX: different from Pyro, we don't support baseline_loss here non_reparam_nodes = { name for name, site in guide_trace.items() if site["type"] == "sample" and (not site["is_observed"]) and ( not site["fn"].has_rsample) } if non_reparam_nodes: downstream_costs, _ = _compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) elbo = 0.0 for site in model_trace.values(): if site["type"] == "sample": elbo = elbo + jnp.sum(site["log_prob"]) for name, site in guide_trace.items(): if site["type"] == "sample": log_prob_sum = jnp.sum(site["log_prob"]) if name in non_reparam_nodes: surrogate = jnp.sum( site["log_prob"] * stop_gradient(downstream_costs[name])) log_prob_sum = ( stop_gradient(log_prob_sum + surrogate) - surrogate) elbo = elbo - log_prob_sum return elbo