Ejemplo n.º 1
0
def get_importance_trace(graph_type, max_plate_nesting, model, guide, *args,
                         **kwargs):
    """
    Returns a single trace from the guide, and the model that is run
    against it.
    """
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
        *args, **kwargs)
    model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                graph_type=graph_type).get_trace(
                                    *args, **kwargs)
    if is_validation_enabled():
        check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)

    model_trace.compute_log_prob()
    guide_trace.compute_score_parts()
    if is_validation_enabled():
        for site in model_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)
        for site in guide_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)

    return model_trace, guide_trace
Ejemplo n.º 2
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'Trace_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            # model_trace.compute_log_prob() # TODO: no va perque no hi ha parametres de decoder
            guide_trace.compute_score_parts()

            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)
                for site in guide_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)

            yield model_trace, guide_trace
Ejemplo n.º 3
0
 def _populate_cache(self, model_trace):
     """
     Populate the ordinals (set of ``CondIndepStack`` frames)
     and enum_dims for each sample site.
     """
     if not self.has_enumerable_sites:
         return
     if self.max_plate_nesting is None:
         raise ValueError(
             "Finite value required for `max_plate_nesting` when model "
             "has discrete (enumerable) sites."
         )
     model_trace.compute_log_prob()
     model_trace.pack_tensors()
     for name, site in model_trace.nodes.items():
         if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
             if is_validation_enabled():
                 check_site_shape(site, self.max_plate_nesting)
             self.ordering[name] = frozenset(
                 model_trace.plate_to_symbol[f.name]
                 for f in site["cond_indep_stack"]
                 if f.vectorized
             )
     self._enum_dims = set(model_trace.symbol_to_dim) - set(
         model_trace.plate_to_symbol.values()
     )
Ejemplo n.º 4
0
            def loss_and_surrogate_loss(*args):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for weight, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
                    model_trace.compute_log_prob()
                    guide_trace.compute_score_parts()
                    if is_validation_enabled():
                        for site in model_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site, self.max_iarange_nesting)
                        for site in guide_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site, self.max_iarange_nesting)

                    # compute elbo for reparameterized nodes
                    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
                    elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

                    # the following computations are only necessary if we have non-reparameterizable nodes
                    baseline_loss = 0.0
                    if non_reparam_nodes:
                        downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
                        surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
                                                                                       non_reparam_nodes,
                                                                                       downstream_costs)
                        surrogate_elbo += surrogate_elbo_term

                    loss = loss - weight * elbo
                    surrogate_loss = surrogate_loss - weight * surrogate_elbo

                return loss, surrogate_loss
Ejemplo n.º 5
0
    def _get_traces(self, model, guide, args, kwargs):
        if self.max_plate_nesting == float("inf"):
            with validation_enabled(
                    False):  # Avoid calling .log_prob() when undefined.
                # TODO factor this out as a stand-alone helper.
                ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
        vectorize = pyro.plate("num_particles_vectorized",
                               self.num_particles,
                               dim=-self.max_plate_nesting)

        # Trace the guide as in ELBO.
        with poutine.trace() as tr, vectorize:
            guide(*args, **kwargs)
        guide_trace = tr.trace

        # Trace the model, drawing posterior predictive samples.
        with poutine.trace() as tr, poutine.uncondition():
            with poutine.replay(trace=guide_trace), vectorize:
                model(*args, **kwargs)
        model_trace = tr.trace
        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["infer"].get(
                    "was_observed", False):
                site["is_observed"] = True
        if is_validation_enabled():
            check_model_guide_match(model_trace, guide_trace,
                                    self.max_plate_nesting)

        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        if is_validation_enabled():
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    warn_if_nan(site["value"], site["name"])
                    if not getattr(site["fn"], "has_rsample", False):
                        raise ValueError(
                            "EnergyDistance requires fully reparametrized guides"
                        )
            for trace in model_trace.nodes.values():
                if site["type"] == "sample":
                    if site["is_observed"]:
                        warn_if_nan(site["value"], site["name"])
                        if not getattr(site["fn"], "has_rsample", False):
                            raise ValueError(
                                "EnergyDistance requires reparametrized likelihoods"
                            )

        if self.prior_scale > 0:
            model_trace.compute_log_prob(
                site_filter=lambda name, site: not site["is_observed"])
            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        if not site["is_observed"]:
                            check_site_shape(site, self.max_plate_nesting)

        return guide_trace, model_trace
Ejemplo n.º 6
0
 def _get_log_factors(self, model_trace):
     """
     Aggregates the `log_prob` terms into a list for each
     ordinal.
     """
     model_trace.compute_log_prob()
     model_trace.pack_tensors()
     log_probs = OrderedDict()
     # Collect log prob terms per independence context.
     for name, site in model_trace.nodes.items():
         if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
             if is_validation_enabled():
                 check_site_shape(site, self.max_plate_nesting)
             log_probs.setdefault(self.ordering[name], []).append(site["packed"]["log_prob"])
     return log_probs
Ejemplo n.º 7
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
Ejemplo n.º 8
0
 def _compute_log_prob_terms(self, model_trace):
     """
     Computes the conditional probabilities for each of the sites
     in the model trace, and stores the result in `self._log_probs`.
     """
     model_trace.compute_log_prob()
     self._log_probs = defaultdict(list)
     ordering = {name: frozenset(site["cond_indep_stack"])
                 for name, site in model_trace.nodes.items()
                 if site["type"] == "sample"}
     # Collect log prob terms per independence context.
     for name, site in model_trace.nodes.items():
         if site["type"] == "sample":
             if is_validation_enabled():
                 check_site_shape(site, self.max_plate_nesting)
             self._log_probs[ordering[name]].append(site["log_prob"])
     if not self._log_prob_shapes:
         for ordinal, log_prob in self._log_probs.items():
             self._log_prob_shapes[ordinal] = broadcast_shape(*(t.shape for t in self._log_probs[ordinal]))
Ejemplo n.º 9
0
    def _guess_max_plate_nesting(self, model, guide, args, kwargs):
        """
        Guesses max_plate_nesting by running the (model,guide) pair once
        without enumeration. This optimistically assumes static model
        structure.
        """
        # Ignore validation to allow model-enumerated sites absent from the guide.
        with poutine.block():
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(
                poutine.replay(model, trace=guide_trace)
            ).get_trace(*args, **kwargs)
        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        sites = [
            site
            for trace in (model_trace, guide_trace)
            for site in trace.nodes.values()
            if site["type"] == "sample"
        ]

        # Validate shapes now, since shape constraints will be weaker once
        # max_plate_nesting is changed from float('inf') to some finite value.
        # Here we know the traces are not enumerated, but later we'll need to
        # allow broadcasting of dims to the left of max_plate_nesting.
        if is_validation_enabled():
            guide_trace.compute_log_prob()
            model_trace.compute_log_prob()
            for site in sites:
                check_site_shape(site, max_plate_nesting=float("inf"))

        dims = [
            frame.dim
            for site in sites
            for frame in site["cond_indep_stack"]
            if frame.vectorized
        ]
        self.max_plate_nesting = -min(dims) if dims else 0
        if self.vectorize_particles and self.num_particles > 1:
            self.max_plate_nesting += 1
        logging.info("Guessed max_plate_nesting = {}".format(self.max_plate_nesting))
Ejemplo n.º 10
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace,
                                                     non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(
                model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
                guide_trace, non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
Ejemplo n.º 11
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
Ejemplo n.º 12
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
                                                                           non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
Ejemplo n.º 13
0
            def loss_and_surrogate_loss(*args):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for weight, model_trace, guide_trace in self._get_traces(
                        model, guide, *args, **kwargs):
                    model_trace.compute_log_prob()
                    guide_trace.compute_score_parts()
                    if is_validation_enabled():
                        for site in model_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site,
                                                 self.max_iarange_nesting)
                        for site in guide_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site,
                                                 self.max_iarange_nesting)

                    # compute elbo for reparameterized nodes
                    non_reparam_nodes = set(
                        guide_trace.nonreparam_stochastic_nodes)
                    elbo, surrogate_elbo = _compute_elbo_reparam(
                        model_trace, guide_trace, non_reparam_nodes)

                    # the following computations are only necessary if we have non-reparameterizable nodes
                    baseline_loss = 0.0
                    if non_reparam_nodes:
                        downstream_costs, _ = _compute_downstream_costs(
                            model_trace, guide_trace, non_reparam_nodes)
                        surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
                            guide_trace, non_reparam_nodes, downstream_costs)
                        surrogate_elbo += surrogate_elbo_term

                    loss = loss - weight * elbo
                    surrogate_loss = surrogate_loss - weight * surrogate_elbo

                return loss, surrogate_loss