Esempio n. 1
0
def test_marginal(equation):
    inputs, output = equation.split("->")
    inputs = inputs.split(",")
    operands = [
        torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs
    ]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # check forward pass
    for x in operands:
        require_backward(x)
    actual = contract(equation,
                      *operands,
                      backend="pyro.ops.einsum.torch_marginal")
    expected = contract(equation,
                        *operands,
                        backend="pyro.ops.einsum.torch_log")
    assert_equal(expected, actual)

    # check backward pass
    actual._pyro_backward()
    for input_, operand in zip(inputs, operands):
        marginal_equation = ",".join(inputs) + "->" + input_
        expected = contract(marginal_equation,
                            *operands,
                            backend="pyro.ops.einsum.torch_log")
        actual = operand._pyro_backward_result
        assert_equal(expected, actual)
Esempio n. 2
0
def test_adjoint_marginal(equation, plates):
    inputs, output = equation.split('->')
    inputs = inputs.split(',')
    operands = [torch.randn(torch.Size((2,) * len(input_)))
                for input_ in inputs]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # check forward pass
    for x in operands:
        require_backward(x)
    actual, = ubersum(equation, *operands, plates=plates, modulo_total=True,
                      backend='pyro.ops.einsum.torch_marginal')
    expected, = ubersum(equation, *operands, plates=plates, modulo_total=True,
                        backend='pyro.ops.einsum.torch_log')
    assert_equal(expected, actual)

    # check backward pass
    actual._pyro_backward()
    for input_, operand in zip(inputs, operands):
        marginal_equation = ','.join(inputs) + '->' + input_
        expected, = ubersum(marginal_equation, *operands, plates=plates, modulo_total=True,
                            backend='pyro.ops.einsum.torch_log')
        actual = operand._pyro_backward_result
        assert_equal(expected, actual)
Esempio n. 3
0
def test_adjoint_shape(backend, equation, plates):
    backend = "pyro.ops.einsum.torch_{}".format(backend)
    inputs, output = equation.split("->")
    inputs = inputs.split(",")
    operands = [
        torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs
    ]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # run forward-backward algorithm
    for x in operands:
        require_backward(x)
    (result, ) = ubersum(equation,
                         *operands,
                         plates=plates,
                         modulo_total=True,
                         backend=backend)
    result._pyro_backward()

    for input_, x in zip(inputs, operands):
        backward_result = x._pyro_backward_result
        contract_dims = set(input_) - set(output) - set(plates)
        if contract_dims:
            assert backward_result is not None
        else:
            assert backward_result is None
Esempio n. 4
0
    def compute_expectation(self, costs):
        """
        Returns a differentiable expected cost, summing over costs at given ordinals.

        :param dict costs: A dict mapping ordinals to lists of cost tensors
        :returns: a scalar expected cost
        :rtype: torch.Tensor or float
        """
        # Share computation across all cost terms.
        with shared_intermediates() as cache:
            ring = MarginalRing(cache=cache)
            expected_cost = 0.
            for ordinal, cost_terms in costs.items():
                log_factors = self._get_log_factors(ordinal)
                scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor)))
                log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)]

                # Collect log_prob terms to query for marginal probability.
                queries = {frozenset(cost._pyro_dims): None for cost in cost_terms}
                for log_factor in log_factors:
                    key = frozenset(log_factor._pyro_dims)
                    if queries.get(key, False) is None:
                        queries[key] = log_factor
                # Ensure a query exists for each cost term.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    if queries[key] is None:
                        query = cost.new_zeros(cost.shape)
                        query._pyro_dims = cost._pyro_dims
                        log_factors.append(query)
                        queries[key] = query

                # Perform sum-product contraction. Note that plates never need to be
                # product-contracted due to our plate-based dependency ordering.
                sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal
                for query in queries.values():
                    require_backward(query)
                root = ring.sumproduct(log_factors, sum_dims)
                root._pyro_backward()
                probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()}

                # Aggregate prob * cost terms.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    prob = probs[key]
                    prob._pyro_dims = queries[key]._pyro_dims
                    mask = prob > 0
                    if torch._C._get_tracing_state() or not mask.all():
                        mask._pyro_dims = prob._pyro_dims
                        cost, prob, mask = packed.broadcast_all(cost, prob, mask)
                        prob = prob[mask]
                        cost = cost[mask]
                    else:
                        cost, prob = packed.broadcast_all(cost, prob)
                    expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())

        LAST_CACHE_SIZE[0] = count_cached_ops(cache)
        return expected_cost
Esempio n. 5
0
    def _get_trace(self, model, guide, args, kwargs):
        model_trace, guide_trace = super()._get_trace(model, guide, args,
                                                      kwargs)

        # Mark all sample sites with require_backward to gather enumerated
        # sites and adjust cond_indep_stack of all sample sites.
        for node in model_trace.nodes.values():
            if node["type"] == "sample" and not node["is_observed"]:
                log_prob = node["packed"]["unscaled_log_prob"]
                require_backward(log_prob)

        self._saved_state = model, model_trace, guide_trace, args, kwargs
        return model_trace, guide_trace
Esempio n. 6
0
def test_shape(backend, equation):
    backend = "pyro.ops.einsum.torch_{}".format(backend)
    inputs, output = equation.split("->")
    inputs = inputs.split(",")
    symbols = sorted(set(equation) - set(",->"))
    sizes = dict(zip(symbols, itertools.count(2)))
    input_shapes = [torch.Size(sizes[dim] for dim in dims) for dims in inputs]
    operands = [torch.randn(shape) for shape in input_shapes]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # check forward pass
    for x in operands:
        require_backward(x)
    expected = contract(equation,
                        *operands,
                        backend="pyro.ops.einsum.torch_log")
    actual = contract(equation, *operands, backend=backend)
    if backend.endswith("map"):
        assert actual.dtype == expected.dtype
        assert actual.shape == expected.shape
    else:
        assert_equal(actual, expected)

    # check backward pass
    actual._pyro_backward()
    for input_, x in zip(inputs, operands):
        backward_result = x._pyro_backward_result
        if backend.endswith("marginal"):
            assert backward_result.shape == x.shape
        else:
            contract_dims = set(input_) - set(output)
            if contract_dims:
                assert backward_result.size(0) == len(contract_dims)
                assert set(backward_result._pyro_dims[1:]) == set(output)
                for sample, dim in zip(backward_result,
                                       backward_result._pyro_sample_dims):
                    assert sample.min() >= 0
                    assert sample.max() < sizes[dim]
            else:
                assert backward_result is None
Esempio n. 7
0
        def _forward_backward(*operands):
            # First we request backward results on each input operand.
            # This is the pyro.ops.adjoint equivalent of torch's .requires_grad_().
            for operand in operands:
                require_backward(operand)

            # Next we run the forward pass.
            results = einsum(equation, *operands, backend=backend, **kwargs)

            # The we run a backward pass.
            for result in results:
                result._pyro_backward()

            # Finally we retrieve results from the ._pyro_backward_result attribute
            # that has been set on each input operand. If you only want results on a
            # subset of operands, you can call require_backward() on only those.
            results = []
            for x in operands:
                results.append(x._pyro_backward_result)
                x._pyro_backward_result = None

            return tuple(results)
Esempio n. 8
0
def _sample_posterior_from_trace(model, enum_trace, temperature, *args,
                                 **kwargs):
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    sum_dims = set()
    queries = []
    dim_to_size = {}
    cost_terms = OrderedDict()
    enum_terms = OrderedDict()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized and f.size > 1)
            # For sites that depend on an enumerated variable, we need to apply
            # the mask but not the scale when sampling.
            if "masked_log_prob" not in node["packed"]:
                node["packed"]["masked_log_prob"] = packed.scale_and_mask(
                    node["packed"]["unscaled_log_prob"],
                    mask=node["packed"]["mask"])
            log_prob = node["packed"]["masked_log_prob"]
            sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal)
            if sum_dims.isdisjoint(log_prob._pyro_dims):
                continue
            dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape))
            if node["infer"].get("_enumerate_dim") is None:
                cost_terms.setdefault(ordinal, []).append(log_prob)
            else:
                enum_terms.setdefault(ordinal, []).append(log_prob)
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # We take special care to match the term ordering in
    # pyro.infer.traceenum_elbo._compute_model_factors() to allow
    # contract_tensor_tree() to use shared_intermediates() inside
    # TraceEnumSample_ELBO. The special ordering is: first all cost terms in
    # order of model_trace, then all enum_terms in order of model trace.
    log_probs = cost_terms
    for ordinal, terms in enum_terms.items():
        log_probs.setdefault(ordinal, []).extend(terms)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    cache = getattr(enum_trace, "_sharing_cache", {})
    ring = _make_ring(temperature, cache, dim_to_size)
    with shared_intermediates(cache):
        log_probs = contract_tensor_tree(log_probs, sum_dims,
                                         ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["masked_log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not (f.vectorized and f.size > 1)
                    or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)
Esempio n. 9
0
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):
    # For internal use by infer_discrete.

    # Create an enumerated trace.
    with poutine.block(), EnumerateMessenger(first_available_dim):
        enum_trace = poutine.trace(model).get_trace(*args, **kwargs)
    enum_trace = prune_subsample_sites(enum_trace)
    enum_trace.compute_log_prob()
    enum_trace.pack_tensors()
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    log_probs = OrderedDict()
    sum_dims = set()
    queries = []
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized)
            log_prob = node["packed"]["log_prob"]
            log_probs.setdefault(ordinal, []).append(log_prob)
            sum_dims.update(log_prob._pyro_dims)
            for frame in node["cond_indep_stack"]:
                if frame.vectorized:
                    sum_dims.remove(plate_to_symbol[frame.name])
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    ring = _make_ring(temperature)
    log_probs = contract_tensor_tree(log_probs, sum_dims,
                                     ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not f.vectorized or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)