Exemple #1
    def __init__(self, guide_trace, ordering):
        log_denom = defaultdict(float)  # avoids double-counting when sequentially enumerating
        log_probs = defaultdict(list)  # accounts for upstream probabilties

        for name, site in guide_trace.nodes.items():
            if site["type"] != "sample":

            log_prob = site["packed"]["score_parts"].score_function  # not scaled by subsampling
            dims = getattr(log_prob, "_pyro_dims", "")
            ordinal = ordering[name]
            if site["infer"].get("enumerate"):
                num_samples = site["infer"].get("num_samples")
                if num_samples is not None:  # site was multiply sampled
                    if not is_identically_zero(log_prob):
                        log_prob = log_prob - log_prob.detach()
                    log_prob = log_prob - math.log(num_samples)
                    if not isinstance(log_prob, torch.Tensor):
                        log_prob = site["value"].new_tensor(log_prob)
                    log_prob._pyro_dims = dims
                    # I don't know why the following broadcast is needed, but it makes tests pass:
                    log_prob, _ = packed.broadcast_all(log_prob, site["packed"]["log_prob"])
                elif site["infer"]["enumerate"] == "sequential":
                    log_denom[ordinal] += math.log(site["infer"]["_enum_total"])
            else:  # site was monte carlo sampled
                if is_identically_zero(log_prob):
                log_prob = log_prob - log_prob.detach()
                log_prob._pyro_dims = dims

        self.log_denom = log_denom
        self.log_probs = log_probs
    def _differentiable_loss_particle(self, model_trace, guide_trace):
        elbo_particle = 0
        surrogate_elbo_particle = 0
        log_r = None

        # compute elbo and surrogate elbo
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample":
                elbo_particle = elbo_particle + torch_item(
                surrogate_elbo_particle = surrogate_elbo_particle + site[

        for name, site in guide_trace.nodes.items():
            if site["type"] == "sample":
                log_prob, score_function_term, entropy_term = site[

                elbo_particle = elbo_particle - torch_item(

                if not is_identically_zero(entropy_term):
                    surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum(

                if not is_identically_zero(score_function_term):
                    if log_r is None:
                        log_r = _compute_log_r(model_trace, guide_trace)
                    site = log_r.sum_to(site["cond_indep_stack"])
                    surrogate_elbo_particle = surrogate_elbo_particle + (
                        site * score_function_term).sum()

        return -elbo_particle, -surrogate_elbo_particle
Exemple #3
def compute_site_dice_factor(site):
    log_denom = 0
    log_prob = site["packed"][
        "score_parts"].score_function  # not scaled by subsampling
    dims = getattr(log_prob, "_pyro_dims", "")
    if site["infer"].get("enumerate"):
        num_samples = site["infer"].get("num_samples")
        if num_samples is not None:  # site was multiply sampled
            if not is_identically_zero(log_prob):
                log_prob = log_prob - log_prob.detach()
            log_prob = log_prob - math.log(num_samples)
            if not isinstance(log_prob, torch.Tensor):
                log_prob = torch.tensor(float(log_prob),
            log_prob._pyro_dims = dims
            # I don't know why the following broadcast is needed, but it makes tests pass:
            log_prob, _ = packed.broadcast_all(log_prob,
        elif site["infer"]["enumerate"] == "sequential":
            log_denom = math.log(site["infer"].get("_enum_total", num_samples))
    else:  # site was monte carlo sampled
        if not is_identically_zero(log_prob):
            log_prob = log_prob - log_prob.detach()
            log_prob._pyro_dims = dims

    return log_prob, log_denom
    def loss_and_grads(self, model, guide, *args, **kwargs):
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        elbo = 0.0
        # grab a trace from the generator
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0
            log_r = None

            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    elbo_particle = elbo_particle + torch_item(
                    surrogate_elbo_particle = surrogate_elbo_particle + site[

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site[

                    elbo_particle = elbo_particle - torch_item(

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum(

                    if not is_identically_zero(score_function_term):
                        if log_r is None:
                            log_r = _compute_log_r(model_trace, guide_trace)
                        site = log_r.sum_to(site["cond_indep_stack"])
                        surrogate_elbo_particle = surrogate_elbo_particle + (
                            site * score_function_term).sum()

            elbo += elbo_particle / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and getattr(surrogate_elbo_particle,
                                            'requires_grad', False):
                surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Exemple #5
def check_fully_reparametrized(guide_site):
    log_prob, score_function_term, entropy_term = guide_site["score_parts"]
    fully_rep = (guide_site["fn"].has_rsample
                 and not is_identically_zero(entropy_term)
                 and is_identically_zero(score_function_term))
    if not fully_rep:
        raise NotImplementedError(
            "All distributions in the guide must be fully reparameterized.")
Exemple #6
def _compute_elbo_reparam(model_trace, guide_trace):

    # In ref [1], section 3.2, the part of the surrogate loss computed here is
    # \sum{cost}, which in this case is the ELBO. Instead of using the ELBO,
    # this implementation uses a surrogate ELBO which modifies some entropy
    # terms depending on the parameterization. This reduces the variance of the
    # gradient under some conditions.

    elbo = 0.0
    surrogate_elbo = 0.0

    # Bring log p(x, z|...) terms into both the ELBO and the surrogate
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            elbo += site["log_prob_sum"]
            surrogate_elbo += site["log_prob_sum"]

    # Bring log q(z|...) terms into the ELBO, and effective terms into the
    # surrogate. Depending on the parameterization of a site, its log q(z|...)
    # cost term may not contribute (in expectation) to the gradient. To reduce
    # the variance under some conditions, the default entropy terms from
    # site[`score_parts`] are used.
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            elbo -= site["log_prob_sum"]
            entropy_term = site["score_parts"].entropy_term
            # For fully reparameterized terms, this entropy_term is log q(z|...)
            # For fully non-reparameterized terms, it is zero
            if not is_identically_zero(entropy_term):
                surrogate_elbo -= entropy_term.sum()

    return elbo, surrogate_elbo
Exemple #7
    def differentiable_loss(self, model, guide, *args, **kwargs):
        :returns: a differentiable estimate of the ELBO
        :rtype: torch.Tensor
        :raises ValueError: if the ELBO is not differentiable (e.g. is
            identically zero)

        Estimates a differentiable ELBO using ``num_particles`` many samples
        (particles).  The result should be infinitely differentiable (as long
        as underlying derivatives have been implemented).
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(
                model, guide, args, kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):

            elbo = elbo + elbo_particle
        elbo = elbo / self.num_particles

        if not torch.is_tensor(elbo) or not elbo.requires_grad:
            raise ValueError(
                'ELBO is cannot be differentiated: {}'.format(elbo))

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
Exemple #8
    def loss_and_grads(self, model, guide, *args, **kwargs):
        :returns: returns an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        Performs backward on the ELBO of each particle.
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):

            elbo += elbo_particle.item() / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and elbo_particle.requires_grad:
                loss_particle = -elbo_particle
                (loss_particle / self.num_particles).backward(retain_graph=True)

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Exemple #9
    def loss_and_grads(self, model, guide, *args, **kwargs):
        :returns: an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        Performs backward on the ELBO of each particle.
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(
                model, guide, args, kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):

            elbo += elbo_particle.item() / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and elbo_particle.requires_grad:
                loss_particle = -elbo_particle
                (loss_particle /

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
Exemple #10
            def loss_and_surrogate_loss(*args, **kwargs):
                self = weakself()
                loss = 0.0
                surrogate_loss = 0.0
                for model_trace, guide_trace in self._get_traces(
                        model, guide, args, kwargs):
                    elbo_particle = 0
                    surrogate_elbo_particle = 0
                    log_r = None

                    # compute elbo and surrogate elbo
                    for name, site in model_trace.nodes.items():
                        if site["type"] == "sample":
                            elbo_particle = elbo_particle + site["log_prob_sum"]
                            surrogate_elbo_particle = (
                                surrogate_elbo_particle + site["log_prob_sum"])

                    for name, site in guide_trace.nodes.items():
                        if site["type"] == "sample":
                            log_prob, score_function_term, entropy_term = site[

                            elbo_particle = elbo_particle - site["log_prob_sum"]

                            if not is_identically_zero(entropy_term):
                                surrogate_elbo_particle = (
                                    surrogate_elbo_particle -

                            if not is_identically_zero(score_function_term):
                                if log_r is None:
                                    log_r = _compute_log_r(
                                        model_trace, guide_trace)
                                site = log_r.sum_to(site["cond_indep_stack"])
                                surrogate_elbo_particle = (
                                    surrogate_elbo_particle +
                                    (site * score_function_term).sum())

                    loss = loss - elbo_particle / self.num_particles
                    surrogate_loss = (
                        surrogate_loss -
                        surrogate_elbo_particle / self.num_particles)

                return loss, surrogate_loss
Exemple #11
    def __init__(self, guide_trace, ordering):
        log_denoms = defaultdict(float)  # avoids double-counting when sequentially enumerating
        log_probs = defaultdict(list)  # accounts for upstream probabilties

        for name, site in guide_trace.nodes.items():
            if site["type"] != "sample":

            ordinal = ordering[name]
            log_prob, log_denom = compute_site_dice_factor(site)
            if not is_identically_zero(log_prob):
            if not is_identically_zero(log_denom):
                log_denoms[ordinal] += log_denom

        self.log_denom = log_denoms
        self.log_probs = log_probs
Exemple #12
def _compute_dice_factors(model_trace, guide_trace):
    compute per-site DiCE log-factors for non-reparameterized proposal sites
    this logic is adapted from pyro.infer.util.Dice.__init__
    log_probs = []
    for role, trace in zip(("model", "guide"), (model_trace, guide_trace)):
        for name, site in trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
            if role == "model" and name in guide_trace:

            log_prob, log_denom = compute_site_dice_factor(site)
            if not is_identically_zero(log_denom):
                dims = log_prob._pyro_dims
                log_prob = log_prob - log_denom
                log_prob._pyro_dims = dims
            if not is_identically_zero(log_prob):

    return log_probs
Exemple #13
    def _get_log_factors(self, target_ordinal):
        Returns a list of DiCE factors at a given ordinal.
        log_denom = 0
        for ordinal, term in self.log_denom.items():
            if not ordinal <= target_ordinal:  # not downstream
                log_denom += term  # term = log(# times this ordinal is counted)

        log_factors = [] if is_identically_zero(log_denom) else [-log_denom]
        for ordinal, terms in self.log_probs.items():
            if ordinal <= target_ordinal:  # upstream
                log_factors.extend(terms)  # terms = [log(dice weight of this ordinal)]

        return log_factors
    def loss(self, model, guide, *args, **kwargs):
        :returns: an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):

            elbo += elbo_particle.item() / self.num_particles

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
Exemple #15
def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes):
    elbo = 0.0
    surrogate_elbo = 0.0

    # deal with log p(z|...) terms
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            elbo += site["log_prob_sum"]
            surrogate_elbo += site["log_prob_sum"]

    # deal with log q(z|...) terms
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            elbo -= site["log_prob_sum"]
            entropy_term = site["score_parts"].entropy_term
            if not is_identically_zero(entropy_term):
                surrogate_elbo -= entropy_term.sum()

    return elbo, surrogate_elbo
Exemple #16
def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes):
    elbo = 0.0
    surrogate_elbo = 0.0

    # deal with log p(z|...) terms
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            elbo += site["log_prob_sum"]
            surrogate_elbo += site["log_prob_sum"]

    # deal with log q(z|...) terms
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            elbo -= site["log_prob_sum"]
            entropy_term = site["score_parts"].entropy_term
            if not is_identically_zero(entropy_term):
                surrogate_elbo -= entropy_term.sum()

    return elbo, surrogate_elbo
Exemple #17
    def loss(self, model, guide, *args, **kwargs):
        :returns: returns an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):

            elbo += elbo_particle.item() / self.num_particles

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Exemple #18
    def loss(self, model, guide, *args, **kwargs):
        :returns: returns an estimate of the ELBO
        :rtype: float

        Estimates the ELBO using ``num_particles`` many samples (particles).
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):

            elbo += elbo_particle.item() / self.num_particles

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Exemple #19
    def _get_log_factors(self, target_ordinal):
        Returns a list of DiCE factors ordinal.
        # memoize
            return self._log_factors_cache[target_ordinal]
        except KeyError:

        log_denom = 0
        for ordinal, term in self.log_denom.items():
            if not ordinal <= target_ordinal:  # not downstream
                log_denom += term  # term = log(# times this ordinal is counted)

        log_factors = [] if is_identically_zero(log_denom) else [-log_denom]
        for ordinal, term in self.log_probs.items():
            if ordinal <= target_ordinal:  # upstream
                log_factors += term  # term = [log(dice weight of this ordinal)]

        self._log_factors_cache[target_ordinal] = log_factors
        return log_factors
Exemple #20
    def differentiable_loss(self, model, guide, *args, **kwargs):
        :returns: a differentiable estimate of the marginal log-likelihood
        :rtype: torch.Tensor
        :raises ValueError: if the ELBO is not differentiable (e.g. is
            identically zero)

        Computes a differentiable TMC estimate using ``num_particles`` many samples
        (particles).  The result should be infinitely differentiable (as long
        as underlying derivatives have been implemented).
        elbo = 0.0
        for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
            elbo_particle = _compute_tmc_estimate(model_trace, guide_trace)
            if is_identically_zero(elbo_particle):

            elbo = elbo + elbo_particle
        elbo = elbo / self.num_particles

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
Exemple #21
    def __init__(self, guide_trace, ordering):
        log_denom = defaultdict(lambda: 0.0)  # avoids double-counting when sequentially enumerating
        log_probs = defaultdict(list)  # accounts for upstream probabilties

        for name, site in guide_trace.nodes.items():
            if site["type"] != "sample":
            log_prob = site['score_parts'].score_function  # not scaled by subsampling
            if is_identically_zero(log_prob):

            ordinal = ordering[name]
            if site["infer"].get("enumerate"):
                if site["infer"]["enumerate"] == "sequential":
                    log_denom[ordinal] += math.log(site["infer"]["_enum_total"])
            else:  # site was monte carlo sampled
                log_prob = log_prob - log_prob.detach()

        self.log_denom = log_denom
        self.log_probs = log_probs
        self._log_factors_cache = {}
        self._prob_cache = {}
Exemple #22
    def get_deltas(self, save_params=None):
        deltas = {}
        aux_values = {}
        compute_density = poutine.get_mask() is not False
        for name, site in self._sorted_sites:
            if save_params is not None and name not in save_params:

            # Sample zero-mean blockwise independent Delta/Normal/MVN.
            log_density = 0.0
            loc = deep_getattr(self.locs, name)
            zero = torch.zeros_like(loc)
            conditional = self.conditionals[name]
            if callable(conditional):
                aux_value = deep_getattr(self.conds, name)()
            elif conditional == "delta":
                aux_value = zero
            elif conditional == "normal":
                aux_value = pyro.sample(
                    name + "_aux",
                    dist.Normal(zero, 1).to_event(1),
                    infer={"is_auxiliary": True},
                scale = deep_getattr(self.scales, name)
                aux_value = aux_value * scale
                if compute_density:
                    log_density = (-scale.log()).expand_as(aux_value)
            elif conditional == "mvn":
                # This overparametrizes by learning (scale,scale_tril),
                # enabling faster learning of the more-global scale parameter.
                aux_value = pyro.sample(
                    name + "_aux",
                    dist.Normal(zero, 1).to_event(1),
                    infer={"is_auxiliary": True},
                scale = deep_getattr(self.scales, name)
                scale_tril = deep_getattr(self.scale_trils, name)
                aux_value = aux_value @ scale_tril.T * scale
                if compute_density:
                    log_density = (
                        -scale_tril.diagonal(dim1=-2, dim2=-1).log() -
                raise ValueError(
                    f"Unsupported conditional type: {conditional}")

            # Accumulate upstream dependencies.
            # Note: by accumulating upstream dependencies before updating the
            # aux_values dict, we encode a block-sparse structure of the
            # precision matrix; if we had instead accumulated after updating
            # aux_values, we would encode a block-sparse structure of the
            # covariance matrix.
            # Note: these shear transforms have no effect on the Jacobian
            # determinant, and can therefore be excluded from the log_density
            # computation below, even for nonlinear dep().
            deps = deep_getattr(self.deps, name)
            for upstream in self.dependencies.get(name, {}):
                dep = deep_getattr(deps, upstream)
                aux_value = aux_value + dep(aux_values[upstream])
            aux_values[name] = aux_value

            # Shift by loc and reshape.
            batch_shape = torch.broadcast_shapes(aux_value.shape[:-1],
            unconstrained = (
                aux_value +
                loc).reshape(batch_shape +
            if not is_identically_zero(log_density):
                log_density = log_density.reshape(batch_shape + (-1, )).sum(-1)

            # Transform to constrained space.
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained)
            if compute_density and conditional != "delta":
                assert transform.codomain.event_dim == site["fn"].event_dim
                log_density = log_density + transform.inv.log_abs_det_jacobian(
                    value, unconstrained)

            # Create a reparametrized Delta distribution.
            deltas[name] = dist.Delta(value, log_density, site["fn"].event_dim)

        return deltas
Exemple #23
    def loss_and_grads(self, model, guide, *args, **kwargs):
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        elbo_particles = []
        surrogate_elbo_particles = []
        is_vectorized = self.vectorize_particles and self.num_particles > 1
        tensor_holder = None

        # grab a vectorized trace from the generator
        for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0

            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    if is_vectorized:
                        log_prob_sum = site["log_prob"].reshape(self.num_particles, -1).sum(-1)
                        log_prob_sum = site["log_prob_sum"]
                    elbo_particle = elbo_particle + log_prob_sum.detach()
                    surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site["score_parts"]
                    if is_vectorized:
                        log_prob_sum = log_prob.reshape(self.num_particles, -1).sum(-1)
                        log_prob_sum = site["log_prob_sum"]

                    elbo_particle = elbo_particle - log_prob_sum.detach()

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum

                        if not is_identically_zero(score_function_term):
                            # link to the issue: https://github.com/pyro-ppl/pyro/issues/1222
                            raise NotImplementedError

                    if not is_identically_zero(score_function_term):
                        surrogate_elbo_particle = (surrogate_elbo_particle +
                                                   (self.alpha / (1. - self.alpha)) * log_prob_sum)

            if is_identically_zero(elbo_particle):
                if tensor_holder is not None:
                    elbo_particle = torch.zeros_like(tensor_holder)
                    surrogate_elbo_particle = torch.zeros_like(tensor_holder)
            else:  # elbo_particle is not None
                if tensor_holder is None:
                    tensor_holder = torch.zeros_like(elbo_particle)
                    # change types of previous `elbo_particle`s
                    for i in range(len(elbo_particles)):
                        elbo_particles[i] = torch.zeros_like(tensor_holder)
                        surrogate_elbo_particles[i] = torch.zeros_like(tensor_holder)


        if tensor_holder is None:
            return 0.

        if is_vectorized:
            elbo_particles = elbo_particles[0]
            surrogate_elbo_particles = surrogate_elbo_particles[0]
            elbo_particles = torch.stack(elbo_particles)
            surrogate_elbo_particles = torch.stack(surrogate_elbo_particles)

        log_weights = (1. - self.alpha) * elbo_particles
        log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log(self.num_particles)
        elbo = log_mean_weight.sum().item() / (1. - self.alpha)

        # collect parameters to train from model and guide
        trainable_params = any(site["type"] == "param"
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values())

        if trainable_params and getattr(surrogate_elbo_particles, 'requires_grad', False):
            normalized_weights = (log_weights - log_mean_weight).exp()
            surrogate_elbo = (normalized_weights * surrogate_elbo_particles).sum() / self.num_particles
            surrogate_loss = -surrogate_elbo

        loss = -elbo
        warn_if_nan(loss, "loss")
        return loss
Exemple #24
    def loss_and_grads(self, model, guide, *args, **kwargs):
        # TODO: add argument lambda --> assigns weights to losses
        # TODO: Normalize loss elbo value if not done
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.

        elbo = 0.0
        dyn_loss = 0.0
        dim_loss = 0.0

        # grab a trace from the generator
        for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
            elbo_particle = 0
            surrogate_elbo_particle = 0
            log_r = None

            ys = []
            # compute elbo and surrogate elbo
            for name, site in model_trace.nodes.items():
                if site["type"] == "sample":
                    elbo_particle = elbo_particle + torch_item(site["log_prob_sum"])
                    surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"]

            for name, site in guide_trace.nodes.items():
                if site["type"] == "sample":
                    log_prob, score_function_term, entropy_term = site["score_parts"]

                    elbo_particle = elbo_particle - torch_item(site["log_prob_sum"])

                    if not is_identically_zero(entropy_term):
                        surrogate_elbo_particle = surrogate_elbo_particle - entropy_term.sum()

                    if not is_identically_zero(score_function_term):
                        if log_r is None:
                            log_r = _compute_log_r(model_trace, guide_trace)
                        site = log_r.sum_to(site["cond_indep_stack"])
                        surrogate_elbo_particle = surrogate_elbo_particle + (site * score_function_term).sum()

                    if site["name"].startswith("y_"):
                        # TODO: check order of y
            man = torch.stack(ys, dim=1)
            mean_man = man.mean(dim=1, keepdims=True)
            man = man - mean_man
            dyn_loss += self._get_logdet_loss(man, delta=self.delta)  # TODO: Normalize
            dim_loss += self._get_traceK_loss(man)
            elbo += elbo_particle / self.num_particles

            # collect parameters to train from model and guide
            trainable_params = any(site["type"] == "param"
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values())

            if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False):
                surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles \
                                          +self.lam * dyn_loss \
                                          +self.gam * dim_loss

        loss = -elbo
        if torch_isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss, dyn_loss.item(), dim_loss.item(), man
Exemple #25
 def loss_and_grads(self, model, guide, *args, **kwargs):
     loss = self.differentiable_loss(model, guide, *args, **kwargs)
     if is_identically_zero(loss) or not loss.requires_grad:
         return torch_item(loss)
     return loss.item()