Esempio n. 1
0
def _compute_dice_elbo(model_trace, guide_trace):
    # Accumulate marginal model costs.
    marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
        model_trace, guide_trace)
    if log_factors:
        # Note that while most applications of tensor message passing use the
        # contract_to_tensor() interface and can be easily refactored to use ubersum(),
        # the application here relies on contract_tensor_tree() to extract the dependency
        # structure of different log_prob terms, which is used by Dice to eliminate
        # zero-expectation terms. One possible refactoring would be to replace
        # contract_to_tensor() with a RaggedTensor -> Tensor contraction operation, but
        # replace contract_tensor_tree() with a RaggedTensor -> RaggedTensor contraction
        # that preserves some dependency structure.
        with shared_intermediates() as cache:
            log_factors = contract_tensor_tree(log_factors,
                                               sum_dims,
                                               cache=cache)
        for t, log_factors_t in log_factors.items():
            marginal_costs_t = marginal_costs.setdefault(t, [])
            for term in log_factors_t:
                term = packed.scale_and_mask(term, scale=scale)
                marginal_costs_t.append(term)
    costs = marginal_costs

    # Accumulate negative guide costs.
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            cost = packed.neg(site["packed"]["log_prob"])
            costs.setdefault(ordering[name], []).append(cost)

    return Dice(guide_trace, ordering).compute_expectation(costs)
Esempio n. 2
0
def _compute_tmc_factors(model_trace, guide_trace):
    """
    compute per-site log-factors for all observed and unobserved variables
    log-factors are log(p / q) for unobserved sites and log(p) for observed sites
    """
    log_factors = []
    for name, site in guide_trace.nodes.items():
        if site["type"] != "sample" or site["is_observed"]:
            continue
        log_proposal = site["packed"]["log_prob"]
        log_factors.append(packed.neg(log_proposal))
    for name, site in model_trace.nodes.items():
        if site["type"] != "sample":
            continue
        if (site["name"] not in guide_trace and not site["is_observed"]
                and site["infer"].get("enumerate", None) == "parallel"
                and site["infer"].get("num_samples", -1) > 0):
            # site was sampled from the prior
            log_proposal = packed.neg(site["packed"]["log_prob"])
            log_factors.append(log_proposal)
        log_factors.append(site["packed"]["log_prob"])
    return log_factors