Ejemplo n.º 1
0
def test_loop_optimise():
    I = 20
    J = K = 10
    i = Index('i', I)
    j = Index('j', J)
    k = Index('k', K)

    A1 = Variable('a1', (I, ))
    A2 = Variable('a2', (I, ))
    A3 = Variable('a3', (I, ))
    A1i = Indexed(A1, (i, ))
    A2i = Indexed(A2, (i, ))
    A3i = Indexed(A3, (i, ))

    B = Variable('b', (J, ))
    C = Variable('c', (J, ))
    Bj = Indexed(B, (j, ))
    Cj = Indexed(C, (j, ))

    E = Variable('e', (K, ))
    F = Variable('f', (K, ))
    G = Variable('g', (K, ))
    Ek = Indexed(E, (k, ))
    Fk = Indexed(F, (k, ))
    Gk = Indexed(G, (k, ))

    Z = Variable('z', ())

    # Bj*Ek + Bj*Fk => (Ek + Fk)*Bj
    expr = Sum(Product(Bj, Ek), Product(Bj, Fk))
    result, = optimise_expressions([expr], (j, k))
    expected = Product(Sum(Ek, Fk), Bj)
    assert result == expected

    # Bj*Ek + Bj*Fk + Bj*Gk + Cj*Ek + Cj*Fk =>
    # (Ek + Fk + Gk)*Bj + (Ek+Fk)*Cj
    expr = Sum(
        Sum(Sum(Sum(Product(Bj, Ek), Product(Bj, Fk)), Product(Bj, Gk)),
            Product(Cj, Ek)), Product(Cj, Fk))
    result, = optimise_expressions([expr], (j, k))
    expected = Sum(Product(Sum(Sum(Ek, Fk), Gk), Bj), Product(Sum(Ek, Fk), Cj))
    assert result == expected

    # Z*A1i*Bj*Ek + Z*A2i*Bj*Ek + A3i*Bj*Ek + Z*A1i*Bj*Fk =>
    # Bj*(Ek*(Z*A1i + Z*A2i) + A3i) + Z*A1i*Fk)

    expr = Sum(
        Sum(
            Sum(Product(Z, Product(A1i, Product(Bj, Ek))),
                Product(Z, Product(A2i, Product(Bj, Ek)))),
            Product(A3i, Product(Bj, Ek))),
        Product(Z, Product(A1i, Product(Bj, Fk))))
    result, = optimise_expressions([expr], (j, k))
    expected = Product(
        Sum(Product(Ek, Sum(Sum(Product(Z, A1i), Product(Z, A2i)), A3i)),
            Product(Fk, Product(Z, A1i))), Bj)
    assert result == expected
Ejemplo n.º 2
0
def test_delta_elimination():
    i = Index()
    j = Index()
    k = Index()
    I = Identity(3)

    sum_indices = (i, j)
    factors = [Delta(i, j), Delta(i, k), Indexed(I, (j, k))]

    sum_indices, factors = delta_elimination(sum_indices, factors)
    factors = remove_componenttensors(factors)

    assert sum_indices == []
    assert factors == [one, one, Indexed(I, (k, k))]
Ejemplo n.º 3
0
def test_conditional_zero_folding():
    b = Variable("B", ())
    a = Variable("A", (3, ))
    i = Index()
    expr = Conditional(LogicalAnd(b, b), Product(Indexed(a, (i, )), Zero()),
                       Zero())

    assert expr == Zero()
Ejemplo n.º 4
0
def test_loop_fusion():
    i = Index()
    j = Index()
    Ri = Indexed(Variable('R', (6, )), (i, ))

    def make_expression(i, j):
        A = Variable('A', (6, ))
        s = IndexSum(Indexed(A, (j, )), (j, ))
        return Product(Indexed(A, (i, )), s)

    e1 = make_expression(i, j)
    e2 = make_expression(i, i)

    def gencode(expr):
        impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j))
        return impero_c.tree

    assert len(gencode(e1).children) == len(gencode(e2).children)
Ejemplo n.º 5
0
def test_replace_div():
    i = Index()
    A = Variable('A', ())
    B = Variable('B', (6,))
    Bi = Indexed(B, (i,))
    d = Division(Bi, A)
    result, = replace_division([d])
    expected = Product(Bi, Division(Literal(1.0), A))

    assert result == expected
Ejemplo n.º 6
0
def _select_expression(expressions, index):
    """Helper function to select an expression from a list of
    expressions with an index.  This function expect sanitised input,
    one should normally call :py:func:`select_expression` instead.

    :arg expressions: a list of expressions
    :arg index: an index (free, fixed or variable)
    :returns: an expression
    """
    expr = expressions[0]
    if all(e == expr for e in expressions):
        return expr

    types = set(map(type, expressions))
    if types <= {Indexed, Zero}:
        multiindex, = set(e.multiindex for e in expressions
                          if isinstance(e, Indexed))
        # Shape only determined by free indices
        shape = tuple(i.extent for i in multiindex if isinstance(i, Index))

        def child(expression):
            if isinstance(expression, Indexed):
                return expression.children[0]
            elif isinstance(expression, Zero):
                return Zero(shape)

        return Indexed(
            _select_expression(list(map(child, expressions)), index),
            multiindex)

    if types <= {Literal, Zero, Failure}:
        return partial_indexed(ListTensor(expressions), (index, ))

    if types <= {ComponentTensor, Zero}:
        shape, = set(e.shape for e in expressions)
        multiindex = tuple(Index(extent=d) for d in shape)
        children = remove_componenttensors(
            [Indexed(e, multiindex) for e in expressions])
        return ComponentTensor(_select_expression(children, index), multiindex)

    if len(types) == 1:
        cls, = types
        if cls.__front__ or cls.__back__:
            raise NotImplementedError(
                "How to factorise {} expressions?".format(cls.__name__))
        assert all(len(e.children) == len(expr.children) for e in expressions)
        assert len(expr.children) > 0

        return expr.reconstruct(*[
            _select_expression(nth_children, index)
            for nth_children in zip(*[e.children for e in expressions])
        ])

    raise NotImplementedError(
        "No rule for factorising expressions of this kind.")
Ejemplo n.º 7
0
def _unconcatenate(cache, pairs):
    # Tail-call recursive core of unconcatenate.
    # Assumes that input has already been sanitised.
    concat_group = find_group([e for v, e in pairs])
    if concat_group is None:
        return pairs

    # Get the index split
    concat_ref = next(iter(concat_group))
    assert isinstance(concat_ref, Indexed)
    concat_expr, = concat_ref.children
    index, = concat_ref.multiindex
    assert isinstance(concat_expr, Concatenate)
    try:
        multiindices = cache[index]
    except KeyError:
        multiindices = tuple(
            tuple(Index(extent=d) for d in child.shape)
            for child in concat_expr.children)
        cache[index] = multiindices

    def cut(node):
        """No need to rebuild expression of independent of the
        relevant concatenation index."""
        return index not in node.free_indices

    # Build Concatenate node replacement mappings
    mappings = [{} for i in range(len(multiindices))]
    for concat_ref in concat_group:
        concat_expr, = concat_ref.children
        for i in range(len(multiindices)):
            sub_ref = Indexed(concat_expr.children[i], multiindices[i])
            sub_ref, = remove_componenttensors((sub_ref, ))
            mappings[i][concat_ref] = sub_ref

    # Finally, split assignment pairs
    split_pairs = []
    for var, expr in pairs:
        if index not in var.free_indices:
            split_pairs.append((var, expr))
        else:
            for v, m in zip(split_variable(var, index, multiindices),
                            mappings):
                split_pairs.append((v, replace_node(expr, m, cut)))

    # Run again, there may be other Concatenate groups
    return _unconcatenate(cache, split_pairs)
Ejemplo n.º 8
0
def select_expression(expressions, index):
    """Select an expression from a list of expressions with an index.
    Semantically equivalent to

        partial_indexed(ListTensor(expressions), (index,))

    but has a much more optimised implementation.

    :arg expressions: a list of expressions of the same shape
    :arg index: an index (free, fixed or variable)
    :returns: an expression of the same shape as the given expressions
    """
    # Check arguments
    shape = expressions[0].shape
    assert all(e.shape == shape for e in expressions)

    # Sanitise input expressions
    alpha = tuple(Index() for s in shape)
    exprs = remove_componenttensors([Indexed(e, alpha) for e in expressions])

    # Factor the expressions recursively and convert result
    selected = _select_expression(exprs, index)
    return ComponentTensor(selected, alpha)