def unroll_indexsum(expressions, predicate): """Unrolls IndexSums below a specified extent. :arg expressions: list of expression DAGs :arg predicate: a predicate function on :py:class:`Index` objects that tells whether to unroll a particular index :returns: list of expression DAGs with some unrolled IndexSums """ mapper = Memoizer(_unroll_indexsum) mapper.predicate = predicate return list(map(mapper, expressions))
def expand_conditional(expressions, predicate): """Applies the following substitution rule on selected :py:class:`Conditional`s: Conditional(a, b, c) => Conditional(a, 1, 0)*b + Conditional(a, 0, 1)*c :arg expressions: expression DAG roots :arg predicate: a predicate function on :py:class:`Conditional`s to determine whether to apply the substitution rule or not :returns: expression DAG roots with some :py:class:`Conditional` nodes expanded """ mapper = Memoizer(_expand_conditional) mapper.predicate = predicate return list(map(mapper, expressions))
def remove_if(expr: Expr, predicate: Callable[[Expr], bool]) -> Expr: """Remove terms from an expression that match a predicate. This is done by replacing matching terms by an appropriately-shaped :class:`~.Zero`, so only works to remove terms that are linear in the expression. :arg expr: the expression to remove terms from. :arg predicate: a function that indicates if an expression should be kept or not. :returns: A potentially new expression with terms matching the predicate removed.""" mapper = Memoizer(_filter) mapper.predicate = predicate return mapper(expr)