Beispiel #1
0
    def compute_expectation(self, costs):
        """
        Returns a differentiable expected cost, summing over costs at given ordinals.

        :param dict costs: A dict mapping ordinals to lists of cost tensors
        :returns: a scalar expected cost
        :rtype: torch.Tensor or float
        """
        # Share computation across all cost terms.
        with shared_intermediates() as cache:
            ring = MarginalRing(cache=cache)
            expected_cost = 0.
            for ordinal, cost_terms in costs.items():
                log_factors = self._get_log_factors(ordinal)
                scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor)))
                log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)]

                # Collect log_prob terms to query for marginal probability.
                queries = {frozenset(cost._pyro_dims): None for cost in cost_terms}
                for log_factor in log_factors:
                    key = frozenset(log_factor._pyro_dims)
                    if queries.get(key, False) is None:
                        queries[key] = log_factor
                # Ensure a query exists for each cost term.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    if queries[key] is None:
                        query = cost.new_zeros(cost.shape)
                        query._pyro_dims = cost._pyro_dims
                        log_factors.append(query)
                        queries[key] = query

                # Perform sum-product contraction. Note that plates never need to be
                # product-contracted due to our plate-based dependency ordering.
                sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal
                for query in queries.values():
                    require_backward(query)
                root = ring.sumproduct(log_factors, sum_dims)
                root._pyro_backward()
                probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()}

                # Aggregate prob * cost terms.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    prob = probs[key]
                    prob._pyro_dims = queries[key]._pyro_dims
                    mask = prob > 0
                    if torch._C._get_tracing_state() or not mask.all():
                        mask._pyro_dims = prob._pyro_dims
                        cost, prob, mask = packed.broadcast_all(cost, prob, mask)
                        prob = prob[mask]
                        cost = cost[mask]
                    else:
                        cost, prob = packed.broadcast_all(cost, prob)
                    expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())

        LAST_CACHE_SIZE[0] = count_cached_ops(cache)
        return expected_cost
Beispiel #2
0
def einsum(equation, *operands):
    """
    Forward-max-sum backward-argmax implementation of einsum.
    This assumes all operands have a ``._pyro_dims`` attribute set.
    """
    equation = packed.rename_equation(equation, *operands)
    inputs, output = equation.split('->')
    any_requires_backward = any(hasattr(x, '_pyro_backward') for x in operands)

    contract_dims = ''.join(
        sorted(set().union(*(x._pyro_dims for x in operands)) - set(output)))
    dims = output + contract_dims
    result = reduce(operator.add, packed.broadcast_all(*operands, dims=dims))
    argmax = None  # work around lack of pytorch support for zero-sized tensors
    if contract_dims:
        output_shape = result.shape[:len(output)]
        contract_shape = result.shape[len(output):]
        result, argmax = result.reshape(output_shape + (-1, )).max(-1)
        if any_requires_backward:
            argmax = unflatten(argmax, output, contract_dims, contract_shape)
    elif result is operands[0]:
        result = result[...]  # create a new object
    result._pyro_dims = output
    assert result.dim() == len(result._pyro_dims)

    if any_requires_backward:
        result._pyro_backward = _EinsumBackward(operands, argmax)
    return result
Beispiel #3
0
def compute_site_dice_factor(site):
    log_denom = 0
    log_prob = site["packed"][
        "score_parts"].score_function  # not scaled by subsampling
    dims = getattr(log_prob, "_pyro_dims", "")
    if site["infer"].get("enumerate"):
        num_samples = site["infer"].get("num_samples")
        if num_samples is not None:  # site was multiply sampled
            if not is_identically_zero(log_prob):
                log_prob = log_prob - log_prob.detach()
            log_prob = log_prob - math.log(num_samples)
            if not isinstance(log_prob, torch.Tensor):
                log_prob = torch.tensor(float(log_prob),
                                        device=site["value"].device)
            log_prob._pyro_dims = dims
            # I don't know why the following broadcast is needed, but it makes tests pass:
            log_prob, _ = packed.broadcast_all(log_prob,
                                               site["packed"]["log_prob"])
        elif site["infer"]["enumerate"] == "sequential":
            log_denom = math.log(site["infer"].get("_enum_total", num_samples))
    else:  # site was monte carlo sampled
        if not is_identically_zero(log_prob):
            log_prob = log_prob - log_prob.detach()
            log_prob._pyro_dims = dims

    return log_prob, log_denom
Beispiel #4
0
    def __init__(self, guide_trace, ordering):
        log_denom = defaultdict(float)  # avoids double-counting when sequentially enumerating
        log_probs = defaultdict(list)  # accounts for upstream probabilties

        for name, site in guide_trace.nodes.items():
            if site["type"] != "sample":
                continue

            log_prob = site["packed"]["score_parts"].score_function  # not scaled by subsampling
            dims = getattr(log_prob, "_pyro_dims", "")
            ordinal = ordering[name]
            if site["infer"].get("enumerate"):
                num_samples = site["infer"].get("num_samples")
                if num_samples is not None:  # site was multiply sampled
                    if not is_identically_zero(log_prob):
                        log_prob = log_prob - log_prob.detach()
                    log_prob = log_prob - math.log(num_samples)
                    if not isinstance(log_prob, torch.Tensor):
                        log_prob = site["value"].new_tensor(log_prob)
                    log_prob._pyro_dims = dims
                    # I don't know why the following broadcast is needed, but it makes tests pass:
                    log_prob, _ = packed.broadcast_all(log_prob, site["packed"]["log_prob"])
                elif site["infer"]["enumerate"] == "sequential":
                    log_denom[ordinal] += math.log(site["infer"]["_enum_total"])
            else:  # site was monte carlo sampled
                if is_identically_zero(log_prob):
                    continue
                log_prob = log_prob - log_prob.detach()
                log_prob._pyro_dims = dims
            log_probs[ordinal].append(log_prob)

        self.log_denom = log_denom
        self.log_probs = log_probs
Beispiel #5
0
    def process(self, message):
        output = self.output
        operands = list(self.operands)
        contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output)))
        batch_dims = output

        # Slice down operands before combining terms.
        sample2 = message
        if sample2 is not None:
            for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
                batch_dims = batch_dims.replace(dim, '')
                for i, x in enumerate(operands):
                    if dim in x._pyro_dims:
                        index._pyro_dims = sample2._pyro_dims[1:]
                        x = packed.gather(x, index, dim)
                    operands[i] = x

        # Combine terms.
        dims = batch_dims + contract_dims
        logits = reduce(operator.add, packed.broadcast_all(*operands, dims=dims))

        # Sample.
        sample1 = None  # work around lack of pytorch support for zero-sized tensors
        if contract_dims:
            output_shape = logits.shape[:len(batch_dims)]
            contract_shape = logits.shape[len(batch_dims):]
            flat_logits = logits.reshape(output_shape + (-1,))
            flat_sample = dist.Categorical(logits=flat_logits).sample()
            sample1 = unflatten(flat_sample, batch_dims, contract_dims, contract_shape)

        # Cut down samples to pass on to subsequent steps.
        return einsum_backward_sample(self.operands, sample1, sample2)
Beispiel #6
0
def test_broadcast_all(shapes):
    inputs, dim_to_symbol, symbol_to_dim = make_inputs(shapes)
    packed_inputs = [packed.pack(x, dim_to_symbol) for x in inputs]
    packed_outputs = packed.broadcast_all(*packed_inputs)
    actual = tuple(packed.unpack(x, symbol_to_dim) for x in packed_outputs)
    expected = broadcast_all(*inputs) if inputs else []
    assert len(actual) == len(expected)
    for a, e in zip(actual, expected):
        assert_equal(a, e)
Beispiel #7
0
def einsum_backward_sample(operands, sample1, sample2):
    """
    Cuts down samples to pass on to subsequent steps.
    This is used in various ``_EinsumBackward.__call__()`` methods.
    This assumes all operands have a ``._pyro_dims`` attribute set.
    """
    # Combine upstream sample with sample at this site.
    if sample1 is None:
        sample = sample2
    elif sample2 is None:
        sample = sample1
    else:
        # Slice sample1 down based on choices in sample2.
        assert set(sample1._pyro_sample_dims).isdisjoint(
            sample2._pyro_sample_dims)
        sample_dims = sample1._pyro_sample_dims + sample2._pyro_sample_dims
        for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
            if dim in sample1._pyro_dims:
                index._pyro_dims = sample2._pyro_dims[1:]
                sample1 = packed.gather(sample1, index, dim)

        # Concatenate the two samples.
        parts = packed.broadcast_all(sample1, sample2)
        sample = torch.cat(parts)
        sample._pyro_dims = parts[0]._pyro_dims
        sample._pyro_sample_dims = sample_dims
        assert sample.dim() == len(sample._pyro_dims)
        if not torch._C._get_tracing_state():
            assert sample.size(0) == len(sample._pyro_sample_dims)

    # Select sample dimensions to pass on to downstream sites.
    for x in operands:
        if not hasattr(x, '_pyro_backward'):
            continue
        if sample is None:
            yield x._pyro_backward, None
            continue
        x_sample_dims = set(x._pyro_dims) & set(sample._pyro_sample_dims)
        if not x_sample_dims:
            yield x._pyro_backward, None
            continue
        if x_sample_dims == set(sample._pyro_sample_dims):
            yield x._pyro_backward, sample
            continue
        x_sample_dims = ''.join(sorted(x_sample_dims))
        x_sample = sample[[
            sample._pyro_sample_dims.index(dim) for dim in x_sample_dims
        ]]
        x_sample._pyro_dims = sample._pyro_dims
        x_sample._pyro_sample_dims = x_sample_dims
        assert x_sample.dim() == len(x_sample._pyro_dims)
        if not torch._C._get_tracing_state():
            assert x_sample.size(0) == len(x_sample._pyro_sample_dims)
        yield x._pyro_backward, x_sample