def test_uncondition(self): unconditioned_model = poutine.uncondition(self.model) unconditioned_trace = poutine.trace(unconditioned_model).get_trace() conditioned_trace = poutine.trace(self.model).get_trace() assert_equal(conditioned_trace.nodes["obs"]["value"], torch.ones(2)) assert_not_equal(unconditioned_trace.nodes["obs"]["value"], torch.ones(2))
def test_undo_uncondition(self): unconditioned_model = poutine.uncondition(self.model) reconditioned_model = pyro.condition( unconditioned_model, {"obs": torch.ones(2)} ) reconditioned_trace = poutine.trace(reconditioned_model).get_trace() assert_equal(reconditioned_trace.nodes["obs"]["value"], torch.ones(2))
def _get_traces(self, model, guide, args, kwargs): if self.max_plate_nesting == float("inf"): with validation_enabled( False): # Avoid calling .log_prob() when undefined. # TODO factor this out as a stand-alone helper. ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs) vectorize = pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting) # Trace the guide as in ELBO. with poutine.trace() as tr, vectorize: guide(*args, **kwargs) guide_trace = tr.trace # Trace the model, drawing posterior predictive samples. with poutine.trace() as tr, poutine.uncondition(): with poutine.replay(trace=guide_trace), vectorize: model(*args, **kwargs) model_trace = tr.trace for site in model_trace.nodes.values(): if site["type"] == "sample" and site["infer"].get( "was_observed", False): site["is_observed"] = True if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): for site in guide_trace.nodes.values(): if site["type"] == "sample": warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires fully reparametrized guides" ) for trace in model_trace.nodes.values(): if site["type"] == "sample": if site["is_observed"]: warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires reparametrized likelihoods" ) if self.prior_scale > 0: model_trace.compute_log_prob( site_filter=lambda name, site: not site["is_observed"]) if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": if not site["is_observed"]: check_site_shape(site, self.max_plate_nesting) return guide_trace, model_trace