Exemplo n.º 1
0
def _compute_marginals(model_trace, guide_trace):
    args = _compute_model_factors(model_trace, guide_trace)
    marginal_costs, log_factors, ordering, sum_dims, scale = args

    marginal_dists = OrderedDict()
    with shared_intermediates() as cache:
        for name, site in model_trace.nodes.items():
            if (site["type"] != "sample" or name in guide_trace.nodes
                    or site["infer"].get("_enumerate_dim") is None):
                continue

            enum_dim = site["infer"]["_enumerate_dim"]
            enum_symbol = site["infer"]["_enumerate_symbol"]
            ordinal = _find_ordinal(model_trace, site)
            logits = contract_to_tensor(log_factors,
                                        sum_dims,
                                        target_ordinal=ordinal,
                                        target_dims={enum_symbol},
                                        cache=cache)
            logits = packed.unpack(logits, model_trace.symbol_to_dim)
            logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1)
            while logits.shape[0] == 1:
                logits = logits.squeeze(0)
            marginal_dists[name] = _make_dist(site["fn"], logits)
    return marginal_dists
Exemplo n.º 2
0
 def log_prob(self, model_trace):
     """
     Returns the log pdf of `model_trace` by appropriately handling
     enumerated log prob factors.
     :return: log pdf of the trace.
     """
     if not self.has_enumerable_sites:
         return model_trace.log_prob_sum()
     log_probs = self._get_log_factors(model_trace)
     with shared_intermediates() as cache:
         return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
Exemplo n.º 3
0
 def log_prob(self, model_trace):
     """
     almost identical to that of TraceEinsumEvaluator but
     uses log_prob instead of log_prob_sum
     """
     if not self.has_enumerable_sites:
         log_prob = 0
         for name in model_trace.stochastic_nodes:
             dist = model_trace.nodes[name]['fn']
             value = model_trace.nodes[name]['value']
             site_log_prob = dist.log_prob(value)
             log_prob = log_prob + site_log_prob
         return log_prob
     log_probs = self._get_log_factors(model_trace)
     with shared_intermediates() as cache:
         return contract_to_tensor(log_probs, self._enum_dims, cache=cache)
Exemplo n.º 4
0
 def _pyro_sample(self, msg):
     enum_msg = self.enum_trace.nodes.get(msg["name"])
     if enum_msg is None:
         return
     enum_symbol = enum_msg["infer"].get("_enumerate_symbol")
     if enum_symbol is None:
         return
     enum_dim = enum_msg["infer"]["_enumerate_dim"]
     with shared_intermediates(self.cache):
         ordinal = _find_ordinal(self.enum_trace, msg)
         logits = contract_to_tensor(self.log_factors, self.sum_dims,
                                     target_ordinal=ordinal, target_dims={enum_symbol},
                                     cache=self.cache)
         logits = packed.unpack(logits, self.enum_trace.symbol_to_dim)
         logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1)
         while logits.shape[0] == 1:
             logits = logits.squeeze(0)
     msg["fn"] = _make_dist(msg["fn"], logits)