def _compute_marginals(model_trace, guide_trace): args = _compute_model_factors(model_trace, guide_trace) marginal_costs, log_factors, ordering, sum_dims, scale = args marginal_dists = OrderedDict() with shared_intermediates() as cache: for name, site in model_trace.nodes.items(): if (site["type"] != "sample" or name in guide_trace.nodes or site["infer"].get("_enumerate_dim") is None): continue enum_dim = site["infer"]["_enumerate_dim"] enum_symbol = site["infer"]["_enumerate_symbol"] ordinal = _find_ordinal(model_trace, site) logits = contract_to_tensor(log_factors, sum_dims, target_ordinal=ordinal, target_dims={enum_symbol}, cache=cache) logits = packed.unpack(logits, model_trace.symbol_to_dim) logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1) while logits.shape[0] == 1: logits = logits.squeeze(0) marginal_dists[name] = _make_dist(site["fn"], logits) return marginal_dists
def log_prob(self, model_trace): """ Returns the log pdf of `model_trace` by appropriately handling enumerated log prob factors. :return: log pdf of the trace. """ if not self.has_enumerable_sites: return model_trace.log_prob_sum() log_probs = self._get_log_factors(model_trace) with shared_intermediates() as cache: return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
def log_prob(self, model_trace): """ almost identical to that of TraceEinsumEvaluator but uses log_prob instead of log_prob_sum """ if not self.has_enumerable_sites: log_prob = 0 for name in model_trace.stochastic_nodes: dist = model_trace.nodes[name]['fn'] value = model_trace.nodes[name]['value'] site_log_prob = dist.log_prob(value) log_prob = log_prob + site_log_prob return log_prob log_probs = self._get_log_factors(model_trace) with shared_intermediates() as cache: return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
def _pyro_sample(self, msg): enum_msg = self.enum_trace.nodes.get(msg["name"]) if enum_msg is None: return enum_symbol = enum_msg["infer"].get("_enumerate_symbol") if enum_symbol is None: return enum_dim = enum_msg["infer"]["_enumerate_dim"] with shared_intermediates(self.cache): ordinal = _find_ordinal(self.enum_trace, msg) logits = contract_to_tensor(self.log_factors, self.sum_dims, target_ordinal=ordinal, target_dims={enum_symbol}, cache=self.cache) logits = packed.unpack(logits, self.enum_trace.symbol_to_dim) logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1) while logits.shape[0] == 1: logits = logits.squeeze(0) msg["fn"] = _make_dist(msg["fn"], logits)