Пример #1
0
 def pack_tensors(self, plate_to_symbol=None):
     """
     Computes packed representations of tensors in the trace.
     This should be called after :meth:`compute_log_prob` or :meth:`compute_score_parts`.
     """
     self.symbolize_dims(plate_to_symbol)
     for site in self.nodes.values():
         if site["type"] != "sample":
             continue
         dim_to_symbol = site["infer"]["_dim_to_symbol"]
         packed = site.setdefault("packed", {})
         try:
             packed["mask"] = pack(site["mask"], dim_to_symbol)
             if "score_parts" in site:
                 log_prob, score_function, entropy_term = site["score_parts"]
                 log_prob = pack(log_prob, dim_to_symbol)
                 score_function = pack(score_function, dim_to_symbol)
                 entropy_term = pack(entropy_term, dim_to_symbol)
                 packed["score_parts"] = ScoreParts(log_prob, score_function, entropy_term)
                 packed["log_prob"] = log_prob
                 packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
             elif "log_prob" in site:
                 packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
                 packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
         except ValueError:
             _, exc_value, traceback = sys.exc_info()
             shapes = self.format_shapes(last_site=site["name"])
             raise ValueError("Error while packing tensors at site '{}':\n  {}\n{}"
                              .format(site["name"], exc_value, shapes)).with_traceback(traceback)
Пример #2
0
def test_broadcast_all(shapes):
    inputs, dim_to_symbol, symbol_to_dim = make_inputs(shapes)
    packed_inputs = [packed.pack(x, dim_to_symbol) for x in inputs]
    packed_outputs = packed.broadcast_all(*packed_inputs)
    actual = tuple(packed.unpack(x, symbol_to_dim) for x in packed_outputs)
    expected = broadcast_all(*inputs) if inputs else []
    assert len(actual) == len(expected)
    for a, e in zip(actual, expected):
        assert_equal(a, e)
Пример #3
0
 def _pyro_post_sample(self, msg):
     enum_msg = self.enum_trace.nodes.get(msg["name"])
     if enum_msg is None:
         return
     enum_symbol = enum_msg["infer"].get("_enumerate_symbol")
     if enum_symbol is None:
         return
     value = packed.pack(msg["value"].long(), enum_msg["infer"]["_dim_to_symbol"])
     assert enum_symbol not in value._pyro_dims
     for t, terms in self.log_factors.items():
         for i, term in enumerate(terms):
             if enum_symbol in term._pyro_dims:
                 terms[i] = packed.gather(term, value, enum_symbol)
     self.sum_dims.remove(enum_symbol)
Пример #4
0
def test_unpack_pack(dims):
    dim_to_symbol = {}
    symbol_to_dim = {}
    for symbol, dim in zip('abcd', range(-1, -5, -1)):
        dim_to_symbol[dim] = symbol
        symbol_to_dim[symbol] = dim
    shape = tuple(range(2, 2 + len(dims)))
    x = torch.randn(shape)

    pack_x = packed.pack(x, dim_to_symbol)
    unpack_pack_x = packed.unpack(pack_x, symbol_to_dim)
    assert_equal(unpack_pack_x, x)

    sort_dims = ''.join(sorted(dims))
    if sort_dims != pack_x._pyro_dims:
        sort_pack_x = pack_x.permute(*(pack_x._pyro_dims.index(d) for d in sort_dims))
        sort_pack_x._pyro_dims = sort_dims
        unpack_sort_pack_x = packed.unpack(sort_pack_x, symbol_to_dim)
        assert_equal(unpack_sort_pack_x, x)
Пример #5
0
def _sample_posterior_from_trace(model, enum_trace, temperature, *args,
                                 **kwargs):
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    sum_dims = set()
    queries = []
    dim_to_size = {}
    cost_terms = OrderedDict()
    enum_terms = OrderedDict()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized and f.size > 1)
            # For sites that depend on an enumerated variable, we need to apply
            # the mask but not the scale when sampling.
            if "masked_log_prob" not in node["packed"]:
                node["packed"]["masked_log_prob"] = packed.scale_and_mask(
                    node["packed"]["unscaled_log_prob"],
                    mask=node["packed"]["mask"])
            log_prob = node["packed"]["masked_log_prob"]
            sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal)
            if sum_dims.isdisjoint(log_prob._pyro_dims):
                continue
            dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape))
            if node["infer"].get("_enumerate_dim") is None:
                cost_terms.setdefault(ordinal, []).append(log_prob)
            else:
                enum_terms.setdefault(ordinal, []).append(log_prob)
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # We take special care to match the term ordering in
    # pyro.infer.traceenum_elbo._compute_model_factors() to allow
    # contract_tensor_tree() to use shared_intermediates() inside
    # TraceEnumSample_ELBO. The special ordering is: first all cost terms in
    # order of model_trace, then all enum_terms in order of model trace.
    log_probs = cost_terms
    for ordinal, terms in enum_terms.items():
        log_probs.setdefault(ordinal, []).extend(terms)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    cache = getattr(enum_trace, "_sharing_cache", {})
    ring = _make_ring(temperature, cache, dim_to_size)
    with shared_intermediates(cache):
        log_probs = contract_tensor_tree(log_probs, sum_dims,
                                         ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["masked_log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not (f.vectorized and f.size > 1)
                    or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)
Пример #6
0
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):
    # For internal use by infer_discrete.

    # Create an enumerated trace.
    with poutine.block(), EnumerateMessenger(first_available_dim):
        enum_trace = poutine.trace(model).get_trace(*args, **kwargs)
    enum_trace = prune_subsample_sites(enum_trace)
    enum_trace.compute_log_prob()
    enum_trace.pack_tensors()
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    log_probs = OrderedDict()
    sum_dims = set()
    queries = []
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized)
            log_prob = node["packed"]["log_prob"]
            log_probs.setdefault(ordinal, []).append(log_prob)
            sum_dims.update(log_prob._pyro_dims)
            for frame in node["cond_indep_stack"]:
                if frame.vectorized:
                    sum_dims.remove(plate_to_symbol[frame.name])
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    ring = _make_ring(temperature)
    log_probs = contract_tensor_tree(log_probs, sum_dims,
                                     ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not f.vectorized or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)