def _compute_dice_elbo(model_trace, guide_trace): # Accumulate marginal model costs. marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors( model_trace, guide_trace) if log_factors: # Note that while most applications of tensor message passing use the # contract_to_tensor() interface and can be easily refactored to use ubersum(), # the application here relies on contract_tensor_tree() to extract the dependency # structure of different log_prob terms, which is used by Dice to eliminate # zero-expectation terms. One possible refactoring would be to replace # contract_to_tensor() with a RaggedTensor -> Tensor contraction operation, but # replace contract_tensor_tree() with a RaggedTensor -> RaggedTensor contraction # that preserves some dependency structure. with shared_intermediates() as cache: log_factors = contract_tensor_tree(log_factors, sum_dims, cache=cache) for t, log_factors_t in log_factors.items(): marginal_costs_t = marginal_costs.setdefault(t, []) for term in log_factors_t: term = packed.scale_and_mask(term, scale=scale) marginal_costs_t.append(term) costs = marginal_costs # Accumulate negative guide costs. for name, site in guide_trace.nodes.items(): if site["type"] == "sample": cost = packed.neg(site["packed"]["log_prob"]) costs.setdefault(ordering[name], []).append(cost) return Dice(guide_trace, ordering).compute_expectation(costs)
def _sample_posterior_from_trace(model, enum_trace, temperature, *args, **kwargs): plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. sum_dims = set() queries = [] dim_to_size = {} cost_terms = OrderedDict() enum_terms = OrderedDict() for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized and f.size > 1) # For sites that depend on an enumerated variable, we need to apply # the mask but not the scale when sampling. if "masked_log_prob" not in node["packed"]: node["packed"]["masked_log_prob"] = packed.scale_and_mask( node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"]) log_prob = node["packed"]["masked_log_prob"] sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal) if sum_dims.isdisjoint(log_prob._pyro_dims): continue dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape)) if node["infer"].get("_enumerate_dim") is None: cost_terms.setdefault(ordinal, []).append(log_prob) else: enum_terms.setdefault(ordinal, []).append(log_prob) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # We take special care to match the term ordering in # pyro.infer.traceenum_elbo._compute_model_factors() to allow # contract_tensor_tree() to use shared_intermediates() inside # TraceEnumSample_ELBO. The special ordering is: first all cost terms in # order of model_trace, then all enum_terms in order of model trace. log_probs = cost_terms for ordinal, terms in enum_terms.items(): log_probs.setdefault(ordinal, []).extend(terms) # Run forward-backward algorithm, collecting the ordinal of each connected component. cache = getattr(enum_trace, "_sharing_cache", {}) ring = _make_ring(temperature, cache, dim_to_size) with shared_intermediates(cache): log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["masked_log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not (f.vectorized and f.size > 1) or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # For internal use by infer_discrete. # Create an enumerated trace. with poutine.block(), EnumerateMessenger(first_available_dim): enum_trace = poutine.trace(model).get_trace(*args, **kwargs) enum_trace = prune_subsample_sites(enum_trace) enum_trace.compute_log_prob() enum_trace.pack_tensors() plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. log_probs = OrderedDict() sum_dims = set() queries = [] for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized) log_prob = node["packed"]["log_prob"] log_probs.setdefault(ordinal, []).append(log_prob) sum_dims.update(log_prob._pyro_dims) for frame in node["cond_indep_stack"]: if frame.vectorized: sum_dims.remove(plate_to_symbol[frame.name]) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # Run forward-backward algorithm, collecting the ordinal of each connected component. ring = _make_ring(temperature) log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not f.vectorized or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)