def unroll_indexsum(expressions, max_extent): """Unrolls IndexSums below a specified extent. :arg expressions: list of expression DAGs :arg max_extent: maximum extent for which IndexSums are unrolled :returns: list of expression DAGs with some unrolled IndexSums """ mapper = Memoizer(_unroll_indexsum) mapper.max_extent = max_extent return list(map(mapper, expressions))
def ffc_rounding(expression, epsilon): """Perform FFC rounding of FIAT tabulation matrices on the literals of a GEM expression. :arg expression: GEM expression :arg epsilon: tolerance limit for rounding """ mapper = Memoizer(literal_rounding) mapper.epsilon = epsilon return mapper(expression)
def unroll_indexsum(expressions, predicate): """Unrolls IndexSums below a specified extent. :arg expressions: list of expression DAGs :arg predicate: a predicate function on :py:class:`Index` objects that tells whether to unroll a particular index :returns: list of expression DAGs with some unrolled IndexSums """ mapper = Memoizer(_unroll_indexsum) mapper.predicate = predicate return list(map(mapper, expressions))
def replace_node(expression, mapping, cut=None): """Replace subexpressions using a given mapping. :param expression: a GEM expression :param mapping: a :py:class:`dict` containing the substitutions :param cut: cutting predicate; if returns true, it is assumed that no replacements would take place in the subexpression. """ mapper = Memoizer(_replace_node) mapper.mapping = mapping mapper.cut = cut or (lambda node: False) return mapper(expression)
def expand_conditional(expressions, predicate): """Applies the following substitution rule on selected :py:class:`Conditional`s: Conditional(a, b, c) => Conditional(a, 1, 0)*b + Conditional(a, 0, 1)*c :arg expressions: expression DAG roots :arg predicate: a predicate function on :py:class:`Conditional`s to determine whether to apply the substitution rule or not :returns: expression DAG roots with some :py:class:`Conditional` nodes expanded """ mapper = Memoizer(_expand_conditional) mapper.predicate = predicate return list(map(mapper, expressions))
def remove_if(expr: Expr, predicate: Callable[[Expr], bool]) -> Expr: """Remove terms from an expression that match a predicate. This is done by replacing matching terms by an appropriately-shaped :class:`~.Zero`, so only works to remove terms that are linear in the expression. :arg expr: the expression to remove terms from. :arg predicate: a function that indicates if an expression should be kept or not. :returns: A potentially new expression with terms matching the predicate removed.""" mapper = Memoizer(_filter) mapper.predicate = predicate return mapper(expr)
def assign_dtypes(expressions, scalar_type): """Assign numpy data types to expressions. Used for declaring temporaries when converting from Impero to lower level code. :arg expressions: List of GEM expressions. :arg scalar_type: Default scalar type. :returns: list of tuples (expression, dtype).""" mapper = Memoizer(_assign_dtype) mapper.scalar_type = scalar_type if scalar_type.kind == "c": mapper.real_type = numpy.finfo(scalar_type).dtype else: mapper.real_type = scalar_type return [(e, mapper(e)) for e in expressions]
def flatten(expressions): """Flatten Concatenate nodes, and destroy the structure they express. :arg expressions: a multi-root expression DAG """ mapper = Memoizer(_flatten) return list(map(mapper, expressions))
def collect_monomials(expressions, classifier): """Refactorises expressions into a sum-of-products form, using distributivity rules (i.e. a*(b + c) -> a*b + a*c). Expansion proceeds until all "compound" expressions are broken up. :arg expressions: GEM expressions to refactorise :arg classifier: a function that can classify any GEM expression as ``ATOMIC``, ``COMPOUND``, or ``OTHER``. This classification drives the factorisation. :returns: list of :py:class:`MonomialSum`s :raises FactorisationError: Failed to break up some "compound" expressions with expansion. """ # Get ComponentTensors out of the way expressions = remove_componenttensors(expressions) # Get ListTensors out of the way must_unroll = [] # indices to unroll for node in traversal(expressions): if isinstance(node, Indexed): child, = node.children if isinstance(child, ListTensor) and classifier(node) == COMPOUND: must_unroll.extend(node.multiindex) if must_unroll: must_unroll = set(must_unroll) expressions = unroll_indexsum(expressions, predicate=lambda i: i in must_unroll) expressions = remove_componenttensors(expressions) # Expand Conditional nodes which are COMPOUND conditional_predicate = lambda node: classifier(node) == COMPOUND expressions = expand_conditional(expressions, conditional_predicate) # Finally, refactorise expressions mapper = Memoizer(_collect_monomials) mapper.classifier = classifier mapper.rename_map = make_rename_map() return list(map(mapper, expressions))
def __init__(self, lvalue, rvalue): """ :arg lvalue: The coefficient to assign into. :arg rvalue: The pointwise expression. """ if not isinstance(lvalue, ufl.Coefficient): raise ValueError("lvalue for pointwise assignment must be a coefficient") self.lvalue = lvalue self.rvalue = ufl.as_ufl(rvalue) n = len(self.lvalue.function_space()) if n > 1: self.splitter = Memoizer(_split) self.splitter.n = n
def strip_dt_form(F): if isinstance(F, Zero): # Avoid applying the time derivative stripper to zero forms return F stripper = Memoizer(strip_dt) # Strip dt from all the integrals in the form Fnew = Form([ i.reconstruct(integrand=stripper(i.integrand())) for i in F.integrals() ]) # Return the form stripped of its time derivatives return Fnew
def check_integrals(integrals: List[Integral], expect_time_derivative: bool = True) -> List[Integral]: """Check a list of integrals for linearity in the time derivative. :arg integrals: list of integrals. :arg expect_time_derivative: Are we expecting to see a time derivative? :raises ValueError: if we are expecting a time derivative and don't see one, or time derivatives are applied nonlinearly, to more than one coefficient, or more than first order.""" mapper = Memoizer(_check_time_terms) time_derivatives = set() for integral in integrals: time_derivatives.update(mapper(integral.integrand())) howmany = int(expect_time_derivative) if len(time_derivatives - {()}) != howmany: raise ValueError(f"Expecting time derivative applied to {howmany}" f"coefficients, not {len(time_derivatives - {()})}") return integrals
def optimise_expressions(expressions, argument_indices): """Perform loop optimisations on GEM DAGs :arg expressions: list of GEM DAGs :arg argument_indices: tuple of argument indices :returns: list of optimised GEM DAGs """ # Skip optimisation for if Failure node is present for n in traversal(expressions): if isinstance(n, Failure): return expressions # Apply argument factorisation unconditionally classifier = partial(spectral.classify, set(argument_indices), delta_inside=Memoizer(spectral._delta_inside)) monomial_sums = collect_monomials(expressions, classifier) return [ optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums ]
def replace_delta(expressions): """Lowers all Deltas in a multi-root expression DAG.""" mapper = Memoizer(_replace_delta) return list(map(mapper, expressions))
def slate2gem(expression, options): mapper = Memoizer(_slate2gem) mapper.var2terminal = OrderedDict() mapper.matfree = options["replace_mul"] return mapper(expression), mapper.var2terminal
def drop_double_transpose(expr): """Remove double transposes from optimised Slate expression.""" from gem.node import Memoizer mapper = Memoizer(_drop_double_transpose) a = mapper(expr) return a
def flatten(var_reps, index_cache): quadrature_indices = OrderedDict() pairs = [] # assignment pairs for variable, reps in var_reps: # Extract argument indices argument_indices, = set(r.argument_indices for r in reps) assert set(variable.free_indices) == set(argument_indices) # Extract and verify expressions expressions = [r.expression for r in reps] assert all( set(e.free_indices) <= set(argument_indices) for e in expressions) # Save assignment pair pairs.append((variable, reduce(Sum, expressions))) # Collect quadrature_indices for r in reps: quadrature_indices.update(zip_longest(r.quadrature_multiindex, ())) # Split Concatenate nodes pairs = unconcatenate(pairs, cache=index_cache) def group_key(pair): variable, expression = pair return frozenset(variable.free_indices) delta_inside = Memoizer(_delta_inside) # Variable ordering after delta cancellation narrow_variables = OrderedDict() # Assignments are variable -> MonomialSum map delta_simplified = defaultdict(MonomialSum) # Group assignment pairs by argument indices for free_indices, pair_group in groupby(pairs, group_key): variables, expressions = zip(*pair_group) # Argument factorise expressions classifier = partial(classify, set(free_indices), delta_inside=delta_inside) monomial_sums = collect_monomials(expressions, classifier) # For each monomial, apply delta cancellation and insert # result into delta_simplified. for variable, monomial_sum in zip(variables, monomial_sums): for monomial in monomial_sum: var, s, a, r = delta_elimination(variable, *monomial) narrow_variables.setdefault(var) delta_simplified[var].add(s, a, r) # Final factorisation for variable in narrow_variables: monomial_sum = delta_simplified[variable] # Collect sum indices applicable to the current MonomialSum sum_indices = set().union(*[m.sum_indices for m in monomial_sum]) # Put them in a deterministic order sum_indices = [i for i in quadrature_indices if i in sum_indices] # Sort for increasing index extent, this obtains the good # factorisation for triangle x interval cells. Python sort is # stable, so in the common case when index extents are equal, # the previous deterministic ordering applies which is good # for getting smaller temporaries. sum_indices = sorted(sum_indices, key=lambda index: index.extent) # Apply sum factorisation combined with COFFEE technology expression = sum_factorise(variable, sum_indices, monomial_sum) yield (variable, expression)
def replace_division(expressions): """Replace divisions with multiplications in expressions""" mapper = Memoizer(_replace_division) return list(map(mapper, expressions))
def slate2gem(expression): mapper = Memoizer(_slate2gem) mapper.var2terminal = OrderedDict() return mapper(expression), mapper.var2terminal