def _compute_dice_elbo(model_trace, guide_trace): # Accumulate marginal model costs. marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors( model_trace, guide_trace) if log_factors: # Note that while most applications of tensor message passing use the # contract_to_tensor() interface and can be easily refactored to use ubersum(), # the application here relies on contract_tensor_tree() to extract the dependency # structure of different log_prob terms, which is used by Dice to eliminate # zero-expectation terms. One possible refactoring would be to replace # contract_to_tensor() with a RaggedTensor -> Tensor contraction operation, but # replace contract_tensor_tree() with a RaggedTensor -> RaggedTensor contraction # that preserves some dependency structure. with shared_intermediates() as cache: log_factors = contract_tensor_tree(log_factors, sum_dims, cache=cache) for t, log_factors_t in log_factors.items(): marginal_costs_t = marginal_costs.setdefault(t, []) for term in log_factors_t: term = packed.scale_and_mask(term, scale=scale) marginal_costs_t.append(term) costs = marginal_costs # Accumulate negative guide costs. for name, site in guide_trace.nodes.items(): if site["type"] == "sample": cost = packed.neg(site["packed"]["log_prob"]) costs.setdefault(ordering[name], []).append(cost) return Dice(guide_trace, ordering).compute_expectation(costs)
def _compute_tmc_factors(model_trace, guide_trace): """ compute per-site log-factors for all observed and unobserved variables log-factors are log(p / q) for unobserved sites and log(p) for observed sites """ log_factors = [] for name, site in guide_trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue log_proposal = site["packed"]["log_prob"] log_factors.append(packed.neg(log_proposal)) for name, site in model_trace.nodes.items(): if site["type"] != "sample": continue if (site["name"] not in guide_trace and not site["is_observed"] and site["infer"].get("enumerate", None) == "parallel" and site["infer"].get("num_samples", -1) > 0): # site was sampled from the prior log_proposal = packed.neg(site["packed"]["log_prob"]) log_factors.append(log_proposal) log_factors.append(site["packed"]["log_prob"]) return log_factors