Beispiel #1
0
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs):
    """
    Assert that enumeration runs...
    """
    with pyro_backend("pyro"):
        pyro.clear_param_store()

    if guide is None:
        guide = lambda **kwargs: None  # noqa: E731

    q_pyro, q_funsor = LifoQueue(), LifoQueue()
    q_pyro.put(Trace())
    q_funsor.put(Trace())

    while not q_pyro.empty() and not q_funsor.empty():
        with pyro_backend("pyro"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_pyro = handlers.trace(
                    handlers.queue(
                        guide,
                        q_pyro,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_pyro = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_pyro)).get_trace(**kwargs)

        with pyro_backend("contrib.funsor"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_funsor = handlers.trace(
                    handlers.queue(
                        guide,
                        q_funsor,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_funsor = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_funsor)).get_trace(**kwargs)

        # make sure all dimensions were cleaned up
        assert _DIM_STACK.local_frame is _DIM_STACK.global_frame
        assert (not _DIM_STACK.global_frame.name_to_dim
                and not _DIM_STACK.global_frame.dim_to_name)
        assert _DIM_STACK.outermost is None

        tr_pyro = prune_subsample_sites(tr_pyro.copy())
        tr_funsor = prune_subsample_sites(tr_funsor.copy())
        _check_traces(tr_pyro, tr_funsor)
    def construct_q_dag(self):
        g = Trace()

        def add_edge(s):
            deps = []
            if s == "1":
                deps.extend(["1L", "1R"])
            else:
                if s[-1] == 'R':
                    deps.append(s[0:-1] + 'L')
                if len(s) < self.N:
                    deps.extend([s + 'L', s + 'R'])
                for k in range(len(s) - 2):
                    base = s[1:-1 - k]
                    if base[-1] == 'R':
                        deps.append('1' + base[:-1] + 'L')
            for dep in deps:
                g.add_edge("loc_latent_" + dep, "loc_latent_" + s)

        previous_names = ["1"]
        add_edge("1")
        for n in range(2, self.N + 1):
            new_names = []
            for prev_name in previous_names:
                for LR in ['L', 'R']:
                    new_name = prev_name + LR
                    new_names.append(new_name)
                    add_edge(new_name)
            previous_names = new_names

        return g
Beispiel #3
0
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        approx = funsor.montecarlo.MonteCarlo()
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    with block(), enum(first_available_dim=first_available_dim):
        # XXX replay against an empty Trace to ensure densities are not double-counted
        model_tr = trace(replay(model,
                                trace=Trace())).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            terms["log_factors"] + terms["log_measures"],
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.nodes.items():
        if node["type"] != "sample" or site_is_subsample(node):
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)}
        else:
            node["funsor"]["log_measure"] = approx_factors[node["funsor"]
                                                           ["log_measure"]]
            node["funsor"]["value"] = _get_support_value(
                node["funsor"]["log_measure"], name)
            sample_subs[name] = node["funsor"]["value"]

    with replay(trace=sample_tr):
        return model(*args, **kwargs)
Beispiel #4
0
def test_copy():
    counts = []
    gc.collect()
    gc.collect()
    tr = Trace()
    expected = count_objects_of_type(Trace)
    for _ in range(10):
        tr.copy()
        counts.append(count_objects_of_type(Trace))

    assert set(counts) == set([expected]), counts
Beispiel #5
0
def test_connectivity_on_removal(edges):
    # check that when nodes are removed in reverse topological order
    # connectivity of the DAG is maintained, i.e. remaining nodes
    # are reachable from the root.
    root = 1
    tr = Trace()
    for e1, e2 in edges:
        tr.add_edge(e1, e2)
    top_sort = tr.topological_sort()
    while top_sort:
        num_nodes = len([n for n in tr._dfs(root, set())])
        num_expected = len(top_sort)
        assert_equal(num_nodes, num_expected)
        tr.remove_node(top_sort.pop())
Beispiel #6
0
def test_topological_sort(edges):
    tr = Trace()
    for n1, n2 in edges:
        tr.add_edge(n1, n2)
    top_sort = tr.topological_sort()

    # check all nodes are accounted for exactly once
    expected_nodes = set().union(*edges)
    assert len(top_sort) == len(expected_nodes)
    assert set(top_sort) == expected_nodes

    # check no edge ordering is violated
    ranks = {n: rank for rank, n in enumerate(top_sort)}
    for n1, n2 in edges:
        assert ranks[n1] < ranks[n2]
Beispiel #7
0
def iter_discrete_traces(graph_type, fn, *args, **kwargs):
    """
    Iterate over all discrete choices of a stochastic function.

    When sampling continuous random variables, this behaves like `fn`.
    When sampling discrete random variables, this iterates over all choices.

    This yields traces scaled by the probability of the discrete choices made
    in the `trace`.

    :param str graph_type: The type of the graph, e.g. "flat" or "dense".
    :param callable fn: A stochastic function.
    :returns: An iterator over traces pairs.
    """
    queue = LifoQueue()
    queue.put(Trace())
    traced_fn = poutine.trace(
        poutine.queue(fn, queue, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend),
        graph_type=graph_type)
    while not queue.empty():
        yield traced_fn.get_trace(*args, **kwargs)
def _check_traces(tr_pyro, tr_funsor):

    assert tr_pyro.nodes.keys() == tr_funsor.nodes.keys()
    tr_pyro.compute_log_prob()
    tr_funsor.compute_log_prob()
    tr_pyro.pack_tensors()

    symbol_to_name = {
        node['infer']['_enumerate_symbol']: name
        for name, node in tr_pyro.nodes.items()
        if node['type'] == 'sample' and not node['is_observed']
        and node['infer'].get('enumerate') == 'parallel'
    }
    symbol_to_name.update({
        symbol: name for name, symbol in tr_pyro.plate_to_symbol.items()})

    if _NAMED_TEST_STRENGTH >= 1:
        # coarser check: enumeration requirements satisfied
        check_traceenum_requirements(tr_pyro, Trace())
        check_traceenum_requirements(tr_funsor, Trace())
        try:
            # coarser check: number of elements and squeezed shapes
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                assert pyro_node['packed']['log_prob'].numel() == funsor_node['log_prob'].numel()
                assert pyro_node['packed']['log_prob'].shape == funsor_node['log_prob'].squeeze().shape
                assert frozenset(f for f in pyro_node['cond_indep_stack'] if f.vectorized) == \
                    frozenset(f for f in funsor_node['cond_indep_stack'] if f.vectorized)
        except AssertionError:
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_packed_shape = pyro_node['packed']['log_prob'].shape
                funsor_packed_shape = funsor_node['log_prob'].squeeze().shape
                if pyro_packed_shape != funsor_packed_shape:
                    err_str = "==> (dep mismatch) {}".format(name)
                else:
                    err_str = name
                print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_packed_shape, funsor_packed_shape))
            raise

    if _NAMED_TEST_STRENGTH >= 2:
        try:
            # medium check: unordered packed shapes match
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims)
                funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs)
                assert pyro_names == frozenset(name.replace('__PARTICLES', '') for name in funsor_names)
        except AssertionError:
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims)
                funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs)
                if pyro_names != funsor_names:
                    err_str = "==> (packed mismatch) {}".format(name)
                else:
                    err_str = name
                print(err_str, "Pyro: {} vs Funsor: {}".format(sorted(tuple(pyro_names)), sorted(tuple(funsor_names))))
            raise

    if _NAMED_TEST_STRENGTH >= 3:
        try:
            # finer check: exact match with unpacked Pyro shapes
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                assert pyro_node['log_prob'].shape == funsor_node['log_prob'].shape
                assert pyro_node['value'].shape == funsor_node['value'].shape
        except AssertionError:
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_shape = pyro_node['log_prob'].shape
                funsor_shape = funsor_node['log_prob'].shape
                if pyro_shape != funsor_shape:
                    err_str = "==> (unpacked mismatch) {}".format(name)
                else:
                    err_str = name
                print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_shape, funsor_shape))
            raise