Пример #1
0
def _identify_dense_edges(trace):
    succ = {}
    for name, node in trace.items():
        if node["type"] == "sample":
            succ[name] = set()
    for name, node in trace.items():
        if node["type"] == "sample":
            for past_name, past_node in trace.items():
                if past_node["type"] == "sample":
                    if past_name == name:
                        break
                    # XXX: different from Pyro, we always add edge past_name -> name
                    succ[past_name].add(name)
    return succ
Пример #2
0
def _get_plate_stacks(trace):
    """
    This builds a dict mapping site name to a set of plate stacks. Each
    plate stack is a list of :class:`CondIndepStackFrame`s corresponding to
    a :class:`plate`. This information is used by :class:`Trace_ELBO` and
    :class:`TraceGraph_ELBO`.
    """
    return {
        name: [f for f in node["cond_indep_stack"]]
        for name, node in trace.items() if node["type"] == "sample"
    }
Пример #3
0
def get_samples_from_trace(trace, with_intermediates=False):
    """ Extracts all sample values from a numpyro trace.

    :param trace: trace object obtained from `numpyro.handlers.trace().get_trace()`
    :param with_intermediates: If True, intermediate(/latent) samples from
        sample site distributions are included in the result.
    :return: Dictionary of sampled values associated with the names given
        via `sample()` in the model. If with_intermediates is True,
        dictionary values are tuples where the first element is the final
        sample values and the second element is a list of intermediate values.
    """
    samples = {
        k: (v['value'], v['intermediates']) if with_intermediates else v['value']
        for k, v in trace.items() if v['type'] == 'sample'
    }
    return samples