Exemple #1
0
 def test_uncondition(self):
     unconditioned_model = poutine.uncondition(self.model)
     unconditioned_trace = poutine.trace(unconditioned_model).get_trace()
     conditioned_trace = poutine.trace(self.model).get_trace()
     assert_equal(conditioned_trace.nodes["obs"]["value"], torch.ones(2))
     assert_not_equal(unconditioned_trace.nodes["obs"]["value"],
                      torch.ones(2))
Exemple #2
0
 def test_undo_uncondition(self):
     unconditioned_model = poutine.uncondition(self.model)
     reconditioned_model = pyro.condition(
         unconditioned_model, {"obs": torch.ones(2)}
     )
     reconditioned_trace = poutine.trace(reconditioned_model).get_trace()
     assert_equal(reconditioned_trace.nodes["obs"]["value"], torch.ones(2))
Exemple #3
0
    def _get_traces(self, model, guide, args, kwargs):
        if self.max_plate_nesting == float("inf"):
            with validation_enabled(
                    False):  # Avoid calling .log_prob() when undefined.
                # TODO factor this out as a stand-alone helper.
                ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
        vectorize = pyro.plate("num_particles_vectorized",
                               self.num_particles,
                               dim=-self.max_plate_nesting)

        # Trace the guide as in ELBO.
        with poutine.trace() as tr, vectorize:
            guide(*args, **kwargs)
        guide_trace = tr.trace

        # Trace the model, drawing posterior predictive samples.
        with poutine.trace() as tr, poutine.uncondition():
            with poutine.replay(trace=guide_trace), vectorize:
                model(*args, **kwargs)
        model_trace = tr.trace
        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["infer"].get(
                    "was_observed", False):
                site["is_observed"] = True
        if is_validation_enabled():
            check_model_guide_match(model_trace, guide_trace,
                                    self.max_plate_nesting)

        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        if is_validation_enabled():
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    warn_if_nan(site["value"], site["name"])
                    if not getattr(site["fn"], "has_rsample", False):
                        raise ValueError(
                            "EnergyDistance requires fully reparametrized guides"
                        )
            for trace in model_trace.nodes.values():
                if site["type"] == "sample":
                    if site["is_observed"]:
                        warn_if_nan(site["value"], site["name"])
                        if not getattr(site["fn"], "has_rsample", False):
                            raise ValueError(
                                "EnergyDistance requires reparametrized likelihoods"
                            )

        if self.prior_scale > 0:
            model_trace.compute_log_prob(
                site_filter=lambda name, site: not site["is_observed"])
            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        if not site["is_observed"]:
                            check_site_shape(site, self.max_plate_nesting)

        return guide_trace, model_trace