Ejemplo n.º 1
0
def test_partition_terms(inputs, dims, expected_num_components):
    ring = LogRing()
    symbol_to_size = dict(zip('abc', [2, 3, 4]))
    shapes = [tuple(symbol_to_size[s] for s in input_) for input_ in inputs]
    tensors = [torch.randn(shape) for shape in shapes]
    for input_, tensor in zip(inputs, tensors):
        tensor._pyro_dims = input_
    components = list(_partition_terms(ring, tensors, dims))

    # Check that result is a partition.
    expected_terms = sorted(tensors, key=id)
    actual_terms = sorted((x for c in components for x in c[0]), key=id)
    assert actual_terms == expected_terms
    assert dims == set.union(set(), *(c[1] for c in components))

    # Check that the partition is not too coarse.
    assert len(components) == expected_num_components

    # Check that partition is not too fine.
    component_dict = {x: i for i, (terms, _) in enumerate(components) for x in terms}
    for x in tensors:
        for y in tensors:
            if x is not y:
                if dims.intersection(x._pyro_dims, y._pyro_dims):
                    assert component_dict[x] == component_dict[y]
Ejemplo n.º 2
0
def contract_tensor_tree(tensor_tree, sum_dims, cache=None, ring=None):
    """
    Contract out ``sum_dims`` in a tree of tensors via message passing.
    This partially contracts out plate dimensions.

    This function should be deterministic and free of side effects.

    :param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of
        tensors. An ordinal is a frozenset of ``CondIndepStack`` frames.
    :param set sum_dims: the complete set of sum-contractions dimensions
        (indexed from the right). This is needed to distinguish sum-contraction
        dimensions from product-contraction dimensions.
    :param dict cache: an optional :func:`~opt_einsum.shared_intermediates`
        cache.
    :param pyro.ops.rings.Ring ring: an optional algebraic ring defining tensor
        operations.
    :returns: A contracted version of ``tensor_tree``
    :rtype: OrderedDict
    """
    assert isinstance(tensor_tree, OrderedDict)
    assert isinstance(sum_dims, set)

    if ring is None:
        ring = LogRing(cache)

    ordinals = {term: t for t, terms in tensor_tree.items() for term in terms}
    all_terms = [term for terms in tensor_tree.values() for term in terms]
    contracted_tree = OrderedDict()

    # Split this tensor tree into connected components.
    for terms, dims in _partition_terms(ring, all_terms, sum_dims):
        component = OrderedDict()
        for term in terms:
            component.setdefault(ordinals[term], []).append(term)

        # Contract this connected component down to a single tensor.
        ordinal, term = _contract_component(ring, component, dims, set())
        contracted_tree.setdefault(ordinal, []).append(term)

    return contracted_tree
Ejemplo n.º 3
0
def contract_to_tensor(tensor_tree,
                       sum_dims,
                       target_ordinal=None,
                       target_dims=None,
                       cache=None,
                       ring=None):
    """
    Contract out ``sum_dims`` in a tree of tensors, via message
    passing. This reduces all terms down to a single tensor in the plate
    context specified by ``target_ordinal``, optionally preserving sum
    dimensions ``target_dims``.

    This function should be deterministic and free of side effects.

    :param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of
        tensors. An ordinal is a frozenset of ``CondIndepStack`` frames.
    :param set sum_dims: the complete set of sum-contractions dimensions
        (indexed from the right). This is needed to distinguish sum-contraction
        dimensions from product-contraction dimensions.
    :param frozenset target_ordinal: An optional ordinal to which the result
        will be contracted or broadcasted.
    :param set target_dims: An optional subset of ``sum_dims`` that should be
        preserved in the result.
    :param dict cache: an optional :func:`~opt_einsum.shared_intermediates`
        cache.
    :param pyro.ops.rings.Ring ring: an optional algebraic ring defining tensor
        operations.
    :returns: a single tensor
    :rtype: torch.Tensor
    """
    if target_ordinal is None:
        target_ordinal = frozenset()
    if target_dims is None:
        target_dims = set()
    assert isinstance(tensor_tree, OrderedDict)
    assert isinstance(sum_dims, set)
    assert isinstance(target_ordinal, frozenset)
    assert isinstance(target_dims, set) and target_dims <= sum_dims
    if ring is None:
        ring = LogRing(cache)

    ordinals = {term: t for t, terms in tensor_tree.items() for term in terms}
    all_terms = [term for terms in tensor_tree.values() for term in terms]
    contracted_terms = []

    # Split this tensor tree into connected components.
    modulo_total = bool(target_dims)
    for terms, dims in _partition_terms(ring, all_terms, sum_dims):
        if modulo_total and dims.isdisjoint(target_dims):
            continue
        component = OrderedDict()
        for term in terms:
            component.setdefault(ordinals[term], []).append(term)

        # Contract this connected component down to a single tensor.
        ordinal, term = _contract_component(ring, component, dims,
                                            target_dims & dims)
        _check_batch_dims_are_sensible(
            target_dims.intersection(term._pyro_dims),
            ordinal - target_ordinal)

        # Eliminate extra plate dims via product contractions.
        contract_frames = ordinal - target_ordinal
        if contract_frames:
            assert not sum_dims.intersection(term._pyro_dims)
            term = ring.product(term, contract_frames)

        contracted_terms.append(term)

    # Combine contracted tensors via product, then broadcast.
    term = ring.sumproduct(contracted_terms, set())
    assert sum_dims.intersection(term._pyro_dims) <= target_dims
    return ring.broadcast(term, target_ordinal)