def prune(factors): # Skip last factor (``rest``, see above) which can be # arbitrarily complicated, so its pruning may be expensive, # and its early pruning brings no advantages. result = remove_componenttensors(factors[:-1]) result.append(factors[-1]) return result
def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=False): """Compiles GEM to Impero. :arg return_variables: return variables for each root (type: GEM expressions) :arg expressions: multi-root expression DAG (type: GEM expressions) :arg prefix_ordering: outermost loop indices :arg remove_zeros: remove zero assignment to return variables """ expressions = optimise.remove_componenttensors(expressions) # Remove zeros if remove_zeros: rv = [] es = [] for var, expr in zip(return_variables, expressions): if not isinstance(expr, gem.Zero): rv.append(var) es.append(expr) return_variables, expressions = rv, es # Collect indices in a deterministic order indices = OrderedSet() for node in traversal(expressions): if isinstance(node, gem.Indexed): for index in node.multiindex: if isinstance(index, gem.Index): indices.add(index) elif isinstance(node, gem.FlexiblyIndexed): for offset, idxs in node.dim2idxs: for index, stride in idxs: if isinstance(index, gem.Index): indices.add(index) # Build ordered index map index_ordering = make_prefix_ordering(indices, prefix_ordering) apply_ordering = make_index_orderer(index_ordering) get_indices = lambda expr: apply_ordering(expr.free_indices) # Build operation ordering ops = scheduling.emit_operations(list(zip(return_variables, expressions)), get_indices) # Empty kernel if len(ops) == 0: raise NoopError() # Drop unnecessary temporaries ops = inline_temporaries(expressions, ops) # Build Impero AST tree = make_loop_tree(ops, get_indices) # Collect temporaries temporaries = collect_temporaries(ops) # Determine declarations declare, indices = place_declarations(ops, tree, temporaries, get_indices) # Prepare ImperoC (Impero AST + other data for code generation) return ImperoC(tree, temporaries, declare, indices)
def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters): """Constructs an integral representation for each GEM integrand expression. :arg expressions: integrand multiplied with quadrature weight; multi-root GEM expression DAG :arg quadrature_multiindex: quadrature multiindex (tuple) :arg argument_multiindices: tuple of argument multiindices, one multiindex for each argument :arg parameters: parameters dictionary :returns: list of integral representations """ # Unroll max_extent = parameters["unroll_indexsum"] if max_extent: def predicate(index): return index.extent <= max_extent expressions = unroll_indexsum(expressions, predicate=predicate) # Choose GEM expression as the integral representation expressions = [index_sum(e, quadrature_multiindex) for e in expressions] expressions = replace_delta(expressions) expressions = remove_componenttensors(expressions) expressions = replace_division(expressions) argument_indices = tuple(itertools.chain(*argument_multiindices)) return optimise_expressions(expressions, argument_indices)
def split_variable(variable_ref, index, multiindices): """Splits a flexibly indexed variable along a concatenation index. :param variable_ref: flexibly indexed variable to split :param index: :py:class:`Concatenate` index to split along :param multiindices: one multiindex for each split variable :returns: generator of split indexed variables """ assert isinstance(variable_ref, FlexiblyIndexed) other_indices = list(variable_ref.index_ordering()) other_indices.remove(index) other_indices = tuple(other_indices) data = ComponentTensor(variable_ref, (index, ) + other_indices) slices = [slice(None)] * len(other_indices) shapes = [(other_index.extent, ) for other_index in other_indices] offset = 0 for multiindex in multiindices: shape = tuple(index.extent for index in multiindex) size = numpy.prod(shape, dtype=int) slice_ = slice(offset, offset + size) offset += size sub_ref = Indexed(reshape(view(data, slice_, *slices), shape, *shapes), multiindex + other_indices) sub_ref, = remove_componenttensors((sub_ref, )) yield sub_ref
def split_variable(variable_ref, index, multiindices): """Splits a flexibly indexed variable along a concatenation index. :param variable_ref: flexibly indexed variable to split :param index: :py:class:`Concatenate` index to split along :param multiindices: one multiindex for each split variable :returns: generator of split indexed variables """ assert isinstance(variable_ref, FlexiblyIndexed) other_indices = list(variable_ref.index_ordering()) other_indices.remove(index) other_indices = tuple(other_indices) data = ComponentTensor(variable_ref, (index,) + other_indices) slices = [slice(None)] * len(other_indices) shapes = [(other_index.extent,) for other_index in other_indices] offset = 0 for multiindex in multiindices: shape = tuple(index.extent for index in multiindex) size = numpy.prod(shape, dtype=int) slice_ = slice(offset, offset + size) offset += size sub_ref = Indexed(reshape(view(data, slice_, *slices), shape, *shapes), multiindex + other_indices) sub_ref, = remove_componenttensors((sub_ref,)) yield sub_ref
def preprocess_gem(expressions, replace_delta=True, remove_componenttensors=True): """Lower GEM nodes that cannot be translated to C directly.""" if remove_componenttensors: expressions = optimise.remove_componenttensors(expressions) if replace_delta: expressions = optimise.replace_delta(expressions) return expressions
def collect_monomials(expressions, classifier): """Refactorises expressions into a sum-of-products form, using distributivity rules (i.e. a*(b + c) -> a*b + a*c). Expansion proceeds until all "compound" expressions are broken up. :arg expressions: GEM expressions to refactorise :arg classifier: a function that can classify any GEM expression as ``ATOMIC``, ``COMPOUND``, or ``OTHER``. This classification drives the factorisation. :returns: list of :py:class:`MonomialSum`s :raises FactorisationError: Failed to break up some "compound" expressions with expansion. """ # Get ComponentTensors out of the way expressions = remove_componenttensors(expressions) # Get ListTensors out of the way must_unroll = [] # indices to unroll for node in traversal(expressions): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor) and classifier(node) == COMPOUND: must_unroll.extend(node.multiindex) if must_unroll: must_unroll = set(must_unroll) expressions = unroll_indexsum(expressions, predicate=lambda i: i in must_unroll) expressions = remove_componenttensors(expressions) # Expand Conditional nodes which are COMPOUND conditional_predicate = lambda node: classifier(node) == COMPOUND expressions = expand_conditional(expressions, conditional_predicate) # Finally, refactorise expressions mapper = Memoizer(_collect_monomials) mapper.classifier = classifier mapper.rename_map = make_rename_map() return list(map(mapper, expressions))
def test_delta_elimination(): i = Index() j = Index() k = Index() I = Identity(3) sum_indices = (i, j) factors = [Delta(i, j), Delta(i, k), Indexed(I, (j, k))] sum_indices, factors = delta_elimination(sum_indices, factors) factors = remove_componenttensors(factors) assert sum_indices == [] assert factors == [one, one, Indexed(I, (k, k))]
def _unconcatenate(cache, pairs): # Tail-call recursive core of unconcatenate. # Assumes that input has already been sanitised. concat_group = find_group([e for v, e in pairs]) if concat_group is None: return pairs # Get the index split concat_ref = next(iter(concat_group)) assert isinstance(concat_ref, Indexed) concat_expr, = concat_ref.children index, = concat_ref.multiindex assert isinstance(concat_expr, Concatenate) try: multiindices = cache[index] except KeyError: multiindices = tuple( tuple(Index(extent=d) for d in child.shape) for child in concat_expr.children) cache[index] = multiindices def cut(node): """No need to rebuild expression of independent of the relevant concatenation index.""" return index not in node.free_indices # Build Concatenate node replacement mappings mappings = [{} for i in range(len(multiindices))] for concat_ref in concat_group: concat_expr, = concat_ref.children for i in range(len(multiindices)): sub_ref = Indexed(concat_expr.children[i], multiindices[i]) sub_ref, = remove_componenttensors((sub_ref, )) mappings[i][concat_ref] = sub_ref # Finally, split assignment pairs split_pairs = [] for var, expr in pairs: if index not in var.free_indices: split_pairs.append((var, expr)) else: for v, m in zip(split_variable(var, index, multiindices), mappings): split_pairs.append((v, replace_node(expr, m, cut))) # Run again, there may be other Concatenate groups return _unconcatenate(cache, split_pairs)
def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters): # Concatenate expressions = concatenate(expressions) # Unroll max_extent = parameters["unroll_indexsum"] if max_extent: def predicate(index): return index.extent <= max_extent expressions = unroll_indexsum(expressions, predicate=predicate) # Refactorise def classify(quadrature_indices, expression): if not quadrature_indices.intersection(expression.free_indices): return OTHER elif isinstance(expression, gem.Indexed) and isinstance( expression.children[0], gem.Literal): return ATOMIC else: return COMPOUND classifier = partial(classify, set(quadrature_multiindex)) result = [] for expr, monomial_sum in zip(expressions, collect_monomials(expressions, classifier)): # Select quadrature indices that are present quadrature_indices = set(index for index in quadrature_multiindex if index in expr.free_indices) products = [] for sum_indices, factors, rest in monomial_sum: # Collapse quadrature literals for each monomial if factors or quadrature_indices: replacement = einsum(remove_componenttensors(factors), quadrature_indices) else: replacement = gem.Literal(1) # Rebuild expression products.append( gem.IndexSum(gem.Product(replacement, rest), sum_indices)) result.append(reduce(gem.Sum, products, gem.Zero())) return result
def _unconcatenate(cache, pairs): # Tail-call recursive core of unconcatenate. # Assumes that input has already been sanitised. concat_group = find_group([e for v, e in pairs]) if concat_group is None: return pairs # Get the index split concat_ref = next(iter(concat_group)) assert isinstance(concat_ref, Indexed) concat_expr, = concat_ref.children index, = concat_ref.multiindex assert isinstance(concat_expr, Concatenate) try: multiindices = cache[index] except KeyError: multiindices = tuple(tuple(Index(extent=d) for d in child.shape) for child in concat_expr.children) cache[index] = multiindices def cut(node): """No need to rebuild expression of independent of the relevant concatenation index.""" return index not in node.free_indices # Build Concatenate node replacement mappings mappings = [{} for i in range(len(multiindices))] for concat_ref in concat_group: concat_expr, = concat_ref.children for i in range(len(multiindices)): sub_ref = Indexed(concat_expr.children[i], multiindices[i]) sub_ref, = remove_componenttensors((sub_ref,)) mappings[i][concat_ref] = sub_ref # Finally, split assignment pairs split_pairs = [] for var, expr in pairs: if index not in var.free_indices: split_pairs.append((var, expr)) else: for v, m in zip(split_variable(var, index, multiindices), mappings): split_pairs.append((v, replace_node(expr, m, cut))) # Run again, there may be other Concatenate groups return _unconcatenate(cache, split_pairs)
def unconcatenate(pairs, cache=None): """Splits a list of (indexed variable, expression) pairs along :py:class:`Concatenate` nodes embedded in the expressions. :param pairs: list of (indexed variable, expression) pairs :param cache: index splitting cache :py:class:`dict` (optional) :returns: list of (indexed variable, expression) pairs """ # Set up cache if cache is None: cache = {} # Eliminate index renaming due to ComponentTensor nodes exprs = remove_componenttensors([e for v, e in pairs]) pairs = [(v, e) for (v, _), e in zip(pairs, exprs)] return _unconcatenate(cache, pairs)
def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters): # Concatenate expressions = concatenate(expressions) # Unroll max_extent = parameters["unroll_indexsum"] if max_extent: def predicate(index): return index.extent <= max_extent expressions = unroll_indexsum(expressions, predicate=predicate) # Refactorise def classify(quadrature_indices, expression): if not quadrature_indices.intersection(expression.free_indices): return OTHER elif isinstance(expression, gem.Indexed) and isinstance(expression.children[0], gem.Literal): return ATOMIC else: return COMPOUND classifier = partial(classify, set(quadrature_multiindex)) result = [] for expr, monomial_sum in zip(expressions, collect_monomials(expressions, classifier)): # Select quadrature indices that are present quadrature_indices = set(index for index in quadrature_multiindex if index in expr.free_indices) products = [] for sum_indices, factors, rest in monomial_sum: # Collapse quadrature literals for each monomial if factors or quadrature_indices: replacement = einsum(remove_componenttensors(factors), quadrature_indices) else: replacement = gem.Literal(1) # Rebuild expression products.append(gem.IndexSum(gem.Product(replacement, rest), sum_indices)) result.append(reduce(gem.Sum, products, gem.Zero())) return result
def compile_integral(integral_data, form_data, prefix, parameters, interface=firedrake_interface): """Compiles a UFL integral into an assembly kernel. :arg integral_data: UFL integral data :arg form_data: UFL form data :arg prefix: kernel name will start with this string :arg parameters: parameters object :arg interface: backend module for the kernel interface :returns: a kernel constructed by the kernel interface """ if parameters is None: parameters = default_parameters() else: _ = default_parameters() _.update(parameters) parameters = _ # Remove these here, they're handled below. if parameters.get("quadrature_degree") in ["auto", "default", None, -1, "-1"]: del parameters["quadrature_degree"] if parameters.get("quadrature_rule") in ["auto", "default", None]: del parameters["quadrature_rule"] integral_type = integral_data.integral_type interior_facet = integral_type.startswith("interior_facet") mesh = integral_data.domain cell = integral_data.domain.ufl_cell() arguments = form_data.preprocessed_form.arguments() fiat_cell = as_fiat_cell(cell) integration_dim, entity_ids = lower_integral_type(fiat_cell, integral_type) argument_indices = tuple(gem.Index(name=name) for arg, name in zip(arguments, ['j', 'k'])) quadrature_indices = [] # Dict mapping domains to index in original_form.ufl_domains() domain_numbering = form_data.original_form.domain_numbering() builder = interface.KernelBuilder(integral_type, integral_data.subdomain_id, domain_numbering[integral_data.domain]) return_variables = builder.set_arguments(arguments, argument_indices) coordinates = ufl_utils.coordinate_coefficient(mesh) if ufl_utils.is_element_affine(mesh.ufl_coordinate_element()): # For affine mesh geometries we prefer code generation that # composes well with optimisations. builder.set_coordinates(coordinates, mode='list_tensor') else: # Otherwise we use the approach that might be faster (?) builder.set_coordinates(coordinates) builder.set_coefficients(integral_data, form_data) # Map from UFL FiniteElement objects to Index instances. This is # so we reuse Index instances when evaluating the same coefficient # multiple times with the same table. Occurs, for example, if we # have multiple integrals here (and the affine coordinate # evaluation can be hoisted). index_cache = collections.defaultdict(gem.Index) kernel_cfg = dict(interface=builder, ufl_cell=cell, precision=parameters["precision"], integration_dim=integration_dim, entity_ids=entity_ids, argument_indices=argument_indices, index_cache=index_cache) kernel_cfg["facetarea"] = facetarea_generator(mesh, coordinates, kernel_cfg, integral_type) kernel_cfg["cellvolume"] = cellvolume_generator(mesh, coordinates, kernel_cfg) irs = [] for integral in integral_data.integrals: params = {} # Record per-integral parameters params.update(integral.metadata()) if params.get("quadrature_rule") == "default": del params["quadrature_rule"] # parameters override per-integral metadata params.update(parameters) # Check if the integral has a quad degree attached, otherwise use # the estimated polynomial degree attached by compute_form_data quadrature_degree = params.get("quadrature_degree", params["estimated_polynomial_degree"]) integration_cell = fiat_cell.construct_subelement(integration_dim) quad_rule = params.get("quadrature_rule", create_quadrature(integration_cell, quadrature_degree)) if not isinstance(quad_rule, QuadratureRule): raise ValueError("Expected to find a QuadratureRule object, not a %s" % type(quad_rule)) integrand = ufl_utils.replace_coordinates(integral.integrand(), coordinates) integrand = ufl_utils.split_coefficients(integrand, builder.coefficient_split) quadrature_index = gem.Index(name='ip') quadrature_indices.append(quadrature_index) config = kernel_cfg.copy() config.update(quadrature_rule=quad_rule, point_index=quadrature_index) ir = fem.compile_ufl(integrand, interior_facet=interior_facet, **config) if parameters["unroll_indexsum"]: ir = opt.unroll_indexsum(ir, max_extent=parameters["unroll_indexsum"]) irs.append([(gem.IndexSum(expr, quadrature_index) if quadrature_index in expr.free_indices else expr) for expr in ir]) # Sum the expressions that are part of the same restriction ir = list(reduce(gem.Sum, e, gem.Zero()) for e in zip(*irs)) # Need optimised roots for COFFEE ir = opt.remove_componenttensors(ir) # Look for cell orientations in the IR if builder.needs_cell_orientations(ir): builder.require_cell_orientations() impero_c = impero_utils.compile_gem(return_variables, ir, tuple(quadrature_indices) + argument_indices, remove_zeros=True) # Generate COFFEE index_names = [(index, index.name) for index in argument_indices] if len(quadrature_indices) == 1: index_names.append((quadrature_indices[0], 'ip')) else: for i, quadrature_index in enumerate(quadrature_indices): index_names.append((quadrature_index, 'ip_%d' % i)) body = generate_coffee(impero_c, index_names, parameters["precision"], ir, argument_indices) kernel_name = "%s_%s_integral_%s" % (prefix, integral_type, integral_data.subdomain_id) return builder.construct_kernel(kernel_name, body)