def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) subs_guide = substitute(seeded_guide, data=param_map) guide_trace = trace(subs_guide).get_trace(*args, **kwargs) mutable_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "mutable" } params.update(mutable_params) subs_model = substitute(replay(seeded_model, guide_trace), data=params) model_trace = trace(subs_model).get_trace(*args, **kwargs) mutable_params.update({ name: site["value"] for name, site in model_trace.items() if site["type"] == "mutable" }) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") _check_mean_field_requirement(model_trace, guide_trace) elbo_particle = 0 for name, model_site in model_trace.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + _get_log_prob_sum( model_site) else: guide_site = guide_trace[name] try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"]) elbo_particle = elbo_particle - jnp.sum(kl_qp) except NotImplementedError: elbo_particle = (elbo_particle + _get_log_prob_sum(model_site) - _get_log_prob_sum(guide_site)) # handle auxiliary sites in the guide for name, site in guide_trace.items(): if site["type"] == "sample" and name not in model_trace: assert site["infer"].get( "is_auxiliary") or site["is_observed"] elbo_particle = elbo_particle - _get_log_prob_sum(site) if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params else: raise ValueError( "Currently, we only support mutable states with num_particles=1." ) else: return elbo_particle, None
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) subs_guide = substitute(seeded_guide, data=param_map) guide_trace = trace(subs_guide).get_trace(*args, **kwargs) subs_model = substitute(replay(seeded_model, guide_trace), data=param_map) model_trace = trace(subs_model).get_trace(*args, **kwargs) _check_mean_field_requirement(model_trace, guide_trace) elbo_particle = 0 for name, model_site in model_trace.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + _get_log_prob_sum(model_site) else: guide_site = guide_trace[name] try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"]) elbo_particle = elbo_particle - jnp.sum(kl_qp) except NotImplementedError: elbo_particle = elbo_particle + _get_log_prob_sum(model_site) \ - _get_log_prob_sum(guide_site) # handle auxiliary sites in the guide for name, site in guide_trace.items(): if site["type"] == "sample" and name not in model_trace: assert site["infer"].get("is_auxiliary") elbo_particle = elbo_particle - _get_log_prob_sum(site) return elbo_particle