Beispiel #1
0
def get_importance_trace(graph_type, max_plate_nesting, model, guide, *args,
                         **kwargs):
    """
    Returns a single trace from the guide, and the model that is run
    against it.
    """
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
        *args, **kwargs)
    model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                graph_type=graph_type).get_trace(
                                    *args, **kwargs)
    if is_validation_enabled():
        check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)

    model_trace.compute_log_prob()
    guide_trace.compute_score_parts()
    if is_validation_enabled():
        for site in model_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)
        for site in guide_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)

    return model_trace, guide_trace
    def _differentiable_loss_particle(self, model_trace, guide_trace):
        elbo_particle = 0

        for name, model_site in model_trace.nodes.items():
            if model_site["type"] == "sample":
                if model_site["is_observed"]:
                    elbo_particle = elbo_particle + model_site["log_prob_sum"]
                else:
                    guide_site = guide_trace.nodes[name]
                    if is_validation_enabled():
                        check_fully_reparametrized(guide_site)

                    # use kl divergence if available, else fall back on sampling
                    try:
                        kl_qp = kl_divergence(guide_site["fn"], model_site["fn"])
                        kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"], mask=guide_site["mask"])
                        assert kl_qp.shape == guide_site["fn"].batch_shape
                        elbo_particle = elbo_particle - kl_qp.sum()
                    except NotImplementedError:
                        entropy_term = guide_site["score_parts"].entropy_term
                        elbo_particle = elbo_particle + model_site["log_prob_sum"] - entropy_term.sum()

        # handle auxiliary sites in the guide
        for name, guide_site in guide_trace.nodes.items():
            if guide_site["type"] == "sample" and name not in model_trace.nodes:
                assert guide_site["infer"].get("is_auxiliary")
                if is_validation_enabled():
                    check_fully_reparametrized(guide_site)
                entropy_term = guide_site["score_parts"].entropy_term
                elbo_particle = elbo_particle - entropy_term.sum()

        loss = -(elbo_particle.detach() if torch._C._get_tracing_state() else torch_item(elbo_particle))
        surrogate_loss = -elbo_particle
        return loss, surrogate_loss
Beispiel #3
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'Trace_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            # model_trace.compute_log_prob() # TODO: no va perque no hi ha parametres de decoder
            guide_trace.compute_score_parts()

            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)
                for site in guide_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)

            yield model_trace, guide_trace
Beispiel #4
0
    def _get_traces(self, model, guide, args, kwargs):
        if self.max_plate_nesting == float("inf"):
            with validation_enabled(
                    False):  # Avoid calling .log_prob() when undefined.
                # TODO factor this out as a stand-alone helper.
                ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
        vectorize = pyro.plate("num_particles_vectorized",
                               self.num_particles,
                               dim=-self.max_plate_nesting)

        # Trace the guide as in ELBO.
        with poutine.trace() as tr, vectorize:
            guide(*args, **kwargs)
        guide_trace = tr.trace

        # Trace the model, drawing posterior predictive samples.
        with poutine.trace() as tr, poutine.uncondition():
            with poutine.replay(trace=guide_trace), vectorize:
                model(*args, **kwargs)
        model_trace = tr.trace
        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["infer"].get(
                    "was_observed", False):
                site["is_observed"] = True
        if is_validation_enabled():
            check_model_guide_match(model_trace, guide_trace,
                                    self.max_plate_nesting)

        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        if is_validation_enabled():
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    warn_if_nan(site["value"], site["name"])
                    if not getattr(site["fn"], "has_rsample", False):
                        raise ValueError(
                            "EnergyDistance requires fully reparametrized guides"
                        )
            for trace in model_trace.nodes.values():
                if site["type"] == "sample":
                    if site["is_observed"]:
                        warn_if_nan(site["value"], site["name"])
                        if not getattr(site["fn"], "has_rsample", False):
                            raise ValueError(
                                "EnergyDistance requires reparametrized likelihoods"
                            )

        if self.prior_scale > 0:
            model_trace.compute_log_prob(
                site_filter=lambda name, site: not site["is_observed"])
            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        if not site["is_observed"]:
                            check_site_shape(site, self.max_plate_nesting)

        return guide_trace, model_trace
Beispiel #5
0
 def _populate_cache(self, model_trace):
     """
     Populate the ordinals (set of ``CondIndepStack`` frames)
     and enum_dims for each sample site.
     """
     if not self.has_enumerable_sites:
         return
     if self.max_plate_nesting is None:
         raise ValueError(
             "Finite value required for `max_plate_nesting` when model "
             "has discrete (enumerable) sites."
         )
     model_trace.compute_log_prob()
     model_trace.pack_tensors()
     for name, site in model_trace.nodes.items():
         if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
             if is_validation_enabled():
                 check_site_shape(site, self.max_plate_nesting)
             self.ordering[name] = frozenset(
                 model_trace.plate_to_symbol[f.name]
                 for f in site["cond_indep_stack"]
                 if f.vectorized
             )
     self._enum_dims = set(model_trace.symbol_to_dim) - set(
         model_trace.plate_to_symbol.values()
     )
Beispiel #6
0
    def _differentiable_loss_particle(self, model_trace, guide_trace):
        if not self.vectorize_particles:
            raise NotImplementedError("TraceTailAdaptive_ELBO only implemented for vectorize_particles==True")

        if self.num_particles == 1:
            warnings.warn("For num_particles==1 TraceTailAdaptive_ELBO uses the same loss function as Trace_ELBO. " +
                          "Increase num_particles to get an adaptive f-divergence.")

        log_p, log_q = 0, 0

        for name, site in model_trace.nodes.items():
            if site["type"] == "sample":
                site_log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1)
                log_p = log_p + site_log_p

        for name, site in guide_trace.nodes.items():
            if site["type"] == "sample":
                site_log_q = site["log_prob"].reshape(self.num_particles, -1).sum(-1)
                log_q = log_q + site_log_q
                if is_validation_enabled():
                    check_fully_reparametrized(site)

        # rank the particles according to p/q
        log_pq = log_p - log_q
        rank = torch.argsort(log_pq, descending=False)
        rank = torch.index_select(torch.arange(self.num_particles, device=log_pq.device) + 1, -1, rank).type_as(log_pq)

        # compute the particle-specific weights used to construct the surrogate loss
        gamma = torch.pow(rank, self.tail_adaptive_beta).detach()
        surrogate_loss = -(log_pq * gamma).sum() / gamma.sum()

        # we do not compute the loss, so return `inf`
        return float('inf'), surrogate_loss
Beispiel #7
0
            def loss_and_surrogate_loss(*args):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for weight, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
                    model_trace.compute_log_prob()
                    guide_trace.compute_score_parts()
                    if is_validation_enabled():
                        for site in model_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site, self.max_iarange_nesting)
                        for site in guide_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site, self.max_iarange_nesting)

                    # compute elbo for reparameterized nodes
                    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
                    elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

                    # the following computations are only necessary if we have non-reparameterizable nodes
                    baseline_loss = 0.0
                    if non_reparam_nodes:
                        downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
                        surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
                                                                                       non_reparam_nodes,
                                                                                       downstream_costs)
                        surrogate_elbo += surrogate_elbo_term

                    loss = loss - weight * elbo
                    surrogate_loss = surrogate_loss - weight * surrogate_elbo

                return loss, surrogate_loss
Beispiel #8
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
Beispiel #9
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide, graph_type="dense").get_trace(
                *args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       trace=guide_trace),
                                        graph_type="dense").get_trace(
                                            *args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [
                    name for name, site in guide_trace.nodes.items() if
                    site["type"] == "sample" and site["infer"].get("enumerate")
                ]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'
                    ]))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
Beispiel #10
0
    def _get_trace(self, model, guide, args, kwargs):
        """
        Returns a single trace from the guide, and the model that is run
        against it.
        """
        model_trace, guide_trace = get_importance_trace(
            "flat", self.max_plate_nesting, model, guide, args, kwargs)

        if is_validation_enabled():
            check_traceenum_requirements(model_trace, guide_trace)
            _check_tmc_elbo_constraint(model_trace, guide_trace)

            has_enumerated_sites = any(site["infer"].get("enumerate")
                                       for trace in (guide_trace, model_trace)
                                       for name, site in trace.nodes.items()
                                       if site["type"] == "sample")

            if self.strict_enumeration_warning and not has_enumerated_sites:
                warnings.warn(
                    'TraceEnum_ELBO found no sample sites configured for enumeration. '
                    'If you want to enumerate sites, you need to @config_enumerate or set '
                    'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                    'If you do not want to enumerate, consider using Trace_ELBO instead.'
                )

        guide_trace.pack_tensors()
        model_trace.pack_tensors(guide_trace.plate_to_symbol)
        return model_trace, guide_trace
Beispiel #11
0
    def __call__(self, name, fn, obs):
        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, dist.Stable) and fn.coords == "S0"
        if is_validation_enabled():
            if not (fn.skew == 0).all():
                raise ValueError("SymmetricStableReparam found nonzero skew")
            if not (fn.stability < 2).all():
                raise ValueError("SymmetricStableReparam found stability >= 2")

        # Draw parameter-free noise.
        proto = fn.stability
        half_pi = proto.new_full(proto.shape, math.pi / 2)
        one = proto.new_ones(proto.shape)
        u = pyro.sample("{}_uniform".format(name),
                        self._wrap(dist.Uniform(-half_pi, half_pi), event_dim))
        e = pyro.sample("{}_exponential".format(name),
                        self._wrap(dist.Exponential(one), event_dim))

        # Differentiably transform to scale drawn from a totally-skewed stable variable.
        a = fn.stability
        z = _unsafe_standard_stable(a / 2, 1, u, e, coords="S")
        assert (z >= 0).all()
        scale = fn.scale * (math.pi / 4 * a).cos().pow(a.reciprocal()) * z.sqrt()
        scale = scale.clamp(min=torch.finfo(scale.dtype).tiny)

        # Construct a scaled Gaussian, using Stable(2,0,s,m) == Normal(m,s*sqrt(2)).
        new_fn = self._wrap(dist.Normal(fn.loc, scale * (2 ** 0.5)), event_dim)
        return new_fn, obs
Beispiel #12
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
Beispiel #13
0
 def _get_trace(self, model, guide, args, kwargs):
     """
     Returns a single trace from the guide, and the model that is run
     against it.
     """
     model_trace, guide_trace = get_importance_trace(
         "flat", self.max_plate_nesting, model, guide, args, kwargs)
     if is_validation_enabled():
         check_if_enumerated(guide_trace)
     return model_trace, guide_trace
Beispiel #14
0
def _disallow_latent_variables(section_name):
    if not is_validation_enabled():
        yield
        return

    with poutine.trace() as tr:
        yield
    for name, site in tr.trace.nodes.items():
        if site["type"] == "sample" and not site["is_observed"]:
            raise NotImplementedError("{} contained latent variable {}"
                                      .format(section_name, name))
Beispiel #15
0
 def __setitem__(self, key, value):
     if self._locked:
         raise RuntimeError("Guide cannot write to SMCState")
     if is_validation_enabled():
         if not isinstance(value, torch.Tensor):
             raise TypeError(
                 "Only Tensors can be stored in an SMCState, but got {}".
                 format(type(value).__name__))
         if value.dim() == 0 or value.size(0) != self._num_particles:
             raise ValueError(
                 "Expected leading dim of size {} but got shape {}".format(
                     self._num_particles, value.shape))
     super().__setitem__(key, value)
Beispiel #16
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
Beispiel #17
0
 def _get_log_factors(self, model_trace):
     """
     Aggregates the `log_prob` terms into a list for each
     ordinal.
     """
     model_trace.compute_log_prob()
     model_trace.pack_tensors()
     log_probs = OrderedDict()
     # Collect log prob terms per independence context.
     for name, site in model_trace.nodes.items():
         if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
             if is_validation_enabled():
                 check_site_shape(site, self.max_plate_nesting)
             log_probs.setdefault(self.ordering[name], []).append(site["packed"]["log_prob"])
     return log_probs
Beispiel #18
0
 def _pyro_sample(self, msg):
     if msg["done"] or msg["is_observed"] or type(msg["fn"]).__name__ == "_Subsample":
         return
     with torch.no_grad():
         value = self.init_fn(msg)
     if is_validation_enabled() and msg["value"] is not None:
         if not isinstance(value, type(msg["value"])):
             raise ValueError(
                 "{} provided invalid type for site {}:\nexpected {}\nactual {}"
                 .format(self.init_fn, msg["name"], type(msg["value"]), type(value)))
         if value.shape != msg["value"].shape:
             raise ValueError(
                 "{} provided invalid shape for site {}:\nexpected {}\nactual {}"
                 .format(self.init_fn, msg["name"], msg["value"].shape, value.shape))
     msg["value"] = value
     msg["done"] = True
Beispiel #19
0
    def _transition_bwd(self, params, prev, curr, t):
        """
        Helper to collect probabilty factors from .transition() conditioned on
        previous and current enumerated states.
        """
        # Run .transition() conditioned on computed flows.
        cond_data = {"{}_{}".format(k, t): v for k, v in curr.items()}
        cond_data.update(self.compute_flows(prev, curr, t))
        with poutine.condition(data=cond_data):
            state = prev.copy()
            self.transition(params, state, t)  # Mutates state.

        # Validate that .transition() matches .compute_flows().
        if is_validation_enabled():
            for key in self.compartments:
                if not torch.allclose(state[key], curr[key]):
                    raise ValueError("Incorrect state['{}'] update in .transition(), "
                                     "check that .transition() matches .compute_flows()."
                                     .format(key))
Beispiel #20
0
 def _compute_log_prob_terms(self, model_trace):
     """
     Computes the conditional probabilities for each of the sites
     in the model trace, and stores the result in `self._log_probs`.
     """
     model_trace.compute_log_prob()
     self._log_probs = defaultdict(list)
     ordering = {name: frozenset(site["cond_indep_stack"])
                 for name, site in model_trace.nodes.items()
                 if site["type"] == "sample"}
     # Collect log prob terms per independence context.
     for name, site in model_trace.nodes.items():
         if site["type"] == "sample":
             if is_validation_enabled():
                 check_site_shape(site, self.max_plate_nesting)
             self._log_probs[ordering[name]].append(site["log_prob"])
     if not self._log_prob_shapes:
         for ordinal, log_prob in self._log_probs.items():
             self._log_prob_shapes[ordinal] = broadcast_shape(*(t.shape for t in self._log_probs[ordinal]))
Beispiel #21
0
    def _guess_max_plate_nesting(self, model, guide, args, kwargs):
        """
        Guesses max_plate_nesting by running the (model,guide) pair once
        without enumeration. This optimistically assumes static model
        structure.
        """
        # Ignore validation to allow model-enumerated sites absent from the guide.
        with poutine.block():
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(
                poutine.replay(model, trace=guide_trace)
            ).get_trace(*args, **kwargs)
        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        sites = [
            site
            for trace in (model_trace, guide_trace)
            for site in trace.nodes.values()
            if site["type"] == "sample"
        ]

        # Validate shapes now, since shape constraints will be weaker once
        # max_plate_nesting is changed from float('inf') to some finite value.
        # Here we know the traces are not enumerated, but later we'll need to
        # allow broadcasting of dims to the left of max_plate_nesting.
        if is_validation_enabled():
            guide_trace.compute_log_prob()
            model_trace.compute_log_prob()
            for site in sites:
                check_site_shape(site, max_plate_nesting=float("inf"))

        dims = [
            frame.dim
            for site in sites
            for frame in site["cond_indep_stack"]
            if frame.vectorized
        ]
        self.max_plate_nesting = -min(dims) if dims else 0
        if self.vectorize_particles and self.num_particles > 1:
            self.max_plate_nesting += 1
        logging.info("Guessed max_plate_nesting = {}".format(self.max_plate_nesting))
Beispiel #22
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace,
                                                     non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(
                model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
                guide_trace, non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
Beispiel #23
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # have the trace compute all the individual (batch) log pdf terms
        # and score function terms (if present) so that they are available below
        model_trace.compute_log_prob()
        guide_trace.compute_score_parts()
        if is_validation_enabled():
            for site in model_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    check_site_shape(site, self.max_iarange_nesting)

        # compute elbo for reparameterized nodes
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:
            downstream_costs, _ = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
            surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(guide_trace,
                                                                           non_reparam_nodes, downstream_costs)
            surrogate_elbo += surrogate_elbo_term

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))

        loss = -torch_item(elbo)
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return weight * loss
Beispiel #24
0
            def loss_and_surrogate_loss(*args):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for weight, model_trace, guide_trace in self._get_traces(
                        model, guide, *args, **kwargs):
                    model_trace.compute_log_prob()
                    guide_trace.compute_score_parts()
                    if is_validation_enabled():
                        for site in model_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site,
                                                 self.max_iarange_nesting)
                        for site in guide_trace.nodes.values():
                            if site["type"] == "sample":
                                check_site_shape(site,
                                                 self.max_iarange_nesting)

                    # compute elbo for reparameterized nodes
                    non_reparam_nodes = set(
                        guide_trace.nonreparam_stochastic_nodes)
                    elbo, surrogate_elbo = _compute_elbo_reparam(
                        model_trace, guide_trace, non_reparam_nodes)

                    # the following computations are only necessary if we have non-reparameterizable nodes
                    baseline_loss = 0.0
                    if non_reparam_nodes:
                        downstream_costs, _ = _compute_downstream_costs(
                            model_trace, guide_trace, non_reparam_nodes)
                        surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
                            guide_trace, non_reparam_nodes, downstream_costs)
                        surrogate_elbo += surrogate_elbo_term

                    loss = loss - weight * elbo
                    surrogate_loss = surrogate_loss - weight * surrogate_elbo

                return loss, surrogate_loss
 def _get_trace(self, model, guide, *args, **kwargs):
     model_trace, guide_trace = super(TraceMeanField_ELBO, self)._get_trace(
         model, guide, *args, **kwargs)
     if is_validation_enabled():
         _check_mean_field_requirement(model_trace, guide_trace)
     return model_trace, guide_trace
Beispiel #26
0
 def _get_trace(self, model, guide, args, kwargs):
     model_trace, guide_trace = super()._get_trace(model, guide, args,
                                                   kwargs)
     if is_validation_enabled():
         _check_mean_field_requirement(model_trace, guide_trace)
     return model_trace, guide_trace