def optimise_expressions(expressions, argument_indices): """Perform loop optimisations on GEM DAGs :arg expressions: list of GEM DAGs :arg argument_indices: tuple of argument indices :returns: list of optimised GEM DAGs """ # Skip optimisation for if Failure node is present for n in traversal(expressions): if isinstance(n, Failure): return expressions def classify(argument_indices, expression): n = len(argument_indices.intersection(expression.free_indices)) if n == 0: return OTHER elif n == 1: if isinstance(expression, Indexed): return ATOMIC else: return COMPOUND else: return COMPOUND # Apply argument factorisation unconditionally classifier = partial(classify, set(argument_indices)) monomial_sums = collect_monomials(expressions, classifier) return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
def flatten(var_reps): # Classifier for argument factorisation def classify(argument_indices, expression): n = len(argument_indices.intersection(expression.free_indices)) if n == 0: return OTHER elif n == 1: return ATOMIC else: return COMPOUND for variable, reps in var_reps: # Destructure representation argument_indicez, expressions = zip(*reps) # Assert identical argument indices for all integrals argument_indices, = set(map(frozenset, argument_indicez)) # Argument factorise classifier = partial(classify, argument_indices) for monomial_sum in collect_monomials(expressions, classifier): # Compact MonomialSum after IndexSum-Delta cancellation delta_simplified = MonomialSum() for monomial in monomial_sum: delta_simplified.add(*delta_elimination(*monomial)) # Yield assignments for monomial in delta_simplified: yield (variable, sum_factorise(*monomial))
def test_refactorise(): f = gem.Variable('f', (3,)) u = gem.Variable('u', (3,)) v = gem.Variable('v', ()) i = gem.Index() f_i = gem.Indexed(f, (i,)) u_i = gem.Indexed(u, (i,)) def classify(atomics_set, expression): if expression in atomics_set: return ATOMIC for node in traversal([expression]): if node in atomics_set: return COMPOUND return OTHER classifier = partial(classify, {u_i, v}) # \sum_i 5*(2*u_i + -1*v)*(u_i + v*f) expr = gem.IndexSum( gem.Product( gem.Literal(5), gem.Product( gem.Sum(gem.Product(gem.Literal(2), u_i), gem.Product(gem.Literal(-1), v)), gem.Sum(u_i, gem.Product(v, f_i)) ) ), (i,) ) expected = [ Monomial((i,), (u_i, u_i), gem.Literal(10)), Monomial((i,), (u_i, v), gem.Product(gem.Literal(5), gem.Sum(gem.Product(f_i, gem.Literal(2)), gem.Literal(-1)))), Monomial((), (v, v), gem.Product(gem.Literal(5), gem.IndexSum(gem.Product(f_i, gem.Literal(-1)), (i,)))), ] actual, = collect_monomials([expr], classifier) assert expected == list(actual)
def test_refactorise(): f = gem.Variable('f', (3,)) u = gem.Variable('u', (3,)) v = gem.Variable('v', ()) i = gem.Index() f_i = gem.Indexed(f, (i,)) u_i = gem.Indexed(u, (i,)) def classify(atomics_set, expression): if expression in atomics_set: return ATOMIC for node in traversal([expression]): if node in atomics_set: return COMPOUND return OTHER classifier = partial(classify, {u_i, v}) # \sum_i 5*(2*u_i + -1*v)*(u_i + v*f) expr = gem.IndexSum( gem.Product( gem.Literal(5), gem.Product( gem.Sum(gem.Product(gem.Literal(2), u_i), gem.Product(gem.Literal(-1), v)), gem.Sum(u_i, gem.Product(v, f_i)) ) ), (i,) ) expected = [ Monomial((i,), (u_i, u_i), gem.Literal(10)), Monomial((i,), (u_i, v), gem.Product(gem.Literal(5), gem.Sum(gem.Product(f_i, gem.Literal(2)), gem.Literal(-1)))), Monomial((), (v, v), gem.Product(gem.Literal(5), gem.IndexSum(gem.Product(f_i, gem.Literal(-1)), (i,)))), ] actual, = collect_monomials([expr], classifier) assert expected == list(actual)
def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters): # Concatenate expressions = concatenate(expressions) # Unroll max_extent = parameters["unroll_indexsum"] if max_extent: def predicate(index): return index.extent <= max_extent expressions = unroll_indexsum(expressions, predicate=predicate) # Refactorise def classify(quadrature_indices, expression): if not quadrature_indices.intersection(expression.free_indices): return OTHER elif isinstance(expression, gem.Indexed) and isinstance( expression.children[0], gem.Literal): return ATOMIC else: return COMPOUND classifier = partial(classify, set(quadrature_multiindex)) result = [] for expr, monomial_sum in zip(expressions, collect_monomials(expressions, classifier)): # Select quadrature indices that are present quadrature_indices = set(index for index in quadrature_multiindex if index in expr.free_indices) products = [] for sum_indices, factors, rest in monomial_sum: # Collapse quadrature literals for each monomial if factors or quadrature_indices: replacement = einsum(remove_componenttensors(factors), quadrature_indices) else: replacement = gem.Literal(1) # Rebuild expression products.append( gem.IndexSum(gem.Product(replacement, rest), sum_indices)) result.append(reduce(gem.Sum, products, gem.Zero())) return result
def optimise_expressions(expressions, argument_indices): """Perform loop optimisations on GEM DAGs :arg expressions: list of GEM DAGs :arg argument_indices: tuple of argument indices :returns: list of optimised GEM DAGs """ # Skip optimisation for if Failure node is present for n in traversal(expressions): if isinstance(n, Failure): return expressions # Apply argument factorisation unconditionally classifier = partial(spectral.classify, set(argument_indices)) monomial_sums = collect_monomials(expressions, classifier) return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
def optimise_expressions(expressions, argument_indices): """Perform loop optimisations on GEM DAGs :arg expressions: list of GEM DAGs :arg argument_indices: tuple of argument indices :returns: list of optimised GEM DAGs """ # Skip optimisation for if Failure node is present for n in traversal(expressions): if isinstance(n, Failure): return expressions # Apply argument factorisation unconditionally classifier = partial(spectral.classify, set(argument_indices)) monomial_sums = collect_monomials(expressions, classifier) return [ optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums ]
def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters): # Concatenate expressions = concatenate(expressions) # Unroll max_extent = parameters["unroll_indexsum"] if max_extent: def predicate(index): return index.extent <= max_extent expressions = unroll_indexsum(expressions, predicate=predicate) # Refactorise def classify(quadrature_indices, expression): if not quadrature_indices.intersection(expression.free_indices): return OTHER elif isinstance(expression, gem.Indexed) and isinstance(expression.children[0], gem.Literal): return ATOMIC else: return COMPOUND classifier = partial(classify, set(quadrature_multiindex)) result = [] for expr, monomial_sum in zip(expressions, collect_monomials(expressions, classifier)): # Select quadrature indices that are present quadrature_indices = set(index for index in quadrature_multiindex if index in expr.free_indices) products = [] for sum_indices, factors, rest in monomial_sum: # Collapse quadrature literals for each monomial if factors or quadrature_indices: replacement = einsum(remove_componenttensors(factors), quadrature_indices) else: replacement = gem.Literal(1) # Rebuild expression products.append(gem.IndexSum(gem.Product(replacement, rest), sum_indices)) result.append(reduce(gem.Sum, products, gem.Zero())) return result
def flatten(var_reps, index_cache): quadrature_indices = OrderedDict() pairs = [] # assignment pairs for variable, reps in var_reps: # Extract argument indices argument_indices, = set(r.argument_indices for r in reps) assert set(variable.free_indices) == set(argument_indices) # Extract and verify expressions expressions = [r.expression for r in reps] assert all(set(e.free_indices) <= set(argument_indices) for e in expressions) # Save assignment pair pairs.append((variable, reduce(Sum, expressions))) # Collect quadrature_indices for r in reps: quadrature_indices.update(zip_longest(r.quadrature_multiindex, ())) # Split Concatenate nodes pairs = unconcatenate(pairs, cache=index_cache) def group_key(pair): variable, expression = pair return frozenset(variable.free_indices) # Variable ordering after delta cancellation narrow_variables = OrderedDict() # Assignments are variable -> MonomialSum map delta_simplified = defaultdict(MonomialSum) # Group assignment pairs by argument indices for free_indices, pair_group in groupby(pairs, group_key): variables, expressions = zip(*pair_group) # Argument factorise expressions classifier = partial(classify, set(free_indices)) monomial_sums = collect_monomials(expressions, classifier) # For each monomial, apply delta cancellation and insert # result into delta_simplified. for variable, monomial_sum in zip(variables, monomial_sums): for monomial in monomial_sum: var, s, a, r = delta_elimination(variable, *monomial) narrow_variables.setdefault(var) delta_simplified[var].add(s, a, r) # Final factorisation for variable in narrow_variables: monomial_sum = delta_simplified[variable] # Collect sum indices applicable to the current MonomialSum sum_indices = set().union(*[m.sum_indices for m in monomial_sum]) # Put them in a deterministic order sum_indices = [i for i in quadrature_indices if i in sum_indices] # Sort for increasing index extent, this obtains the good # factorisation for triangle x interval cells. Python sort is # stable, so in the common case when index extents are equal, # the previous deterministic ordering applies which is good # for getting smaller temporaries. sum_indices = sorted(sum_indices, key=lambda index: index.extent) # Apply sum factorisation combined with COFFEE technology expression = sum_factorise(variable, sum_indices, monomial_sum) yield (variable, expression)
def flatten(var_reps, index_cache): quadrature_indices = OrderedDict() pairs = [] # assignment pairs for variable, reps in var_reps: # Extract argument indices argument_indices, = set(r.argument_indices for r in reps) assert set(variable.free_indices) == set(argument_indices) # Extract and verify expressions expressions = [r.expression for r in reps] assert all( set(e.free_indices) <= set(argument_indices) for e in expressions) # Save assignment pair pairs.append((variable, reduce(Sum, expressions))) # Collect quadrature_indices for r in reps: quadrature_indices.update(zip_longest(r.quadrature_multiindex, ())) # Split Concatenate nodes pairs = unconcatenate(pairs, cache=index_cache) def group_key(pair): variable, expression = pair return frozenset(variable.free_indices) delta_inside = Memoizer(_delta_inside) # Variable ordering after delta cancellation narrow_variables = OrderedDict() # Assignments are variable -> MonomialSum map delta_simplified = defaultdict(MonomialSum) # Group assignment pairs by argument indices for free_indices, pair_group in groupby(pairs, group_key): variables, expressions = zip(*pair_group) # Argument factorise expressions classifier = partial(classify, set(free_indices), delta_inside=delta_inside) monomial_sums = collect_monomials(expressions, classifier) # For each monomial, apply delta cancellation and insert # result into delta_simplified. for variable, monomial_sum in zip(variables, monomial_sums): for monomial in monomial_sum: var, s, a, r = delta_elimination(variable, *monomial) narrow_variables.setdefault(var) delta_simplified[var].add(s, a, r) # Final factorisation for variable in narrow_variables: monomial_sum = delta_simplified[variable] # Collect sum indices applicable to the current MonomialSum sum_indices = set().union(*[m.sum_indices for m in monomial_sum]) # Put them in a deterministic order sum_indices = [i for i in quadrature_indices if i in sum_indices] # Sort for increasing index extent, this obtains the good # factorisation for triangle x interval cells. Python sort is # stable, so in the common case when index extents are equal, # the previous deterministic ordering applies which is good # for getting smaller temporaries. sum_indices = sorted(sum_indices, key=lambda index: index.extent) # Apply sum factorisation combined with COFFEE technology expression = sum_factorise(variable, sum_indices, monomial_sum) yield (variable, expression)