def test_partial_sharing(backend): eq = 'ab,bc,de->' x, y, z1 = helpers.build_views(eq) z2 = 2.0 * z1 - 1.0 expr = contract_expression(eq, x.shape, y.shape, z1.shape) print('-' * 40) print('Without sharing:') num_exprs_nosharing = Counter() with shared_intermediates() as cache: expr(x, y, z1, backend=backend) num_exprs_nosharing.update(count_cached_ops(cache)) with shared_intermediates() as cache: expr(x, y, z2, backend=backend) num_exprs_nosharing.update(count_cached_ops(cache)) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: expr(x, y, z1, backend=backend) expr(x, y, z2, backend=backend) num_exprs_sharing = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(num_exprs_nosharing)) print('With sharing: {} expressions'.format(num_exprs_sharing)) assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
def test_no_sharing_separate_cache(backend): eq = 'ab,bc,cd->' views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) expected.update(count_cached_ops(cache)) # we expect double print('-' * 40) print('With sharing:') with shared_intermediates() as cache1: expr(*views, backend=backend) actual = count_cached_ops(cache1) with shared_intermediates() as cache2: expr(*views, backend=backend) actual.update(count_cached_ops(cache2)) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def test_sharing_modulo_commutativity(eq, backend): ops = helpers.build_views(eq) ops = [to_backend[backend](x) for x in ops] inputs, output, _ = parse_einsum_input([eq] + ops) inputs = inputs.split(',') print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: _einsum(eq, *ops, backend=backend) expected = count_cached_ops(cache) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: for permuted in itertools.permutations(zip(inputs, ops)): permuted_inputs = [p[0] for p in permuted] permuted_ops = [p[1] for p in permuted] permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output) _einsum(permuted_eq, *permuted_ops, backend=backend) actual = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def compute_expectation(self, costs): """ Returns a differentiable expected cost, summing over costs at given ordinals. :param dict costs: A dict mapping ordinals to lists of cost tensors :returns: a scalar expected cost :rtype: torch.Tensor or float """ # Share computation across all cost terms. with shared_intermediates() as cache: ring = MarginalRing(cache=cache) expected_cost = 0. for ordinal, cost_terms in costs.items(): log_factors = self._get_log_factors(ordinal) scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor))) log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)] # Collect log_prob terms to query for marginal probability. queries = {frozenset(cost._pyro_dims): None for cost in cost_terms} for log_factor in log_factors: key = frozenset(log_factor._pyro_dims) if queries.get(key, False) is None: queries[key] = log_factor # Ensure a query exists for each cost term. for cost in cost_terms: key = frozenset(cost._pyro_dims) if queries[key] is None: query = cost.new_zeros(cost.shape) query._pyro_dims = cost._pyro_dims log_factors.append(query) queries[key] = query # Perform sum-product contraction. Note that plates never need to be # product-contracted due to our plate-based dependency ordering. sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal for query in queries.values(): require_backward(query) root = ring.sumproduct(log_factors, sum_dims) root._pyro_backward() probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()} # Aggregate prob * cost terms. for cost in cost_terms: key = frozenset(cost._pyro_dims) prob = probs[key] prob._pyro_dims = queries[key]._pyro_dims mask = prob > 0 if torch._C._get_tracing_state() or not mask.all(): mask._pyro_dims = prob._pyro_dims cost, prob, mask = packed.broadcast_all(cost, prob, mask) prob = prob[mask] cost = cost[mask] else: cost, prob = packed.broadcast_all(cost, prob) expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim()) LAST_CACHE_SIZE[0] = count_cached_ops(cache) return expected_cost
def test_complete_sharing(backend): eq = 'ab,bc,cd->' views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print('-' * 40) print('Without sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) print('-' * 40) print('With sharing:') with shared_intermediates() as cache: expr(*views, backend=backend) expr(*views, backend=backend) actual = count_cached_ops(cache) print('-' * 40) print('Without sharing: {} expressions'.format(expected)) print('With sharing: {} expressions'.format(actual)) assert actual == expected
def test_sharing_reused_cache(backend): eq = "ab,bc,cd->" views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print("-" * 40) print("Without sharing:") with shared_intermediates() as cache: expr(*views, backend=backend) expected = count_cached_ops(cache) print("-" * 40) print("With sharing:") with shared_intermediates() as cache: expr(*views, backend=backend) with shared_intermediates(cache): expr(*views, backend=backend) actual = count_cached_ops(cache) print("-" * 40) print("Without sharing: {} expressions".format(expected)) print("With sharing: {} expressions".format(actual)) assert actual == expected
def _compute_cost(cache): counts = count_cached_ops(cache) return counts['einsum'] + counts['tensordot']
def _compute_cost(cache): counts = count_cached_ops(cache) return counts["einsum"] + counts["tensordot"]