def init(self, rng_key, *args, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace( *args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) inv_transforms[site['name']] = transform params[site['name']] = transform.inv(site['value']) self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run params = tree_map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), params) return SVIState(self.optim.init(params), rng_key)
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) subs_guide = substitute(seeded_guide, data=param_map) guide_trace = trace(subs_guide).get_trace(*args, **kwargs) subs_model = substitute(replay(seeded_model, guide_trace), data=param_map) model_trace = trace(subs_model).get_trace(*args, **kwargs) _check_mean_field_requirement(model_trace, guide_trace) elbo_particle = 0 for name, model_site in model_trace.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + _get_log_prob_sum(model_site) else: guide_site = guide_trace[name] try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"]) elbo_particle = elbo_particle - jnp.sum(kl_qp) except NotImplementedError: elbo_particle = elbo_particle + _get_log_prob_sum(model_site) \ - _get_log_prob_sum(guide_site) # handle auxiliary sites in the guide for name, site in guide_trace.items(): if site["type"] == "sample" and name not in model_trace: assert site["infer"].get("is_auxiliary") elbo_particle = elbo_particle - _get_log_prob_sum(site) return elbo_particle
def get_importance_trace(model, guide, args, kwargs, params): """ (EXPERIMENTAL) Returns traces from the guide and the model that is run against it. The returned traces also store the log probability at each site. .. note:: Gradients are blocked at latent sites which do not have reparametrized samplers. """ guide = substitute(guide, data=params) with _without_rsample_stop_gradient(): guide_trace = trace(guide).get_trace(*args, **kwargs) model = substitute(replay(model, guide_trace), data=params) model_trace = trace(model).get_trace(*args, **kwargs) for tr in (guide_trace, model_trace): for site in tr.values(): if site["type"] == "sample": if "log_prob" not in site: value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob site["log_prob"] = log_prob return model_trace, guide_trace
def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple containing initial :data:`SVIState`, and `get_params`, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace(*args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) inv_transforms[site['name']] = transform params[site['name']] = transform.inv(site['value']) self.constrain_fn = partial(transform_fn, inv_transforms) return SVIState(self.optim.init(params), rng_key)
def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density( seeded_guide, args, kwargs, param_map) mutable_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "mutable" } params.update(mutable_params) seeded_model = replay(seeded_model, guide_trace) model_log_density, model_trace = log_density( seeded_model, args, kwargs, params) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") mutable_params.update({ name: site["value"] for name, site in model_trace.items() if site["type"] == "mutable" }) # log p(z) - log q(z) elbo_particle = model_log_density - guide_log_density if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params else: raise ValueError( "Currently, we only support mutable states with num_particles=1." ) else: return elbo_particle, None
def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) subs_guide = substitute(seeded_guide, data=param_map) guide_trace = trace(subs_guide).get_trace(*args, **kwargs) mutable_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "mutable" } params.update(mutable_params) subs_model = substitute(replay(seeded_model, guide_trace), data=params) model_trace = trace(subs_model).get_trace(*args, **kwargs) mutable_params.update({ name: site["value"] for name, site in model_trace.items() if site["type"] == "mutable" }) check_model_guide_match(model_trace, guide_trace) _validate_model(model_trace, plate_warning="loose") _check_mean_field_requirement(model_trace, guide_trace) elbo_particle = 0 for name, model_site in model_trace.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + _get_log_prob_sum( model_site) else: guide_site = guide_trace[name] try: kl_qp = kl_divergence(guide_site["fn"], model_site["fn"]) kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"]) elbo_particle = elbo_particle - jnp.sum(kl_qp) except NotImplementedError: elbo_particle = (elbo_particle + _get_log_prob_sum(model_site) - _get_log_prob_sum(guide_site)) # handle auxiliary sites in the guide for name, site in guide_trace.items(): if site["type"] == "sample" and name not in model_trace: assert site["infer"].get( "is_auxiliary") or site["is_observed"] elbo_particle = elbo_particle - _get_log_prob_sum(site) if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params else: raise ValueError( "Currently, we only support mutable states with num_particles=1." ) else: return elbo_particle, None
def elbo(param_map, model, guide, model_args, guide_args, kwargs): """ This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variablbes with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives. For more details, refer to http://pyro.ai/examples/svi_part_i.html. :param dict param_map: dictionary of current parameter values keyed by site name. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :param dict kwargs: static keyword arguments to the model / guide. :return: negative of the Evidence Lower Bound (ELBo) to be minimized. """ guide_log_density, guide_trace = log_density(guide, guide_args, kwargs, param_map) model_log_density, _ = log_density(replay(model, guide_trace), model_args, kwargs, param_map) # log p(z) - log q(z) elbo = model_log_density - guide_log_density # Return (-elbo) since by convention we do gradient descent on a loss and # the ELBO is a lower bound that needs to be maximized. return -elbo
def get_param(opt_state, model, guide, get_params, constrain_fn, rng, model_args=None, guide_args=None, **kwargs): params = constrain_fn(get_params(opt_state)) model, guide = _seed(model, guide, rng) if guide_args is not None: guide = substitute(guide, base_param_map=params) guide_trace = trace(guide).get_trace(*guide_args, **kwargs) model_params = { k: v for k, v in params.items() if k not in guide_trace } params = { k: guide_trace[k]['value'] if k in guide_trace else v for k, v in params.items() } if model_args is not None: model = substitute(replay(model, guide_trace), base_param_map=model_params) model_trace = trace(model).get_trace(*model_args, **kwargs) params = { k: model_trace[k]['value'] if k in model_params else v for k, v in params.items() } return params
def elbo(param_map, model, guide, model_args, guide_args, kwargs): guide_log_density, guide_trace = log_density(guide, guide_args, kwargs, param_map) model_log_density, _ = log_density(replay(model, guide_trace), model_args, kwargs, param_map) # log p(z) - log q(z) elbo = model_log_density - guide_log_density # Return (-elbo) since by convention we do gradient descent on a loss and # the ELBO is a lower bound that needs to be maximized. return -elbo
def elbo(param_map, model, guide, model_args, guide_args, kwargs, constrain_fn, is_autoguide=False): """ This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variablbes with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives. For more details, refer to http://pyro.ai/examples/svi_part_i.html. :param dict param_map: dictionary of current parameter values keyed by site name. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :param dict kwargs: static keyword arguments to the model / guide. :param constrain_fn: a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain. :return: negative of the Evidence Lower Bound (ELBo) to be minimized. """ param_map = constrain_fn(param_map) guide_log_density, guide_trace = log_density(guide, guide_args, kwargs, param_map) if is_autoguide: # in autoguide, a site's value holds intermediate value for name, site in guide_trace.items(): if site['type'] == 'sample': param_map[name] = site['value'] else: # NB: we only want to substitute params not available in guide_trace param_map = { k: v for k, v in param_map.items() if k not in guide_trace } model = replay(model, guide_trace) model_log_density, _ = log_density(model, model_args, kwargs, param_map, skip_dist_transforms=is_autoguide) # log p(z) - log q(z) elbo = model_log_density - guide_log_density # Return (-elbo) since by convention we do gradient descent on a loss and # the ELBO is a lower bound that needs to be maximized. return -elbo
def test_get_mask_optimization(): def model(): with numpyro.handlers.seed(rng_seed=0): x = numpyro.sample("x", dist.Normal(0, 1)) numpyro.sample("y", dist.Normal(x, 1), obs=0.) called.add("model-always") if numpyro.get_mask() is not False: called.add("model-sometimes") numpyro.factor("f", x + 1) def guide(): with numpyro.handlers.seed(rng_seed=1): x = numpyro.sample("x", dist.Normal(0, 1)) called.add("guide-always") if numpyro.get_mask() is not False: called.add("guide-sometimes") numpyro.factor("g", 2 - x) called = set() trace = handlers.trace(guide).get_trace() handlers.replay(model, trace)() assert "model-always" in called assert "guide-always" in called assert "model-sometimes" in called assert "guide-sometimes" in called called = set() with handlers.mask(mask=False): trace = handlers.trace(guide).get_trace() handlers.replay(model, trace)() assert "model-always" in called assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called called = set() Predictive(model, guide=guide, num_samples=2, parallel=True)(random.PRNGKey(2)) assert "model-always" in called assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map) seeded_model = replay(seeded_model, guide_trace) model_log_density, _ = log_density(seeded_model, args, kwargs, param_map) # log p(z) - log q(z) elbo = model_log_density - guide_log_density return elbo
def test_subsample_replay(): data = jnp.arange(100.) subsample_size = 7 with handlers.trace() as guide_trace, handlers.seed(rng_seed=0): with numpyro.plate("a", len(data), subsample_size=subsample_size): pass with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace): with numpyro.plate("a", len(data)): subsample_data = numpyro.subsample(data, event_dim=0) assert subsample_data.shape == (subsample_size, )
def init(self, rng_key, *args, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace( *args, **kwargs, **self.static_kwargs ) params = {} inv_transforms = {} mutable_state = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site["type"] == "param": constraint = site["kwargs"].pop("constraint", constraints.real) with helpful_support_errors(site): transform = biject_to(constraint) inv_transforms[site["name"]] = transform params[site["name"]] = transform.inv(site["value"]) elif site["type"] == "mutable": mutable_state[site["name"]] = site["value"] elif ( site["type"] == "sample" and (not site["is_observed"]) and site["fn"].support.is_discrete and not self.loss.can_infer_discrete ): s_name = type(self.loss).__name__ warnings.warn( f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" ) if not mutable_state: mutable_state = None self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run params, mutable_state = tree_map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), (params, mutable_state), ) return SVIState(self.optim.init(params), mutable_state, rng_key)
def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map) # NB: we only want to substitute params not available in guide_trace model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace} seeded_model = replay(seeded_model, guide_trace) model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map) # log p(z) - log q(z) elbo = model_log_density - guide_log_density return elbo
def retrace(self, name, tr, dist_proposal, val_proposal, model_args, model_kwargs): fn_current = tr[name]["fn"] val_current = tr[name]["value"] tr[name]["fn"] = dist_proposal tr[name]["value"] = val_proposal second_trace = trace(replay(self.model, tr)).get_trace(*model_args, **model_kwargs) tr[name]["fn"] = fn_current tr[name]["value"] = val_current return second_trace
def test_compute_downstream_costs_plate_reuse(dim1, dim2): seeded_guide = handlers.seed(plate_reuse_model_guide, rng_seed=0) guide_trace = handlers.trace(seeded_guide).get_trace(include_obs=False, dim1=dim1, dim2=dim2) model_trace = handlers.trace(handlers.replay( seeded_guide, guide_trace)).get_trace(include_obs=True, dim1=dim1, dim2=dim2) for trace in (model_trace, guide_trace): for site in trace.values(): if site["type"] == "sample": site["log_prob"] = site["fn"].log_prob(site["value"]) non_reparam_nodes = set( name for name, site in guide_trace.items() if site["type"] == "sample" and ( site["is_observed"] or not site["fn"].has_rsample)) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute for k in dc: assert guide_trace[k]["log_prob"].shape == dc[k].shape assert_allclose(dc[k], dc_brute[k], rtol=1e-6) expected_c1 = model_trace["c1"]["log_prob"] - guide_trace["c1"]["log_prob"] expected_c1 += (model_trace["b1"]["log_prob"] - guide_trace["b1"]["log_prob"]).sum() expected_c1 += model_trace["c2"]["log_prob"] - guide_trace["c2"]["log_prob"] expected_c1 += model_trace["obs"]["log_prob"] assert_allclose(expected_c1, dc["c1"], rtol=1e-6)
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with block(), enum(first_available_dim=first_available_dim): with plate_to_enum_plate(): model_tr = packed_trace(model).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, list(terms["log_factors"].values()) + list(terms["log_measures"].values()), eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[terms["log_measures"][name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) with replay(guide_trace=sample_tr): return model(*args, **kwargs)
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23, include_triple, include_z1): seeded_guide = handlers.seed(big_model_guide, rng_seed=0) guide_trace = handlers.trace(seeded_guide).get_trace( include_obs=False, include_inner_1=include_inner_1, include_single=include_single, flip_c23=flip_c23, include_triple=include_triple, include_z1=include_z1, ) model_trace = handlers.trace(handlers.replay( seeded_guide, guide_trace)).get_trace( include_obs=True, include_inner_1=include_inner_1, include_single=include_single, flip_c23=flip_c23, include_triple=include_triple, include_z1=include_z1, ) for trace in (model_trace, guide_trace): for site in trace.values(): if site["type"] == "sample": site["log_prob"] = site["fn"].log_prob(site["value"]) non_reparam_nodes = set( name for name, site in guide_trace.items() if site["type"] == "sample" and ( site["is_observed"] or not site["fn"].has_rsample)) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute expected_nodes_full_model = { "a1": {"c2", "a1", "d1", "c1", "obs", "b1", "d2", "c3", "b0"}, "d2": {"obs", "d2"}, "d1": {"obs", "d1", "d2"}, "c3": {"d2", "obs", "d1", "c3"}, "b0": {"b0", "d1", "c1", "obs", "b1", "d2", "c3", "c2"}, "b1": {"obs", "b1", "d1", "d2", "c3", "c1", "c2"}, "c1": {"d1", "c1", "obs", "d2", "c3", "c2"}, "c2": {"obs", "d1", "c3", "d2", "c2"}, } if not include_triple and include_inner_1 and include_single and not flip_c23: assert dc_nodes == expected_nodes_full_model expected_b1 = model_trace["b1"]["log_prob"] - guide_trace["b1"]["log_prob"] expected_b1 += (model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"]).sum(0) expected_b1 += (model_trace["d1"]["log_prob"] - guide_trace["d1"]["log_prob"]).sum(0) expected_b1 += model_trace["obs"]["log_prob"].sum(0, keepdims=False) if include_inner_1: expected_b1 += (model_trace["c1"]["log_prob"] - guide_trace["c1"]["log_prob"]).sum(0) expected_b1 += (model_trace["c2"]["log_prob"] - guide_trace["c2"]["log_prob"]).sum(0) expected_b1 += (model_trace["c3"]["log_prob"] - guide_trace["c3"]["log_prob"]).sum(0) assert_allclose(expected_b1, dc["b1"], atol=1.0e-6) if include_single: expected_b0 = model_trace["b0"]["log_prob"] - guide_trace["b0"][ "log_prob"] expected_b0 += (model_trace["b1"]["log_prob"] - guide_trace["b1"]["log_prob"]).sum() expected_b0 += (model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"]).sum() expected_b0 += (model_trace["d1"]["log_prob"] - guide_trace["d1"]["log_prob"]).sum() expected_b0 += model_trace["obs"]["log_prob"].sum() if include_inner_1: expected_b0 += (model_trace["c1"]["log_prob"] - guide_trace["c1"]["log_prob"]).sum() expected_b0 += (model_trace["c2"]["log_prob"] - guide_trace["c2"]["log_prob"]).sum() expected_b0 += (model_trace["c3"]["log_prob"] - guide_trace["c3"]["log_prob"]).sum() assert_allclose(expected_b0, dc["b0"], atol=1.0e-6) assert dc["b0"].shape == (5, ) if include_inner_1: expected_c3 = model_trace["c3"]["log_prob"] - guide_trace["c3"][ "log_prob"] expected_c3 += (model_trace["d1"]["log_prob"] - guide_trace["d1"]["log_prob"]).sum(0) expected_c3 += (model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"]).sum(0) expected_c3 += model_trace["obs"]["log_prob"].sum(0) expected_c2 = model_trace["c2"]["log_prob"] - guide_trace["c2"][ "log_prob"] expected_c2 += (model_trace["d1"]["log_prob"] - guide_trace["d1"]["log_prob"]).sum(0) expected_c2 += (model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"]).sum(0) expected_c2 += model_trace["obs"]["log_prob"].sum(0) expected_c1 = model_trace["c1"]["log_prob"] - guide_trace["c1"][ "log_prob"] if flip_c23: expected_c3 += model_trace["c2"]["log_prob"] - guide_trace["c2"][ "log_prob"] expected_c2 += model_trace["c3"]["log_prob"] else: expected_c2 += model_trace["c3"]["log_prob"] - guide_trace["c3"][ "log_prob"] expected_c2 += model_trace["c2"]["log_prob"] - guide_trace["c2"][ "log_prob"] expected_c1 += expected_c3 assert_allclose(expected_c1, dc["c1"], atol=1.0e-6) assert_allclose(expected_c2, dc["c2"], atol=1.0e-6) assert_allclose(expected_c3, dc["c3"], atol=1.0e-6) expected_d1 = model_trace["d1"]["log_prob"] - guide_trace["d1"]["log_prob"] expected_d1 += model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"] expected_d1 += model_trace["obs"]["log_prob"] expected_d2 = model_trace["d2"]["log_prob"] - guide_trace["d2"]["log_prob"] expected_d2 += model_trace["obs"]["log_prob"] if include_triple: expected_z0 = (dc["a1"] + model_trace["z0"]["log_prob"] - guide_trace["z0"]["log_prob"]) assert_allclose(expected_z0, dc["z0"], atol=1.0e-6) assert_allclose(expected_d2, dc["d2"], atol=1.0e-6) assert_allclose(expected_d1, dc["d1"], atol=1.0e-6) assert dc["b1"].shape == (2, ) assert dc["d2"].shape == (4, 2) for k in dc: assert guide_trace[k]["log_prob"].shape == dc[k].shape assert_allclose(dc[k], dc_brute[k], rtol=2e-7)