Example #1
0
    def _get_matched_trace(self, model_trace, *args, **kwargs):
        """
        :param model_trace: a trace from the model
        :type model_trace: pyro.poutine.trace_struct.Trace
        :returns: guide trace with sampled values matched to model_trace
        :rtype: pyro.poutine.trace_struct.Trace

        Returns a guide trace with values at sample and observe statements
        matched to those in model_trace.

        `args` and `kwargs` are passed to the guide.
        """
        kwargs["observations"] = {}
        for node in itertools.chain(model_trace.stochastic_nodes,
                                    model_trace.observation_nodes):
            if "was_observed" in model_trace.nodes[node]["infer"]:
                model_trace.nodes[node]["is_observed"] = True
                kwargs["observations"][node] = model_trace.nodes[node]["value"]

        guide_trace = poutine.trace(poutine.replay(self.guide,
                                                   model_trace)).get_trace(
                                                       *args, **kwargs)

        check_model_guide_match(model_trace, guide_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        return guide_trace
Example #2
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                raise NotImplementedError(
                    "https://github.com/uber/pyro/issues/220")

            guide_trace = poutine.trace(guide, graph_type="dense").get_trace(
                *args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, guide_trace),
                                        graph_type="dense").get_trace(
                                            *args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
Example #3
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'Trace_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            # model_trace.compute_log_prob() # TODO: no va perque no hi ha parametres de decoder
            guide_trace.compute_score_parts()

            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)
                for site in guide_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)

            yield model_trace, guide_trace
Example #4
0
def get_importance_trace(graph_type, max_plate_nesting, model, guide, *args,
                         **kwargs):
    """
    Returns a single trace from the guide, and the model that is run
    against it.
    """
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
        *args, **kwargs)
    model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                graph_type=graph_type).get_trace(
                                    *args, **kwargs)
    if is_validation_enabled():
        check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)

    model_trace.compute_log_prob()
    guide_trace.compute_score_parts()
    if is_validation_enabled():
        for site in model_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)
        for site in guide_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)

    return model_trace, guide_trace
Example #5
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide, graph_type="dense").get_trace(
                *args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       trace=guide_trace),
                                        graph_type="dense").get_trace(
                                            *args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [
                    name for name, site in guide_trace.nodes.items() if
                    site["type"] == "sample" and site["infer"].get("enumerate")
                ]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'
                    ]))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
Example #6
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
Example #7
0
    def _get_traces(self, model, guide, args, kwargs):
        if self.max_plate_nesting == float("inf"):
            with validation_enabled(
                    False):  # Avoid calling .log_prob() when undefined.
                # TODO factor this out as a stand-alone helper.
                ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
        vectorize = pyro.plate("num_particles_vectorized",
                               self.num_particles,
                               dim=-self.max_plate_nesting)

        # Trace the guide as in ELBO.
        with poutine.trace() as tr, vectorize:
            guide(*args, **kwargs)
        guide_trace = tr.trace

        # Trace the model, drawing posterior predictive samples.
        with poutine.trace() as tr, poutine.uncondition():
            with poutine.replay(trace=guide_trace), vectorize:
                model(*args, **kwargs)
        model_trace = tr.trace
        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["infer"].get(
                    "was_observed", False):
                site["is_observed"] = True
        if is_validation_enabled():
            check_model_guide_match(model_trace, guide_trace,
                                    self.max_plate_nesting)

        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        if is_validation_enabled():
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    warn_if_nan(site["value"], site["name"])
                    if not getattr(site["fn"], "has_rsample", False):
                        raise ValueError(
                            "EnergyDistance requires fully reparametrized guides"
                        )
            for trace in model_trace.nodes.values():
                if site["type"] == "sample":
                    if site["is_observed"]:
                        warn_if_nan(site["value"], site["name"])
                        if not getattr(site["fn"], "has_rsample", False):
                            raise ValueError(
                                "EnergyDistance requires reparametrized likelihoods"
                            )

        if self.prior_scale > 0:
            model_trace.compute_log_prob(
                site_filter=lambda name, site: not site["is_observed"])
            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        if not site["is_observed"]:
                            check_site_shape(site, self.max_plate_nesting)

        return guide_trace, model_trace
Example #8
0
    def _get_matched_trace(model_trace, guide, args, kwargs):
        kwargs["observations"] = {}
        for node in model_trace.stochastic_nodes + model_trace.observation_nodes:
            if "was_observed" in model_trace.nodes[node]["infer"]:
                model_trace.nodes[node]["is_observed"] = True
                kwargs["observations"][node] = model_trace.nodes[node]["value"]

        guide_trace = poutine.trace(poutine.replay(guide,
                                                   model_trace)).get_trace(
                                                       *args, **kwargs)
        check_model_guide_match(model_trace, guide_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        return guide_trace
Example #9
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
Example #10
0
    def _get_matched_cross_trace(self, model_x_trace, model_z_trace,*args, **kwargs):
        kwargs["observations"] = {}
        kwargs["truth"] = {}
        for node in itertools.chain(model_x_trace.stochastic_nodes, model_x_trace.observation_nodes):
            if "was_observed" in model_x_trace.nodes[node]["infer"]:
                model_x_trace.nodes[node]["is_observed"] = True
                model_z_trace.nodes[node]["is_observed"] = True
                kwargs["observations"][node] = model_x_trace.nodes[node]["value"]
            else:
                kwargs["truth"][node] = model_x_trace.nodes[node]["value"]

        guide_trace = poutine.trace(poutine.replay(self.guide,
                                                   model_z_trace)
                                    ).get_trace(*args, **kwargs)

        check_model_guide_match(model_x_trace, guide_trace)
        check_model_guide_match(model_z_trace, guide_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        return guide_trace
Example #11
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                # This iterates over a bag of traces, for each particle.
                for scale, guide_trace in iter_discrete_traces(
                        "flat", guide, *args, **kwargs):
                    model_trace = poutine.trace(
                        poutine.replay(model, guide_trace),
                        graph_type="flat").get_trace(*args, **kwargs)

                    check_model_guide_match(model_trace, guide_trace)
                    guide_trace = prune_subsample_sites(guide_trace)
                    model_trace = prune_subsample_sites(model_trace)
                    check_enum_discrete_can_run(model_trace, guide_trace)

                    log_r = model_trace.batch_log_pdf(
                    ) - guide_trace.batch_log_pdf()
                    weight = scale / self.num_particles
                    yield weight, model_trace, guide_trace, log_r
                continue

            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       guide_trace)).get_trace(
                                                           *args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            log_r = model_trace.log_pdf() - guide_trace.log_pdf()
            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace, log_r
Example #12
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
Example #13
0
def test_nested_autoguide(Elbo):
    class Model(PyroModule):
        def __init__(self):
            super().__init__()
            self.x_loc = nn.Parameter(torch.tensor(1.))
            self.x_scale = PyroParam(torch.tensor(0.1), constraints.positive)

        def forward(self):
            pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))
            with pyro.plate("plate", 2):
                pyro.sample("y", dist.Normal(2., 0.1))

    model = Model()
    guide = nested_auto_guide_callable(model)

    # Check master ref for all nested components.
    for _, m in guide.named_modules():
        if m is guide:
            continue
        assert m.master is not None and m.master(
        ) is guide, "master ref wrong for {}".format(m._pyro_name)

    infer = SVI(model, guide, Adam({'lr': 0.005}),
                Elbo(strict_enumeration_warning=False))
    for _ in range(20):
        infer.step()

    guide_trace = poutine.trace(guide).get_trace()
    model_trace = poutine.trace(model).get_trace()
    check_model_guide_match(model_trace, guide_trace)
    assert all(
        p.startswith("AutoGuideList.0") or p.startswith("AutoGuideList.1.z")
        for p in guide_trace.param_nodes)
    stochastic_nodes = set(guide_trace.stochastic_nodes)
    assert "x" in stochastic_nodes
    assert "y" in stochastic_nodes
    # Only latent sampled is for the IAF.
    assert "_AutoGuideList.1.z_latent" in stochastic_nodes
Example #14
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                raise NotImplementedError("https://github.com/uber/pyro/issues/220")

            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace