def test_copy(): counts = [] gc.collect() gc.collect() tr = Trace() expected = count_objects_of_type(Trace) for _ in range(10): tr.copy() counts.append(count_objects_of_type(Trace)) assert set(counts) == set([expected]), counts
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs): """ Assert that enumeration runs... """ with pyro_backend("pyro"): pyro.clear_param_store() if guide is None: guide = lambda **kwargs: None # noqa: E731 q_pyro, q_funsor = LifoQueue(), LifoQueue() q_pyro.put(Trace()) q_funsor.put(Trace()) while not q_pyro.empty() and not q_funsor.empty(): with pyro_backend("pyro"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_pyro = handlers.trace( handlers.queue( guide, q_pyro, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_pyro = handlers.trace( handlers.replay(model, trace=guide_tr_pyro)).get_trace(**kwargs) with pyro_backend("contrib.funsor"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_funsor = handlers.trace( handlers.queue( guide, q_funsor, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_funsor = handlers.trace( handlers.replay(model, trace=guide_tr_funsor)).get_trace(**kwargs) # make sure all dimensions were cleaned up assert _DIM_STACK.local_frame is _DIM_STACK.global_frame assert (not _DIM_STACK.global_frame.name_to_dim and not _DIM_STACK.global_frame.dim_to_name) assert _DIM_STACK.outermost is None tr_pyro = prune_subsample_sites(tr_pyro.copy()) tr_funsor = prune_subsample_sites(tr_funsor.copy()) _check_traces(tr_pyro, tr_funsor)
def test_topological_sort(edges): tr = Trace() for n1, n2 in edges: tr.add_edge(n1, n2) top_sort = tr.topological_sort() # check all nodes are accounted for exactly once expected_nodes = set().union(*edges) assert len(top_sort) == len(expected_nodes) assert set(top_sort) == expected_nodes # check no edge ordering is violated ranks = {n: rank for rank, n in enumerate(top_sort)} for n1, n2 in edges: assert ranks[n1] < ranks[n2]
def construct_q_dag(self): g = Trace() def add_edge(s): deps = [] if s == "1": deps.extend(["1L", "1R"]) else: if s[-1] == 'R': deps.append(s[0:-1] + 'L') if len(s) < self.N: deps.extend([s + 'L', s + 'R']) for k in range(len(s) - 2): base = s[1:-1 - k] if base[-1] == 'R': deps.append('1' + base[:-1] + 'L') for dep in deps: g.add_edge("loc_latent_" + dep, "loc_latent_" + s) previous_names = ["1"] add_edge("1") for n in range(2, self.N + 1): new_names = [] for prev_name in previous_names: for LR in ['L', 'R']: new_name = prev_name + LR new_names.append(new_name) add_edge(new_name) previous_names = new_names return g
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add approx = funsor.montecarlo.MonteCarlo() else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") with block(), enum(first_available_dim=first_available_dim): # XXX replay against an empty Trace to ensure densities are not double-counted model_tr = trace(replay(model, trace=Trace())).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, terms["log_factors"] + terms["log_measures"], eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.nodes.items(): if node["type"] != "sample" or site_is_subsample(node): continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)} else: node["funsor"]["log_measure"] = approx_factors[node["funsor"] ["log_measure"]] node["funsor"]["value"] = _get_support_value( node["funsor"]["log_measure"], name) sample_subs[name] = node["funsor"]["value"] with replay(trace=sample_tr): return model(*args, **kwargs)
def test_connectivity_on_removal(edges): # check that when nodes are removed in reverse topological order # connectivity of the DAG is maintained, i.e. remaining nodes # are reachable from the root. root = 1 tr = Trace() for e1, e2 in edges: tr.add_edge(e1, e2) top_sort = tr.topological_sort() while top_sort: num_nodes = len([n for n in tr._dfs(root, set())]) num_expected = len(top_sort) assert_equal(num_nodes, num_expected) tr.remove_node(top_sort.pop())
def iter_discrete_traces(graph_type, fn, *args, **kwargs): """ Iterate over all discrete choices of a stochastic function. When sampling continuous random variables, this behaves like `fn`. When sampling discrete random variables, this iterates over all choices. This yields traces scaled by the probability of the discrete choices made in the `trace`. :param str graph_type: The type of the graph, e.g. "flat" or "dense". :param callable fn: A stochastic function. :returns: An iterator over traces pairs. """ queue = LifoQueue() queue.put(Trace()) traced_fn = poutine.trace( poutine.queue(fn, queue, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend), graph_type=graph_type) while not queue.empty(): yield traced_fn.get_trace(*args, **kwargs)
def _check_traces(tr_pyro, tr_funsor): assert tr_pyro.nodes.keys() == tr_funsor.nodes.keys() tr_pyro.compute_log_prob() tr_funsor.compute_log_prob() tr_pyro.pack_tensors() symbol_to_name = { node['infer']['_enumerate_symbol']: name for name, node in tr_pyro.nodes.items() if node['type'] == 'sample' and not node['is_observed'] and node['infer'].get('enumerate') == 'parallel' } symbol_to_name.update({ symbol: name for name, symbol in tr_pyro.plate_to_symbol.items()}) if _NAMED_TEST_STRENGTH >= 1: # coarser check: enumeration requirements satisfied check_traceenum_requirements(tr_pyro, Trace()) check_traceenum_requirements(tr_funsor, Trace()) try: # coarser check: number of elements and squeezed shapes for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] assert pyro_node['packed']['log_prob'].numel() == funsor_node['log_prob'].numel() assert pyro_node['packed']['log_prob'].shape == funsor_node['log_prob'].squeeze().shape assert frozenset(f for f in pyro_node['cond_indep_stack'] if f.vectorized) == \ frozenset(f for f in funsor_node['cond_indep_stack'] if f.vectorized) except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_packed_shape = pyro_node['packed']['log_prob'].shape funsor_packed_shape = funsor_node['log_prob'].squeeze().shape if pyro_packed_shape != funsor_packed_shape: err_str = "==> (dep mismatch) {}".format(name) else: err_str = name print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_packed_shape, funsor_packed_shape)) raise if _NAMED_TEST_STRENGTH >= 2: try: # medium check: unordered packed shapes match for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims) funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs) assert pyro_names == frozenset(name.replace('__PARTICLES', '') for name in funsor_names) except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims) funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs) if pyro_names != funsor_names: err_str = "==> (packed mismatch) {}".format(name) else: err_str = name print(err_str, "Pyro: {} vs Funsor: {}".format(sorted(tuple(pyro_names)), sorted(tuple(funsor_names)))) raise if _NAMED_TEST_STRENGTH >= 3: try: # finer check: exact match with unpacked Pyro shapes for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] assert pyro_node['log_prob'].shape == funsor_node['log_prob'].shape assert pyro_node['value'].shape == funsor_node['value'].shape except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_shape = pyro_node['log_prob'].shape funsor_shape = funsor_node['log_prob'].shape if pyro_shape != funsor_shape: err_str = "==> (unpacked mismatch) {}".format(name) else: err_str = name print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_shape, funsor_shape)) raise