def test_partition_terms(inputs, dims, expected_num_components): ring = LogRing() symbol_to_size = dict(zip('abc', [2, 3, 4])) shapes = [tuple(symbol_to_size[s] for s in input_) for input_ in inputs] tensors = [torch.randn(shape) for shape in shapes] for input_, tensor in zip(inputs, tensors): tensor._pyro_dims = input_ components = list(_partition_terms(ring, tensors, dims)) # Check that result is a partition. expected_terms = sorted(tensors, key=id) actual_terms = sorted((x for c in components for x in c[0]), key=id) assert actual_terms == expected_terms assert dims == set.union(set(), *(c[1] for c in components)) # Check that the partition is not too coarse. assert len(components) == expected_num_components # Check that partition is not too fine. component_dict = {x: i for i, (terms, _) in enumerate(components) for x in terms} for x in tensors: for y in tensors: if x is not y: if dims.intersection(x._pyro_dims, y._pyro_dims): assert component_dict[x] == component_dict[y]
def contract_tensor_tree(tensor_tree, sum_dims, cache=None, ring=None): """ Contract out ``sum_dims`` in a tree of tensors via message passing. This partially contracts out plate dimensions. This function should be deterministic and free of side effects. :param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of tensors. An ordinal is a frozenset of ``CondIndepStack`` frames. :param set sum_dims: the complete set of sum-contractions dimensions (indexed from the right). This is needed to distinguish sum-contraction dimensions from product-contraction dimensions. :param dict cache: an optional :func:`~opt_einsum.shared_intermediates` cache. :param pyro.ops.rings.Ring ring: an optional algebraic ring defining tensor operations. :returns: A contracted version of ``tensor_tree`` :rtype: OrderedDict """ assert isinstance(tensor_tree, OrderedDict) assert isinstance(sum_dims, set) if ring is None: ring = LogRing(cache) ordinals = {term: t for t, terms in tensor_tree.items() for term in terms} all_terms = [term for terms in tensor_tree.values() for term in terms] contracted_tree = OrderedDict() # Split this tensor tree into connected components. for terms, dims in _partition_terms(ring, all_terms, sum_dims): component = OrderedDict() for term in terms: component.setdefault(ordinals[term], []).append(term) # Contract this connected component down to a single tensor. ordinal, term = _contract_component(ring, component, dims, set()) contracted_tree.setdefault(ordinal, []).append(term) return contracted_tree
def contract_to_tensor(tensor_tree, sum_dims, target_ordinal=None, target_dims=None, cache=None, ring=None): """ Contract out ``sum_dims`` in a tree of tensors, via message passing. This reduces all terms down to a single tensor in the plate context specified by ``target_ordinal``, optionally preserving sum dimensions ``target_dims``. This function should be deterministic and free of side effects. :param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of tensors. An ordinal is a frozenset of ``CondIndepStack`` frames. :param set sum_dims: the complete set of sum-contractions dimensions (indexed from the right). This is needed to distinguish sum-contraction dimensions from product-contraction dimensions. :param frozenset target_ordinal: An optional ordinal to which the result will be contracted or broadcasted. :param set target_dims: An optional subset of ``sum_dims`` that should be preserved in the result. :param dict cache: an optional :func:`~opt_einsum.shared_intermediates` cache. :param pyro.ops.rings.Ring ring: an optional algebraic ring defining tensor operations. :returns: a single tensor :rtype: torch.Tensor """ if target_ordinal is None: target_ordinal = frozenset() if target_dims is None: target_dims = set() assert isinstance(tensor_tree, OrderedDict) assert isinstance(sum_dims, set) assert isinstance(target_ordinal, frozenset) assert isinstance(target_dims, set) and target_dims <= sum_dims if ring is None: ring = LogRing(cache) ordinals = {term: t for t, terms in tensor_tree.items() for term in terms} all_terms = [term for terms in tensor_tree.values() for term in terms] contracted_terms = [] # Split this tensor tree into connected components. modulo_total = bool(target_dims) for terms, dims in _partition_terms(ring, all_terms, sum_dims): if modulo_total and dims.isdisjoint(target_dims): continue component = OrderedDict() for term in terms: component.setdefault(ordinals[term], []).append(term) # Contract this connected component down to a single tensor. ordinal, term = _contract_component(ring, component, dims, target_dims & dims) _check_batch_dims_are_sensible( target_dims.intersection(term._pyro_dims), ordinal - target_ordinal) # Eliminate extra plate dims via product contractions. contract_frames = ordinal - target_ordinal if contract_frames: assert not sum_dims.intersection(term._pyro_dims) term = ring.product(term, contract_frames) contracted_terms.append(term) # Combine contracted tensors via product, then broadcast. term = ring.sumproduct(contracted_terms, set()) assert sum_dims.intersection(term._pyro_dims) <= target_dims return ring.broadcast(term, target_ordinal)