def inline_temporaries(expressions, ops): """Inline temporaries which could be inlined without blowing up the code. :arg expressions: a multi-root GEM expression DAG, used for reference counting :arg ops: ordered list of Impero terminals :returns: a filtered ``ops``, without the unnecessary :class:`impero.Evaluate`s """ refcount = collect_refcount(expressions) candidates = set() # candidates for inlining for op in ops: if isinstance(op, imp.Evaluate): expr = op.expression if expr.shape == () and refcount[expr] == 1: candidates.add(expr) # Prevent inlining that pulls expressions into inner loops for node in traversal(expressions): for child in node.children: if child in candidates and set(child.free_indices) < set(node.free_indices): candidates.remove(child) # Filter out candidates return [op for op in ops if not (isinstance(op, imp.Evaluate) and op.expression in candidates)]
def needs_cell_orientations(ir): """Does a multi-root GEM expression DAG references cell orientations?""" for node in traversal(ir): if isinstance(node, gem.Variable) and node.name == "cell_orientations": return True return False
def translate_coefficient(terminal, mt, ctx): vec = ctx.coefficient(terminal, mt.restriction) if terminal.ufl_element().family() == 'Real': assert mt.local_derivatives == 0 return vec element = ctx.create_element(terminal.ufl_element()) # Collect FInAT tabulation for all entities per_derivative = collections.defaultdict(list) for entity_id in ctx.entity_ids: finat_dict = ctx.basis_evaluation(element, mt.local_derivatives, entity_id) for alpha, table in finat_dict.items(): # Filter out irrelevant derivatives if sum(alpha) == mt.local_derivatives: # A numerical hack that FFC used to apply on FIAT # tables still lives on after ditching FFC and # switching to FInAT. table = ffc_rounding(table, ctx.epsilon) per_derivative[alpha].append(table) # Merge entity tabulations for each derivative if len(ctx.entity_ids) == 1: def take_singleton(xs): x, = xs # asserts singleton return x per_derivative = {alpha: take_singleton(tables) for alpha, tables in per_derivative.items()} else: f = ctx.entity_number(mt.restriction) per_derivative = {alpha: gem.select_expression(tables, f) for alpha, tables in per_derivative.items()} # Coefficient evaluation ctx.index_cache.setdefault(terminal.ufl_element(), element.get_indices()) beta = ctx.index_cache[terminal.ufl_element()] zeta = element.get_value_indices() vec_beta, = gem.optimise.remove_componenttensors([gem.Indexed(vec, beta)]) value_dict = {} for alpha, table in per_derivative.items(): table_qi = gem.Indexed(table, beta + zeta) summands = [] for var, expr in unconcatenate([(vec_beta, table_qi)], ctx.index_cache): value = gem.IndexSum(gem.Product(expr, var), var.index_ordering()) summands.append(gem.optimise.contraction(value)) optimised_value = gem.optimise.make_sum(summands) value_dict[alpha] = gem.ComponentTensor(optimised_value, zeta) # Change from FIAT to UFL arrangement result = fiat_to_ufl(value_dict, mt.local_derivatives) assert result.shape == mt.expr.ufl_shape assert set(result.free_indices) <= set(ctx.point_indices) # Detect Jacobian of affine cells if not result.free_indices and all(numpy.count_nonzero(node.array) <= 2 for node in traversal((result,)) if isinstance(node, gem.Literal)): result = gem.optimise.aggressive_unroll(result) return result
def optimise_expressions(expressions, argument_indices): """Perform loop optimisations on GEM DAGs :arg expressions: list of GEM DAGs :arg argument_indices: tuple of argument indices :returns: list of optimised GEM DAGs """ # Skip optimisation for if Failure node is present for n in traversal(expressions): if isinstance(n, Failure): return expressions def classify(argument_indices, expression): n = len(argument_indices.intersection(expression.free_indices)) if n == 0: return OTHER elif n == 1: if isinstance(expression, Indexed): return ATOMIC else: return COMPOUND else: return COMPOUND # Apply argument factorisation unconditionally classifier = partial(classify, set(argument_indices)) monomial_sums = collect_monomials(expressions, classifier) return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=False): """Compiles GEM to Impero. :arg return_variables: return variables for each root (type: GEM expressions) :arg expressions: multi-root expression DAG (type: GEM expressions) :arg prefix_ordering: outermost loop indices :arg remove_zeros: remove zero assignment to return variables """ expressions = optimise.remove_componenttensors(expressions) # Remove zeros if remove_zeros: rv = [] es = [] for var, expr in zip(return_variables, expressions): if not isinstance(expr, gem.Zero): rv.append(var) es.append(expr) return_variables, expressions = rv, es # Collect indices in a deterministic order indices = OrderedSet() for node in traversal(expressions): if isinstance(node, gem.Indexed): for index in node.multiindex: if isinstance(index, gem.Index): indices.add(index) elif isinstance(node, gem.FlexiblyIndexed): for offset, idxs in node.dim2idxs: for index, stride in idxs: if isinstance(index, gem.Index): indices.add(index) # Build ordered index map index_ordering = make_prefix_ordering(indices, prefix_ordering) apply_ordering = make_index_orderer(index_ordering) get_indices = lambda expr: apply_ordering(expr.free_indices) # Build operation ordering ops = scheduling.emit_operations(list(zip(return_variables, expressions)), get_indices) # Empty kernel if len(ops) == 0: raise NoopError() # Drop unnecessary temporaries ops = inline_temporaries(expressions, ops) # Build Impero AST tree = make_loop_tree(ops, get_indices) # Collect temporaries temporaries = collect_temporaries(ops) # Determine declarations declare, indices = place_declarations(ops, tree, temporaries, get_indices) # Prepare ImperoC (Impero AST + other data for code generation) return ImperoC(tree, temporaries, declare, indices)
def compile_gem(assignments, prefix_ordering, remove_zeros=False): """Compiles GEM to Impero. :arg assignments: list of (return variable, expression DAG root) pairs :arg prefix_ordering: outermost loop indices :arg remove_zeros: remove zero assignment to return variables """ # Remove zeros if remove_zeros: def nonzero(assignment): variable, expression = assignment return not isinstance(expression, gem.Zero) assignments = list(filter(nonzero, assignments)) # Just the expressions expressions = [expression for variable, expression in assignments] # Collect indices in a deterministic order indices = OrderedSet() for node in traversal(expressions): if isinstance(node, gem.Indexed): for index in node.multiindex: if isinstance(index, gem.Index): indices.add(index) elif isinstance(node, gem.FlexiblyIndexed): for offset, idxs in node.dim2idxs: for index, stride in idxs: if isinstance(index, gem.Index): indices.add(index) # Build ordered index map index_ordering = make_prefix_ordering(indices, prefix_ordering) apply_ordering = make_index_orderer(index_ordering) get_indices = lambda expr: apply_ordering(expr.free_indices) # Build operation ordering ops = scheduling.emit_operations(assignments, get_indices) # Empty kernel if len(ops) == 0: raise NoopError() # Drop unnecessary temporaries ops = inline_temporaries(expressions, ops) # Build Impero AST tree = make_loop_tree(ops, get_indices) # Collect temporaries temporaries = collect_temporaries(tree) # Determine declarations declare, indices = place_declarations(tree, temporaries, get_indices) # Prepare ImperoC (Impero AST + other data for code generation) return ImperoC(tree, temporaries, declare, indices)
def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=False): """Compiles GEM to Impero. :arg return_variables: return variables for each root (type: GEM expressions) :arg expressions: multi-root expression DAG (type: GEM expressions) :arg prefix_ordering: outermost loop indices :arg remove_zeros: remove zero assignment to return variables """ # Remove zeros if remove_zeros: rv = [] es = [] for var, expr in zip(return_variables, expressions): if not isinstance(expr, gem.Zero): rv.append(var) es.append(expr) return_variables, expressions = rv, es # Collect indices in a deterministic order indices = OrderedSet() for node in traversal(expressions): if isinstance(node, gem.Indexed): for index in node.multiindex: if isinstance(index, gem.Index): indices.add(index) elif isinstance(node, gem.FlexiblyIndexed): for offset, idxs in node.dim2idxs: for index, stride in idxs: if isinstance(index, gem.Index): indices.add(index) # Build ordered index map index_ordering = make_prefix_ordering(indices, prefix_ordering) apply_ordering = make_index_orderer(index_ordering) get_indices = lambda expr: apply_ordering(expr.free_indices) # Build operation ordering ops = scheduling.emit_operations(list(zip(return_variables, expressions)), get_indices) # Empty kernel if len(ops) == 0: raise NoopError() # Drop unnecessary temporaries ops = inline_temporaries(expressions, ops) # Build Impero AST tree = make_loop_tree(ops, get_indices) # Collect temporaries temporaries = collect_temporaries(tree) # Determine declarations declare, indices = place_declarations(tree, temporaries, get_indices) # Prepare ImperoC (Impero AST + other data for code generation) return ImperoC(tree, temporaries, declare, indices)
def classify(atomics_set, expression): if expression in atomics_set: return ATOMIC for node in traversal([expression]): if node in atomics_set: return COMPOUND return OTHER
def collect_temporaries(tree): """Collects GEM expressions to assign to temporaries from a list of Impero terminals.""" result = [] for node in traversal((tree,)): # IndexSum temporaries should be added either at Initialise or # at Accumulate. The difference is only in ordering # (numbering). We chose Accumulate here. if isinstance(node, imp.Accumulate): result.append(node.indexsum) elif isinstance(node, imp.Evaluate): result.append(node.expression) return result
def compile_gem(assignments, prefix_ordering, remove_zeros=False): """Compiles GEM to Impero. :arg assignments: list of (return variable, expression DAG root) pairs :arg prefix_ordering: outermost loop indices :arg remove_zeros: remove zero assignment to return variables """ # Remove zeros if remove_zeros: def nonzero(assignment): variable, expression = assignment return not isinstance(expression, gem.Zero) assignments = list(filter(nonzero, assignments)) # Just the expressions expressions = [expression for variable, expression in assignments] # Collect indices in a deterministic order indices = list(collections.OrderedDict.fromkeys(chain.from_iterable( node.index_ordering() for node in traversal(expressions) if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)) ))) # Build ordered index map index_ordering = make_prefix_ordering(indices, prefix_ordering) apply_ordering = make_index_orderer(index_ordering) get_indices = lambda expr: apply_ordering(expr.free_indices) # Build operation ordering ops = scheduling.emit_operations(assignments, get_indices) # Empty kernel if len(ops) == 0: raise NoopError() # Drop unnecessary temporaries ops = inline_temporaries(expressions, ops) # Build Impero AST tree = make_loop_tree(ops, get_indices) # Collect temporaries temporaries = collect_temporaries(tree) # Determine declarations declare, indices = place_declarations(tree, temporaries, get_indices) # Prepare ImperoC (Impero AST + other data for code generation) return ImperoC(tree, temporaries, declare, indices)
def check_requirements(ir): """Look for cell orientations, cell sizes, and collect tabulations in one pass.""" cell_orientations = False cell_sizes = False rt_tabs = {} for node in traversal(ir): if isinstance(node, gem.Variable): if node.name == "cell_orientations": cell_orientations = True elif node.name == "cell_sizes": cell_sizes = True elif node.name.startswith("rt_"): rt_tabs[node.name] = node.shape return cell_orientations, cell_sizes, tuple(sorted(rt_tabs.items()))
def contraction(expression): """Optimise the contractions of the tensor product at the root of the expression, including: - IndexSum-Delta cancellation - Sum factorisation This routine was designed with finite element coefficient evaluation in mind. """ # Eliminate annoying ComponentTensors expression, = remove_componenttensors([expression]) # Flatten product tree, eliminate deltas, sum factorise def rebuild(expression): sum_indices, factors = delta_elimination(*traverse_product(expression)) factors = remove_componenttensors(factors) return sum_factorise(sum_indices, factors) # Sometimes the value shape is composed as a ListTensor, which # could get in the way of decomposing factors. In particular, # this is the case for H(div) and H(curl) conforming tensor # product elements. So if ListTensors are used, they are pulled # out to be outermost, so we can straightforwardly factorise each # of its entries. lt_fis = OrderedDict() # ListTensor free indices for node in traversal((expression, )): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor): lt_fis.update(zip_longest(node.multiindex, ())) lt_fis = tuple(index for index in lt_fis if index in expression.free_indices) if lt_fis: # Rebuild each split component tensor = ComponentTensor(expression, lt_fis) entries = [ Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape) ] entries = remove_componenttensors(entries) return Indexed( ListTensor( numpy.array(list(map(rebuild, entries))).reshape(tensor.shape)), lt_fis) else: # Rebuild whole expression at once return rebuild(expression)
def optimise_expressions(expressions, argument_indices): """Perform loop optimisations on GEM DAGs :arg expressions: list of GEM DAGs :arg argument_indices: tuple of argument indices :returns: list of optimised GEM DAGs """ # Skip optimisation for if Failure node is present for n in traversal(expressions): if isinstance(n, Failure): return expressions # Apply argument factorisation unconditionally classifier = partial(spectral.classify, set(argument_indices)) monomial_sums = collect_monomials(expressions, classifier) return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
def optimise_expressions(expressions, argument_indices): """Perform loop optimisations on GEM DAGs :arg expressions: list of GEM DAGs :arg argument_indices: tuple of argument indices :returns: list of optimised GEM DAGs """ # Skip optimisation for if Failure node is present for n in traversal(expressions): if isinstance(n, Failure): return expressions # Apply argument factorisation unconditionally classifier = partial(spectral.classify, set(argument_indices)) monomial_sums = collect_monomials(expressions, classifier) return [ optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums ]
def collect_monomials(expressions, classifier): """Refactorises expressions into a sum-of-products form, using distributivity rules (i.e. a*(b + c) -> a*b + a*c). Expansion proceeds until all "compound" expressions are broken up. :arg expressions: GEM expressions to refactorise :arg classifier: a function that can classify any GEM expression as ``ATOMIC``, ``COMPOUND``, or ``OTHER``. This classification drives the factorisation. :returns: list of :py:class:`MonomialSum`s :raises FactorisationError: Failed to break up some "compound" expressions with expansion. """ # Get ComponentTensors out of the way expressions = remove_componenttensors(expressions) # Get ListTensors out of the way must_unroll = [] # indices to unroll for node in traversal(expressions): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor) and classifier(node) == COMPOUND: must_unroll.extend(node.multiindex) if must_unroll: must_unroll = set(must_unroll) expressions = unroll_indexsum(expressions, predicate=lambda i: i in must_unroll) expressions = remove_componenttensors(expressions) # Expand Conditional nodes which are COMPOUND conditional_predicate = lambda node: classifier(node) == COMPOUND expressions = expand_conditional(expressions, conditional_predicate) # Finally, refactorise expressions mapper = Memoizer(_collect_monomials) mapper.classifier = classifier mapper.rename_map = make_rename_map() return list(map(mapper, expressions))
def place_declarations(tree, temporaries, get_indices): """Determines where and how to declare temporaries for an Impero AST. :arg tree: Impero AST to determine the declarations for :arg temporaries: list of GEM expressions which are assigned to temporaries :arg get_indices: callable mapping from GEM nodes to an ordering of free indices """ numbering = {t: n for n, t in enumerate(temporaries)} assert len(numbering) == len(temporaries) # Collect the total number of temporary references total_refcount = collections.Counter() for node in traversal((tree,)): if isinstance(node, imp.Terminal): total_refcount.update(temp_refcount(numbering, node)) assert set(total_refcount) == set(temporaries) # Result declare = {} indices = {} @singledispatch def recurse(expr, loop_indices): """Visit an Impero AST to collect declarations. :arg expr: Impero tree node :arg loop_indices: loop indices (in order) from the outer loops surrounding ``expr`` :returns: :class:`collections.Counter` with the reference counts for each temporary in the subtree whose root is ``expr`` """ return AssertionError("unsupported expression type %s" % type(expr)) @recurse.register(imp.Terminal) def recurse_terminal(expr, loop_indices): return temp_refcount(numbering, expr) @recurse.register(imp.For) def recurse_for(expr, loop_indices): return recurse(expr.children[0], loop_indices + (expr.index,)) @recurse.register(imp.Block) def recurse_block(expr, loop_indices): # Temporaries declared at the beginning of the block are # collected here declare[expr] = [] # Collect reference counts for the block refcount = collections.Counter() for statement in expr.children: refcount.update(recurse(statement, loop_indices)) # Visit :class:`collections.Counter` in deterministic order for e in sorted(refcount.keys(), key=lambda t: numbering[t]): if refcount[e] == total_refcount[e]: # If all references are within this block, then this # block is the right place to declare the temporary. assert loop_indices == get_indices(e)[:len(loop_indices)] indices[e] = get_indices(e)[len(loop_indices):] if indices[e]: # Scalar-valued temporaries are not declared until # their value is assigned. This does not really # matter, but produces a more compact and nicer to # read C code. declare[expr].append(e) # Remove expression from the ``refcount`` so it will # not be declared again. del refcount[e] return refcount # Populate result remainder = recurse(tree, ()) assert not remainder # Set in ``declare`` for Impero terminals whether they should # declare the temporary that they are writing to. for node in traversal((tree,)): if isinstance(node, imp.Terminal): declare[node] = False if isinstance(node, imp.Evaluate): e = node.expression elif isinstance(node, imp.Initialise): e = node.indexsum else: continue if len(indices[e]) == 0: declare[node] = True return declare, indices