Ejemplo n.º 1
0
def test_partial_sharing(backend):
    eq = 'ab,bc,de->'
    x, y, z1 = helpers.build_views(eq)
    z2 = 2.0 * z1 - 1.0
    expr = contract_expression(eq, x.shape, y.shape, z1.shape)

    print('-' * 40)
    print('Without sharing:')
    num_exprs_nosharing = Counter()
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))
    with shared_intermediates() as cache:
        expr(x, y, z2, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        expr(x, y, z2, backend=backend)
        num_exprs_sharing = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
    print('With sharing: {} expressions'.format(num_exprs_sharing))
    assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
Ejemplo n.º 2
0
def test_no_sharing_separate_cache(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)
        expected.update(count_cached_ops(cache))  # we expect double

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache1:
        expr(*views, backend=backend)
        actual = count_cached_ops(cache1)
    with shared_intermediates() as cache2:
        expr(*views, backend=backend)
        actual.update(count_cached_ops(cache2))

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 3
0
def test_sharing_modulo_commutativity(eq, backend):
    ops = helpers.build_views(eq)
    ops = [to_backend[backend](x) for x in ops]
    inputs, output, _ = parse_einsum_input([eq] + ops)
    inputs = inputs.split(',')

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        _einsum(eq, *ops, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        for permuted in itertools.permutations(zip(inputs, ops)):
            permuted_inputs = [p[0] for p in permuted]
            permuted_ops = [p[1] for p in permuted]
            permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output)
            _einsum(permuted_eq, *permuted_ops, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 4
0
def test_partial_sharing(backend):
    eq = 'ab,bc,de->'
    x, y, z1 = helpers.build_views(eq)
    z2 = 2.0 * z1 - 1.0
    expr = contract_expression(eq, x.shape, y.shape, z1.shape)

    print('-' * 40)
    print('Without sharing:')
    num_exprs_nosharing = Counter()
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))
    with shared_intermediates() as cache:
        expr(x, y, z2, backend=backend)
        num_exprs_nosharing.update(count_cached_ops(cache))

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(x, y, z1, backend=backend)
        expr(x, y, z2, backend=backend)
        num_exprs_sharing = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
    print('With sharing: {} expressions'.format(num_exprs_sharing))
    assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
Ejemplo n.º 5
0
def test_sharing_modulo_commutativity(eq, backend):
    ops = helpers.build_views(eq)
    ops = [to_backend[backend](x) for x in ops]
    inputs, output, _ = parse_einsum_input([eq] + ops)
    inputs = inputs.split(',')

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        _einsum(eq, *ops, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        for permuted in itertools.permutations(zip(inputs, ops)):
            permuted_inputs = [p[0] for p in permuted]
            permuted_ops = [p[1] for p in permuted]
            permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output)
            _einsum(permuted_eq, *permuted_ops, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 6
0
def test_no_sharing_separate_cache(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)
        expected.update(count_cached_ops(cache))  # we expect double

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache1:
        expr(*views, backend=backend)
        actual = count_cached_ops(cache1)
    with shared_intermediates() as cache2:
        expr(*views, backend=backend)
        actual.update(count_cached_ops(cache2))

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 7
0
    def compute_expectation(self, costs):
        """
        Returns a differentiable expected cost, summing over costs at given ordinals.

        :param dict costs: A dict mapping ordinals to lists of cost tensors
        :returns: a scalar expected cost
        :rtype: torch.Tensor or float
        """
        # Share computation across all cost terms.
        with shared_intermediates() as cache:
            ring = MarginalRing(cache=cache)
            expected_cost = 0.
            for ordinal, cost_terms in costs.items():
                log_factors = self._get_log_factors(ordinal)
                scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor)))
                log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)]

                # Collect log_prob terms to query for marginal probability.
                queries = {frozenset(cost._pyro_dims): None for cost in cost_terms}
                for log_factor in log_factors:
                    key = frozenset(log_factor._pyro_dims)
                    if queries.get(key, False) is None:
                        queries[key] = log_factor
                # Ensure a query exists for each cost term.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    if queries[key] is None:
                        query = cost.new_zeros(cost.shape)
                        query._pyro_dims = cost._pyro_dims
                        log_factors.append(query)
                        queries[key] = query

                # Perform sum-product contraction. Note that plates never need to be
                # product-contracted due to our plate-based dependency ordering.
                sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal
                for query in queries.values():
                    require_backward(query)
                root = ring.sumproduct(log_factors, sum_dims)
                root._pyro_backward()
                probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()}

                # Aggregate prob * cost terms.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    prob = probs[key]
                    prob._pyro_dims = queries[key]._pyro_dims
                    mask = prob > 0
                    if torch._C._get_tracing_state() or not mask.all():
                        mask._pyro_dims = prob._pyro_dims
                        cost, prob, mask = packed.broadcast_all(cost, prob, mask)
                        prob = prob[mask]
                        cost = cost[mask]
                    else:
                        cost, prob = packed.broadcast_all(cost, prob)
                    expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())

        LAST_CACHE_SIZE[0] = count_cached_ops(cache)
        return expected_cost
Ejemplo n.º 8
0
def test_complete_sharing(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expr(*views, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 9
0
def test_complete_sharing(backend):
    eq = 'ab,bc,cd->'
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print('-' * 40)
    print('Without sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)

    print('-' * 40)
    print('With sharing:')
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expr(*views, backend=backend)
        actual = count_cached_ops(cache)

    print('-' * 40)
    print('Without sharing: {} expressions'.format(expected))
    print('With sharing: {} expressions'.format(actual))
    assert actual == expected
Ejemplo n.º 10
0
def test_sharing_reused_cache(backend):
    eq = "ab,bc,cd->"
    views = helpers.build_views(eq)
    expr = contract_expression(eq, *(v.shape for v in views))

    print("-" * 40)
    print("Without sharing:")
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
        expected = count_cached_ops(cache)

    print("-" * 40)
    print("With sharing:")
    with shared_intermediates() as cache:
        expr(*views, backend=backend)
    with shared_intermediates(cache):
        expr(*views, backend=backend)
        actual = count_cached_ops(cache)

    print("-" * 40)
    print("Without sharing: {} expressions".format(expected))
    print("With sharing: {} expressions".format(actual))
    assert actual == expected
Ejemplo n.º 11
0
def _compute_cost(cache):
    counts = count_cached_ops(cache)
    return counts['einsum'] + counts['tensordot']
Ejemplo n.º 12
0
def _compute_cost(cache):
    counts = count_cached_ops(cache)
    return counts['einsum'] + counts['tensordot']
Ejemplo n.º 13
0
def _compute_cost(cache):
    counts = count_cached_ops(cache)
    return counts["einsum"] + counts["tensordot"]