def _identify_dense_edges(trace): succ = {} for name, node in trace.items(): if node["type"] == "sample": succ[name] = set() for name, node in trace.items(): if node["type"] == "sample": for past_name, past_node in trace.items(): if past_node["type"] == "sample": if past_name == name: break # XXX: different from Pyro, we always add edge past_name -> name succ[past_name].add(name) return succ
def _get_plate_stacks(trace): """ This builds a dict mapping site name to a set of plate stacks. Each plate stack is a list of :class:`CondIndepStackFrame`s corresponding to a :class:`plate`. This information is used by :class:`Trace_ELBO` and :class:`TraceGraph_ELBO`. """ return { name: [f for f in node["cond_indep_stack"]] for name, node in trace.items() if node["type"] == "sample" }
def get_samples_from_trace(trace, with_intermediates=False): """ Extracts all sample values from a numpyro trace. :param trace: trace object obtained from `numpyro.handlers.trace().get_trace()` :param with_intermediates: If True, intermediate(/latent) samples from sample site distributions are included in the result. :return: Dictionary of sampled values associated with the names given via `sample()` in the model. If with_intermediates is True, dictionary values are tuples where the first element is the final sample values and the second element is a list of intermediate values. """ samples = { k: (v['value'], v['intermediates']) if with_intermediates else v['value'] for k, v in trace.items() if v['type'] == 'sample' } return samples