def factorise_atomics(monomials, optimal_atomics, linear_indices): """Group and factorise monomials using a list of atomics as common subexpressions. Create new monomials for each group and optimise them recursively. :arg monomials: an iterable of :class:`Monomial`s, all of which should have the same sum indices :arg optimal_atomics: list of tuples of atomics to be used as common subexpression :arg linear_indices: tuple of linear indices :returns: an iterable of :class:`Monomials`s after factorisation """ if not optimal_atomics or len(monomials) <= 1: return monomials # Group monomials with respect to each optimal atomic def group_key(monomial): for oa in optimal_atomics: if oa in monomial.atomics: return oa assert False, "Expect at least one optimal atomic per monomial." factor_group = groupby(monomials, key=group_key) # We should not drop monomials assert sum(len(ms) for _, ms in factor_group) == len(monomials) sum_indices = next(iter(monomials)).sum_indices new_monomials = [] for oa, monomials in factor_group: # Create new MonomialSum for the factorised out terms sub_monomials = [] for monomial in monomials: atomics = list(monomial.atomics) atomics.remove(oa) # remove common factor sub_monomials.append(Monomial((), tuple(atomics), monomial.rest)) # Continue to factorise the remaining expression sub_monomials = optimise_monomials(sub_monomials, linear_indices) if len(sub_monomials) == 1: # Factorised part is a product, we add back the common atomics then # add to new MonomialSum directly rather than forming a product node # Retaining the monomial structure enables applying associativity # when forming GEM nodes later. sub_monomial, = sub_monomials new_monomials.append( Monomial(sum_indices, (oa, ) + sub_monomial.atomics, sub_monomial.rest)) else: # Factorised part is a summation, we need to create a new GEM node # and multiply with the common factor node = monomial_sum_to_expression(sub_monomials) # If the free indices of the new node intersect with linear indices, # add to the new monomial as `atomic`, otherwise add as `rest`. # Note: we might want to continue to factorise with the new atomics # by running optimise_monoials twice. if set(linear_indices) & set(node.free_indices): new_monomials.append(Monomial(sum_indices, (oa, node), one)) else: new_monomials.append(Monomial(sum_indices, (oa, ), node)) return new_monomials
def factorise_atomics(monomials, optimal_atomics, linear_indices): """Group and factorise monomials using a list of atomics as common subexpressions. Create new monomials for each group and optimise them recursively. :arg monomials: an iterable of :class:`Monomial`s, all of which should have the same sum indices :arg optimal_atomics: list of tuples of atomics to be used as common subexpression :arg linear_indices: tuple of linear indices :returns: an iterable of :class:`Monomials`s after factorisation """ if not optimal_atomics or len(monomials) <= 1: return monomials # Group monomials with respect to each optimal atomic def group_key(monomial): for oa in optimal_atomics: if oa in monomial.atomics: return oa assert False, "Expect at least one optimal atomic per monomial." factor_group = groupby(monomials, key=group_key) # We should not drop monomials assert sum(len(ms) for _, ms in factor_group) == len(monomials) sum_indices = next(iter(monomials)).sum_indices new_monomials = [] for oa, monomials in factor_group: # Create new MonomialSum for the factorised out terms sub_monomials = [] for monomial in monomials: atomics = list(monomial.atomics) atomics.remove(oa) # remove common factor sub_monomials.append(Monomial((), tuple(atomics), monomial.rest)) # Continue to factorise the remaining expression sub_monomials = optimise_monomials(sub_monomials, linear_indices) if len(sub_monomials) == 1: # Factorised part is a product, we add back the common atomics then # add to new MonomialSum directly rather than forming a product node # Retaining the monomial structure enables applying associativity # when forming GEM nodes later. sub_monomial, = sub_monomials new_monomials.append( Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest)) else: # Factorised part is a summation, we need to create a new GEM node # and multiply with the common factor node = monomial_sum_to_expression(sub_monomials) # If the free indices of the new node intersect with linear indices, # add to the new monomial as `atomic`, otherwise add as `rest`. # Note: we might want to continue to factorise with the new atomics # by running optimise_monoials twice. if set(linear_indices) & set(node.free_indices): new_monomials.append(Monomial(sum_indices, (oa, node), one)) else: new_monomials.append(Monomial(sum_indices, (oa, ), node)) return new_monomials
def sum_factorise(sum_indices, factors): """Optimise a tensor product through sum factorisation. :arg sum_indices: free indices for contractions :arg factors: product factors :returns: optimised GEM expression """ if len(factors) == 0 and len(sum_indices) == 0: # Empty product return one if len(sum_indices) > 6: raise NotImplementedError("Too many indices for sum factorisation!") # Form groups by free indices groups = groupby(factors, key=lambda f: f.free_indices) groups = [reduce(Product, terms) for _, terms in groups] # Sum factorisation expression = None best_flops = numpy.inf # Consider all orderings of contraction indices for ordering in permutations(sum_indices): terms = groups[:] flops = 0 # Apply contraction index by index for sum_index in ordering: # Select terms that need to be part of the contraction contract = [t for t in terms if sum_index in t.free_indices] deferred = [t for t in terms if sum_index not in t.free_indices] # Optimise associativity product, flops_ = associate(Product, contract) term = IndexSum(product, (sum_index, )) flops += flops_ + numpy.prod( [i.extent for i in product.free_indices], dtype=int) # Replace the contracted terms with the result of the # contraction. terms = deferred + [term] # If some contraction indices were independent, then we may # still have several terms at this point. expr, flops_ = associate(Product, terms) flops += flops_ if flops < best_flops: expression = expr best_flops = flops return expression
def optimise_monomial_sum(monomial_sum, linear_indices): """Choose optimal common atomic subexpressions and factorise a :class:`MonomialSum` object to create a GEM expression. :arg monomial_sum: a :class:`MonomialSum` object :arg linear_indices: tuple of linear indices :returns: factorised GEM expression """ groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) new_monomials = [] for _, monomials in groups: new_monomials.extend(optimise_monomials(monomials, linear_indices)) return monomial_sum_to_expression(new_monomials)
def monomial_sum_to_expression(monomial_sum): """Convert a monomial sum to a GEM expression. :arg monomial_sum: an iterable of :class:`Monomial`s :returns: GEM expression """ indexsums = [] # The result is summation of indexsums # Group monomials according to their sum indices groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) # Create IndexSum's from each monomial group for _, monomials in groups: sum_indices = monomials[0].sum_indices products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials] indexsums.append(IndexSum(make_sum(products), sum_indices)) return make_sum(indexsums)
def flatten(var_reps, index_cache): """Flatten mode-specific intermediate representation to a series of assignments. :arg var_reps: series of (return variable, [integral representation]) pairs :arg index_cache: cache :py:class:`dict` for :py:func:`unconcatenate` :returns: series of (return variable, GEM expression root) pairs """ assignments = unconcatenate([(variable, reduce(Sum, reps)) for variable, reps in var_reps], cache=index_cache) def group_key(assignment): variable, expression = assignment return variable.free_indices for argument_indices, assignment_group in groupby(assignments, group_key): variables, expressions = zip(*assignment_group) expressions = optimise_expressions(expressions, argument_indices) for var, expr in zip(variables, expressions): yield (var, expr)
def flatten(var_reps, index_cache): quadrature_indices = OrderedDict() pairs = [] # assignment pairs for variable, reps in var_reps: # Extract argument indices argument_indices, = set(r.argument_indices for r in reps) assert set(variable.free_indices) == set(argument_indices) # Extract and verify expressions expressions = [r.expression for r in reps] assert all(set(e.free_indices) <= set(argument_indices) for e in expressions) # Save assignment pair pairs.append((variable, reduce(Sum, expressions))) # Collect quadrature_indices for r in reps: quadrature_indices.update(zip_longest(r.quadrature_multiindex, ())) # Split Concatenate nodes pairs = unconcatenate(pairs, cache=index_cache) def group_key(pair): variable, expression = pair return frozenset(variable.free_indices) # Variable ordering after delta cancellation narrow_variables = OrderedDict() # Assignments are variable -> MonomialSum map delta_simplified = defaultdict(MonomialSum) # Group assignment pairs by argument indices for free_indices, pair_group in groupby(pairs, group_key): variables, expressions = zip(*pair_group) # Argument factorise expressions classifier = partial(classify, set(free_indices)) monomial_sums = collect_monomials(expressions, classifier) # For each monomial, apply delta cancellation and insert # result into delta_simplified. for variable, monomial_sum in zip(variables, monomial_sums): for monomial in monomial_sum: var, s, a, r = delta_elimination(variable, *monomial) narrow_variables.setdefault(var) delta_simplified[var].add(s, a, r) # Final factorisation for variable in narrow_variables: monomial_sum = delta_simplified[variable] # Collect sum indices applicable to the current MonomialSum sum_indices = set().union(*[m.sum_indices for m in monomial_sum]) # Put them in a deterministic order sum_indices = [i for i in quadrature_indices if i in sum_indices] # Sort for increasing index extent, this obtains the good # factorisation for triangle x interval cells. Python sort is # stable, so in the common case when index extents are equal, # the previous deterministic ordering applies which is good # for getting smaller temporaries. sum_indices = sorted(sum_indices, key=lambda index: index.extent) # Apply sum factorisation combined with COFFEE technology expression = sum_factorise(variable, sum_indices, monomial_sum) yield (variable, expression)
def tensor_assembly_calls(builder): """Generates a block of statements for assembling the local finite element tensors. :arg builder: The :class:`LocalKernelBuilder` containing all relevant expression information and assembly calls. """ assembly_calls = builder.assembly_calls statements = [ast.FlatBlock("/* Assemble local tensors */\n")] # Cell integrals are straightforward. Just splat them out. statements.extend(assembly_calls["cell"]) if builder.needs_cell_facets: # The for-loop will have the general structure: # # FOR (facet=0; facet<num_facets; facet++): # IF (facet is interior): # *interior calls # ELSE IF (facet is exterior): # *exterior calls # # If only interior (exterior) facets are present, # then only a single IF-statement checking for interior # (exterior) facets will be present within the loop. The # cell facets are labelled `1` for interior, and `0` for # exterior. statements.append(ast.FlatBlock("/* Loop over cell facets */\n")) int_calls = list( chain(*[ assembly_calls[it_type] for it_type in ("interior_facet", "interior_facet_vert") ])) ext_calls = list( chain(*[ assembly_calls[it_type] for it_type in ("exterior_facet", "exterior_facet_vert") ])) # Generate logical statements for handling exterior/interior facet # integrals on subdomains. # Currently only facet integrals are supported. for sd_type in ("subdomains_exterior_facet", "subdomains_interior_facet"): stmts = [] for sd, sd_calls in groupby(assembly_calls[sd_type], lambda x: x[0]): _, calls = zip(*sd_calls) if_sd = ast.Eq( ast.Symbol(builder.cell_facet_sym, rank=(builder.it_sym, 1)), sd) stmts.append( ast.If(if_sd, (ast.Block(calls, open_scope=True), ))) if sd_type == "subdomains_exterior_facet": ext_calls.extend(stmts) if sd_type == "subdomains_interior_facet": int_calls.extend(stmts) # Compute the number of facets to loop over domain = builder.expression.ufl_domain() if domain.cell_set._extruded: num_facets = domain.ufl_cell()._cells[0].num_facets() else: num_facets = domain.ufl_cell().num_facets() if_ext = ast.Eq( ast.Symbol(builder.cell_facet_sym, rank=(builder.it_sym, 0)), 0) if_int = ast.Eq( ast.Symbol(builder.cell_facet_sym, rank=(builder.it_sym, 0)), 1) body = [] if ext_calls: body.append( ast.If(if_ext, (ast.Block(ext_calls, open_scope=True), ))) if int_calls: body.append( ast.If(if_int, (ast.Block(int_calls, open_scope=True), ))) statements.append( ast.For(ast.Decl("unsigned int", builder.it_sym, init=0), ast.Less(builder.it_sym, num_facets), ast.Incr(builder.it_sym, 1), body)) if builder.needs_mesh_layers: # In the presence of interior horizontal facet calls, an # IF-ELIF-ELSE block is generated using the mesh levels # as conditions for which calls are needed: # # IF (layer == bottom_layer): # *bottom calls # ELSE IF (layer == top_layer): # *top calls # ELSE: # *top calls # *bottom calls # # Any extruded top or bottom calls for extruded facets are # included within the appropriate mesh-level IF-blocks. If # no interior horizontal facet calls are present, then # standard IF-blocks are generated for exterior top/bottom # facet calls when appropriate: # # IF (layer == bottom_layer): # *bottom calls # # IF (layer == top_layer): # *top calls # # The mesh level is an integer provided as a macro kernel # argument. # FIXME: No variable layers assumption statements.append(ast.FlatBlock("/* Mesh levels: */\n")) num_layers = ast.Symbol(builder.mesh_layer_count_sym, rank=(0, )) layer = builder.mesh_layer_sym types = [ "interior_facet_horiz_top", "interior_facet_horiz_bottom", "exterior_facet_top", "exterior_facet_bottom" ] decide = [ ast.Less(layer, num_layers), ast.Greater(layer, 0), ast.Eq(layer, num_layers), ast.Eq(layer, 0) ] for (integral_type, which) in zip(types, decide): statements.append( ast.If(which, (ast.Block(assembly_calls[integral_type], open_scope=True), ))) return statements
def make_sum(summands): """Constructs an operation-minimal sum of GEM expressions.""" groups = groupby(summands, key=lambda f: f.free_indices) summands = [reduce(Sum, terms) for _, terms in groups] result, flops = associate(Sum, summands) return result
def flatten(var_reps, index_cache): quadrature_indices = OrderedDict() pairs = [] # assignment pairs for variable, reps in var_reps: # Extract argument indices argument_indices, = set(r.argument_indices for r in reps) assert set(variable.free_indices) == set(argument_indices) # Extract and verify expressions expressions = [r.expression for r in reps] assert all( set(e.free_indices) <= set(argument_indices) for e in expressions) # Save assignment pair pairs.append((variable, reduce(Sum, expressions))) # Collect quadrature_indices for r in reps: quadrature_indices.update(zip_longest(r.quadrature_multiindex, ())) # Split Concatenate nodes pairs = unconcatenate(pairs, cache=index_cache) def group_key(pair): variable, expression = pair return frozenset(variable.free_indices) delta_inside = Memoizer(_delta_inside) # Variable ordering after delta cancellation narrow_variables = OrderedDict() # Assignments are variable -> MonomialSum map delta_simplified = defaultdict(MonomialSum) # Group assignment pairs by argument indices for free_indices, pair_group in groupby(pairs, group_key): variables, expressions = zip(*pair_group) # Argument factorise expressions classifier = partial(classify, set(free_indices), delta_inside=delta_inside) monomial_sums = collect_monomials(expressions, classifier) # For each monomial, apply delta cancellation and insert # result into delta_simplified. for variable, monomial_sum in zip(variables, monomial_sums): for monomial in monomial_sum: var, s, a, r = delta_elimination(variable, *monomial) narrow_variables.setdefault(var) delta_simplified[var].add(s, a, r) # Final factorisation for variable in narrow_variables: monomial_sum = delta_simplified[variable] # Collect sum indices applicable to the current MonomialSum sum_indices = set().union(*[m.sum_indices for m in monomial_sum]) # Put them in a deterministic order sum_indices = [i for i in quadrature_indices if i in sum_indices] # Sort for increasing index extent, this obtains the good # factorisation for triangle x interval cells. Python sort is # stable, so in the common case when index extents are equal, # the previous deterministic ordering applies which is good # for getting smaller temporaries. sum_indices = sorted(sum_indices, key=lambda index: index.extent) # Apply sum factorisation combined with COFFEE technology expression = sum_factorise(variable, sum_indices, monomial_sum) yield (variable, expression)