Ejemplo n.º 1
0
def optimise_expressions(expressions, argument_indices):
    """Perform loop optimisations on GEM DAGs

    :arg expressions: list of GEM DAGs
    :arg argument_indices: tuple of argument indices

    :returns: list of optimised GEM DAGs
    """
    # Skip optimisation for if Failure node is present
    for n in traversal(expressions):
        if isinstance(n, Failure):
            return expressions

    def classify(argument_indices, expression):
        n = len(argument_indices.intersection(expression.free_indices))
        if n == 0:
            return OTHER
        elif n == 1:
            if isinstance(expression, Indexed):
                return ATOMIC
            else:
                return COMPOUND
        else:
            return COMPOUND

    # Apply argument factorisation unconditionally
    classifier = partial(classify, set(argument_indices))
    monomial_sums = collect_monomials(expressions, classifier)
    return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
Ejemplo n.º 2
0
def flatten(var_reps):
    # Classifier for argument factorisation
    def classify(argument_indices, expression):
        n = len(argument_indices.intersection(expression.free_indices))
        if n == 0:
            return OTHER
        elif n == 1:
            return ATOMIC
        else:
            return COMPOUND

    for variable, reps in var_reps:
        # Destructure representation
        argument_indicez, expressions = zip(*reps)
        # Assert identical argument indices for all integrals
        argument_indices, = set(map(frozenset, argument_indicez))
        # Argument factorise
        classifier = partial(classify, argument_indices)
        for monomial_sum in collect_monomials(expressions, classifier):
            # Compact MonomialSum after IndexSum-Delta cancellation
            delta_simplified = MonomialSum()
            for monomial in monomial_sum:
                delta_simplified.add(*delta_elimination(*monomial))

            # Yield assignments
            for monomial in delta_simplified:
                yield (variable, sum_factorise(*monomial))
Ejemplo n.º 3
0
def test_refactorise():
    f = gem.Variable('f', (3,))
    u = gem.Variable('u', (3,))
    v = gem.Variable('v', ())

    i = gem.Index()
    f_i = gem.Indexed(f, (i,))
    u_i = gem.Indexed(u, (i,))

    def classify(atomics_set, expression):
        if expression in atomics_set:
            return ATOMIC

        for node in traversal([expression]):
            if node in atomics_set:
                return COMPOUND

        return OTHER
    classifier = partial(classify, {u_i, v})

    # \sum_i 5*(2*u_i + -1*v)*(u_i + v*f)
    expr = gem.IndexSum(
        gem.Product(
            gem.Literal(5),
            gem.Product(
                gem.Sum(gem.Product(gem.Literal(2), u_i),
                        gem.Product(gem.Literal(-1), v)),
                gem.Sum(u_i, gem.Product(v, f_i))
            )
        ),
        (i,)
    )

    expected = [
        Monomial((i,),
                 (u_i, u_i),
                 gem.Literal(10)),
        Monomial((i,),
                 (u_i, v),
                 gem.Product(gem.Literal(5),
                             gem.Sum(gem.Product(f_i, gem.Literal(2)),
                                     gem.Literal(-1)))),
        Monomial((),
                 (v, v),
                 gem.Product(gem.Literal(5),
                             gem.IndexSum(gem.Product(f_i, gem.Literal(-1)),
                                          (i,)))),
    ]

    actual, = collect_monomials([expr], classifier)
    assert expected == list(actual)
Ejemplo n.º 4
0
def test_refactorise():
    f = gem.Variable('f', (3,))
    u = gem.Variable('u', (3,))
    v = gem.Variable('v', ())

    i = gem.Index()
    f_i = gem.Indexed(f, (i,))
    u_i = gem.Indexed(u, (i,))

    def classify(atomics_set, expression):
        if expression in atomics_set:
            return ATOMIC

        for node in traversal([expression]):
            if node in atomics_set:
                return COMPOUND

        return OTHER
    classifier = partial(classify, {u_i, v})

    # \sum_i 5*(2*u_i + -1*v)*(u_i + v*f)
    expr = gem.IndexSum(
        gem.Product(
            gem.Literal(5),
            gem.Product(
                gem.Sum(gem.Product(gem.Literal(2), u_i),
                        gem.Product(gem.Literal(-1), v)),
                gem.Sum(u_i, gem.Product(v, f_i))
            )
        ),
        (i,)
    )

    expected = [
        Monomial((i,),
                 (u_i, u_i),
                 gem.Literal(10)),
        Monomial((i,),
                 (u_i, v),
                 gem.Product(gem.Literal(5),
                             gem.Sum(gem.Product(f_i, gem.Literal(2)),
                                     gem.Literal(-1)))),
        Monomial((),
                 (v, v),
                 gem.Product(gem.Literal(5),
                             gem.IndexSum(gem.Product(f_i, gem.Literal(-1)),
                                          (i,)))),
    ]

    actual, = collect_monomials([expr], classifier)
    assert expected == list(actual)
Ejemplo n.º 5
0
def Integrals(expressions, quadrature_multiindex, argument_multiindices,
              parameters):
    # Concatenate
    expressions = concatenate(expressions)

    # Unroll
    max_extent = parameters["unroll_indexsum"]
    if max_extent:

        def predicate(index):
            return index.extent <= max_extent

        expressions = unroll_indexsum(expressions, predicate=predicate)

    # Refactorise
    def classify(quadrature_indices, expression):
        if not quadrature_indices.intersection(expression.free_indices):
            return OTHER
        elif isinstance(expression, gem.Indexed) and isinstance(
                expression.children[0], gem.Literal):
            return ATOMIC
        else:
            return COMPOUND

    classifier = partial(classify, set(quadrature_multiindex))

    result = []
    for expr, monomial_sum in zip(expressions,
                                  collect_monomials(expressions, classifier)):
        # Select quadrature indices that are present
        quadrature_indices = set(index for index in quadrature_multiindex
                                 if index in expr.free_indices)

        products = []
        for sum_indices, factors, rest in monomial_sum:
            # Collapse quadrature literals for each monomial
            if factors or quadrature_indices:
                replacement = einsum(remove_componenttensors(factors),
                                     quadrature_indices)
            else:
                replacement = gem.Literal(1)
            # Rebuild expression
            products.append(
                gem.IndexSum(gem.Product(replacement, rest), sum_indices))
        result.append(reduce(gem.Sum, products, gem.Zero()))
    return result
Ejemplo n.º 6
0
def optimise_expressions(expressions, argument_indices):
    """Perform loop optimisations on GEM DAGs

    :arg expressions: list of GEM DAGs
    :arg argument_indices: tuple of argument indices

    :returns: list of optimised GEM DAGs
    """
    # Skip optimisation for if Failure node is present
    for n in traversal(expressions):
        if isinstance(n, Failure):
            return expressions

    # Apply argument factorisation unconditionally
    classifier = partial(spectral.classify, set(argument_indices))
    monomial_sums = collect_monomials(expressions, classifier)
    return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
Ejemplo n.º 7
0
def optimise_expressions(expressions, argument_indices):
    """Perform loop optimisations on GEM DAGs

    :arg expressions: list of GEM DAGs
    :arg argument_indices: tuple of argument indices

    :returns: list of optimised GEM DAGs
    """
    # Skip optimisation for if Failure node is present
    for n in traversal(expressions):
        if isinstance(n, Failure):
            return expressions

    # Apply argument factorisation unconditionally
    classifier = partial(spectral.classify, set(argument_indices))
    monomial_sums = collect_monomials(expressions, classifier)
    return [
        optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums
    ]
Ejemplo n.º 8
0
def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters):
    # Concatenate
    expressions = concatenate(expressions)

    # Unroll
    max_extent = parameters["unroll_indexsum"]
    if max_extent:
        def predicate(index):
            return index.extent <= max_extent
        expressions = unroll_indexsum(expressions, predicate=predicate)

    # Refactorise
    def classify(quadrature_indices, expression):
        if not quadrature_indices.intersection(expression.free_indices):
            return OTHER
        elif isinstance(expression, gem.Indexed) and isinstance(expression.children[0], gem.Literal):
            return ATOMIC
        else:
            return COMPOUND
    classifier = partial(classify, set(quadrature_multiindex))

    result = []
    for expr, monomial_sum in zip(expressions, collect_monomials(expressions, classifier)):
        # Select quadrature indices that are present
        quadrature_indices = set(index for index in quadrature_multiindex
                                 if index in expr.free_indices)

        products = []
        for sum_indices, factors, rest in monomial_sum:
            # Collapse quadrature literals for each monomial
            if factors or quadrature_indices:
                replacement = einsum(remove_componenttensors(factors), quadrature_indices)
            else:
                replacement = gem.Literal(1)
            # Rebuild expression
            products.append(gem.IndexSum(gem.Product(replacement, rest), sum_indices))
        result.append(reduce(gem.Sum, products, gem.Zero()))
    return result
Ejemplo n.º 9
0
def flatten(var_reps, index_cache):
    quadrature_indices = OrderedDict()

    pairs = []  # assignment pairs
    for variable, reps in var_reps:
        # Extract argument indices
        argument_indices, = set(r.argument_indices for r in reps)
        assert set(variable.free_indices) == set(argument_indices)

        # Extract and verify expressions
        expressions = [r.expression for r in reps]
        assert all(set(e.free_indices) <= set(argument_indices)
                   for e in expressions)

        # Save assignment pair
        pairs.append((variable, reduce(Sum, expressions)))

        # Collect quadrature_indices
        for r in reps:
            quadrature_indices.update(zip_longest(r.quadrature_multiindex, ()))

    # Split Concatenate nodes
    pairs = unconcatenate(pairs, cache=index_cache)

    def group_key(pair):
        variable, expression = pair
        return frozenset(variable.free_indices)

    # Variable ordering after delta cancellation
    narrow_variables = OrderedDict()
    # Assignments are variable -> MonomialSum map
    delta_simplified = defaultdict(MonomialSum)
    # Group assignment pairs by argument indices
    for free_indices, pair_group in groupby(pairs, group_key):
        variables, expressions = zip(*pair_group)
        # Argument factorise expressions
        classifier = partial(classify, set(free_indices))
        monomial_sums = collect_monomials(expressions, classifier)
        # For each monomial, apply delta cancellation and insert
        # result into delta_simplified.
        for variable, monomial_sum in zip(variables, monomial_sums):
            for monomial in monomial_sum:
                var, s, a, r = delta_elimination(variable, *monomial)
                narrow_variables.setdefault(var)
                delta_simplified[var].add(s, a, r)

    # Final factorisation
    for variable in narrow_variables:
        monomial_sum = delta_simplified[variable]
        # Collect sum indices applicable to the current MonomialSum
        sum_indices = set().union(*[m.sum_indices for m in monomial_sum])
        # Put them in a deterministic order
        sum_indices = [i for i in quadrature_indices if i in sum_indices]
        # Sort for increasing index extent, this obtains the good
        # factorisation for triangle x interval cells.  Python sort is
        # stable, so in the common case when index extents are equal,
        # the previous deterministic ordering applies which is good
        # for getting smaller temporaries.
        sum_indices = sorted(sum_indices, key=lambda index: index.extent)
        # Apply sum factorisation combined with COFFEE technology
        expression = sum_factorise(variable, sum_indices, monomial_sum)
        yield (variable, expression)
Ejemplo n.º 10
0
def flatten(var_reps, index_cache):
    quadrature_indices = OrderedDict()

    pairs = []  # assignment pairs
    for variable, reps in var_reps:
        # Extract argument indices
        argument_indices, = set(r.argument_indices for r in reps)
        assert set(variable.free_indices) == set(argument_indices)

        # Extract and verify expressions
        expressions = [r.expression for r in reps]
        assert all(
            set(e.free_indices) <= set(argument_indices) for e in expressions)

        # Save assignment pair
        pairs.append((variable, reduce(Sum, expressions)))

        # Collect quadrature_indices
        for r in reps:
            quadrature_indices.update(zip_longest(r.quadrature_multiindex, ()))

    # Split Concatenate nodes
    pairs = unconcatenate(pairs, cache=index_cache)

    def group_key(pair):
        variable, expression = pair
        return frozenset(variable.free_indices)

    delta_inside = Memoizer(_delta_inside)
    # Variable ordering after delta cancellation
    narrow_variables = OrderedDict()
    # Assignments are variable -> MonomialSum map
    delta_simplified = defaultdict(MonomialSum)
    # Group assignment pairs by argument indices
    for free_indices, pair_group in groupby(pairs, group_key):
        variables, expressions = zip(*pair_group)
        # Argument factorise expressions
        classifier = partial(classify,
                             set(free_indices),
                             delta_inside=delta_inside)
        monomial_sums = collect_monomials(expressions, classifier)
        # For each monomial, apply delta cancellation and insert
        # result into delta_simplified.
        for variable, monomial_sum in zip(variables, monomial_sums):
            for monomial in monomial_sum:
                var, s, a, r = delta_elimination(variable, *monomial)
                narrow_variables.setdefault(var)
                delta_simplified[var].add(s, a, r)

    # Final factorisation
    for variable in narrow_variables:
        monomial_sum = delta_simplified[variable]
        # Collect sum indices applicable to the current MonomialSum
        sum_indices = set().union(*[m.sum_indices for m in monomial_sum])
        # Put them in a deterministic order
        sum_indices = [i for i in quadrature_indices if i in sum_indices]
        # Sort for increasing index extent, this obtains the good
        # factorisation for triangle x interval cells.  Python sort is
        # stable, so in the common case when index extents are equal,
        # the previous deterministic ordering applies which is good
        # for getting smaller temporaries.
        sum_indices = sorted(sum_indices, key=lambda index: index.extent)
        # Apply sum factorisation combined with COFFEE technology
        expression = sum_factorise(variable, sum_indices, monomial_sum)
        yield (variable, expression)