Beispiel #1
0
    def process(self, message):
        output = self.output
        operands = list(self.operands)
        contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output)))
        batch_dims = output

        # Slice down operands before combining terms.
        sample2 = message
        if sample2 is not None:
            for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
                batch_dims = batch_dims.replace(dim, '')
                for i, x in enumerate(operands):
                    if dim in x._pyro_dims:
                        index._pyro_dims = sample2._pyro_dims[1:]
                        x = packed.gather(x, index, dim)
                    operands[i] = x

        # Combine terms.
        dims = batch_dims + contract_dims
        logits = reduce(operator.add, packed.broadcast_all(*operands, dims=dims))

        # Sample.
        sample1 = None  # work around lack of pytorch support for zero-sized tensors
        if contract_dims:
            output_shape = logits.shape[:len(batch_dims)]
            contract_shape = logits.shape[len(batch_dims):]
            flat_logits = logits.reshape(output_shape + (-1,))
            flat_sample = dist.Categorical(logits=flat_logits).sample()
            sample1 = unflatten(flat_sample, batch_dims, contract_dims, contract_shape)

        # Cut down samples to pass on to subsequent steps.
        return einsum_backward_sample(self.operands, sample1, sample2)
Beispiel #2
0
def einsum_backward_sample(operands, sample1, sample2):
    """
    Cuts down samples to pass on to subsequent steps.
    This is used in various ``_EinsumBackward.__call__()`` methods.
    This assumes all operands have a ``._pyro_dims`` attribute set.
    """
    # Combine upstream sample with sample at this site.
    if sample1 is None:
        sample = sample2
    elif sample2 is None:
        sample = sample1
    else:
        # Slice sample1 down based on choices in sample2.
        assert set(sample1._pyro_sample_dims).isdisjoint(
            sample2._pyro_sample_dims)
        sample_dims = sample1._pyro_sample_dims + sample2._pyro_sample_dims
        for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
            if dim in sample1._pyro_dims:
                index._pyro_dims = sample2._pyro_dims[1:]
                sample1 = packed.gather(sample1, index, dim)

        # Concatenate the two samples.
        parts = packed.broadcast_all(sample1, sample2)
        sample = torch.cat(parts)
        sample._pyro_dims = parts[0]._pyro_dims
        sample._pyro_sample_dims = sample_dims
        assert sample.dim() == len(sample._pyro_dims)
        if not torch._C._get_tracing_state():
            assert sample.size(0) == len(sample._pyro_sample_dims)

    # Select sample dimensions to pass on to downstream sites.
    for x in operands:
        if not hasattr(x, '_pyro_backward'):
            continue
        if sample is None:
            yield x._pyro_backward, None
            continue
        x_sample_dims = set(x._pyro_dims) & set(sample._pyro_sample_dims)
        if not x_sample_dims:
            yield x._pyro_backward, None
            continue
        if x_sample_dims == set(sample._pyro_sample_dims):
            yield x._pyro_backward, sample
            continue
        x_sample_dims = ''.join(sorted(x_sample_dims))
        x_sample = sample[[
            sample._pyro_sample_dims.index(dim) for dim in x_sample_dims
        ]]
        x_sample._pyro_dims = sample._pyro_dims
        x_sample._pyro_sample_dims = x_sample_dims
        assert x_sample.dim() == len(x_sample._pyro_dims)
        if not torch._C._get_tracing_state():
            assert x_sample.size(0) == len(x_sample._pyro_sample_dims)
        yield x._pyro_backward, x_sample
 def _pyro_post_sample(self, msg):
     enum_msg = self.enum_trace.nodes.get(msg["name"])
     if enum_msg is None:
         return
     enum_symbol = enum_msg["infer"].get("_enumerate_symbol")
     if enum_symbol is None:
         return
     value = packed.pack(msg["value"].long(), enum_msg["infer"]["_dim_to_symbol"])
     assert enum_symbol not in value._pyro_dims
     for t, terms in self.log_factors.items():
         for i, term in enumerate(terms):
             if enum_symbol in term._pyro_dims:
                 terms[i] = packed.gather(term, value, enum_symbol)
     self.sum_dims.remove(enum_symbol)
Beispiel #4
0
def _sample_posterior_from_trace(model, enum_trace, temperature, *args,
                                 **kwargs):
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    sum_dims = set()
    queries = []
    dim_to_size = {}
    cost_terms = OrderedDict()
    enum_terms = OrderedDict()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized and f.size > 1)
            # For sites that depend on an enumerated variable, we need to apply
            # the mask but not the scale when sampling.
            if "masked_log_prob" not in node["packed"]:
                node["packed"]["masked_log_prob"] = packed.scale_and_mask(
                    node["packed"]["unscaled_log_prob"],
                    mask=node["packed"]["mask"])
            log_prob = node["packed"]["masked_log_prob"]
            sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal)
            if sum_dims.isdisjoint(log_prob._pyro_dims):
                continue
            dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape))
            if node["infer"].get("_enumerate_dim") is None:
                cost_terms.setdefault(ordinal, []).append(log_prob)
            else:
                enum_terms.setdefault(ordinal, []).append(log_prob)
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # We take special care to match the term ordering in
    # pyro.infer.traceenum_elbo._compute_model_factors() to allow
    # contract_tensor_tree() to use shared_intermediates() inside
    # TraceEnumSample_ELBO. The special ordering is: first all cost terms in
    # order of model_trace, then all enum_terms in order of model trace.
    log_probs = cost_terms
    for ordinal, terms in enum_terms.items():
        log_probs.setdefault(ordinal, []).extend(terms)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    cache = getattr(enum_trace, "_sharing_cache", {})
    ring = _make_ring(temperature, cache, dim_to_size)
    with shared_intermediates(cache):
        log_probs = contract_tensor_tree(log_probs, sum_dims,
                                         ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["masked_log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not (f.vectorized and f.size > 1)
                    or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)
Beispiel #5
0
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):
    # For internal use by infer_discrete.

    # Create an enumerated trace.
    with poutine.block(), EnumerateMessenger(first_available_dim):
        enum_trace = poutine.trace(model).get_trace(*args, **kwargs)
    enum_trace = prune_subsample_sites(enum_trace)
    enum_trace.compute_log_prob()
    enum_trace.pack_tensors()
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    log_probs = OrderedDict()
    sum_dims = set()
    queries = []
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized)
            log_prob = node["packed"]["log_prob"]
            log_probs.setdefault(ordinal, []).append(log_prob)
            sum_dims.update(log_prob._pyro_dims)
            for frame in node["cond_indep_stack"]:
                if frame.vectorized:
                    sum_dims.remove(plate_to_symbol[frame.name])
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    ring = _make_ring(temperature)
    log_probs = contract_tensor_tree(log_probs, sum_dims,
                                     ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not f.vectorized or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)