def freeze_expression(expr): """ Reconstruct ``expr`` turning all :class:`sympy.Mul` and :class:`sympy.Add` into, respectively, :class:`devito.Mul` and :class:`devito.Add`. """ if expr.is_Atom or expr.is_Indexed: return expr elif expr.is_Add: rebuilt_args = [freeze_expression(e) for e in expr.args] return Add(*rebuilt_args, evaluate=False) elif expr.is_Mul: rebuilt_args = [freeze_expression(e) for e in expr.args] return Mul(*rebuilt_args, evaluate=False) elif expr.is_Pow: rebuilt_args = [freeze_expression(e) for e in expr.args] return Pow(*rebuilt_args, evaluate=False) elif expr.is_Equality: rebuilt_args = [freeze_expression(e) for e in expr.args] if isinstance(expr, FrozenExpr): # Avoid dropping metadata associated with /expr/ return expr.func(*rebuilt_args) else: return Eq(*rebuilt_args, evaluate=False) else: return expr.func(*[freeze_expression(e) for e in expr.args])
def freeze(expr): """ Reconstruct ``expr`` turning all sympy.Mul and sympy.Add into FrozenExpr equivalents. """ if expr.is_Atom or expr.is_Indexed: return expr elif expr.is_Add: rebuilt_args = [freeze(e) for e in expr.args] return Add(*rebuilt_args, evaluate=False) elif expr.is_Mul: rebuilt_args = [freeze(e) for e in expr.args] return Mul(*rebuilt_args, evaluate=False) elif expr.is_Pow: rebuilt_args = [freeze(e) for e in expr.args] return Pow(*rebuilt_args, evaluate=False) elif expr.is_Equality: rebuilt_args = [freeze(e) for e in expr.args] if isinstance(expr, FrozenExpr): # Avoid dropping metadata associated with /expr/ return expr.func(*rebuilt_args) else: return Eq(*rebuilt_args, evaluate=False) else: return expr.func(*[freeze(e) for e in expr.args])
def freeze_expression(expr): """ Reconstruct ``expr`` turning all :class:`sympy.Mul` and :class:`sympy.Add` into, respectively, :class:`devito.Mul` and :class:`devito.Add`. """ if expr.is_Atom or expr.is_Indexed: return expr elif expr.is_Add: rebuilt_args = [freeze_expression(e) for e in expr.args] return Add(*rebuilt_args, evaluate=False) elif expr.is_Mul: rebuilt_args = [freeze_expression(e) for e in expr.args] return Mul(*rebuilt_args, evaluate=False) elif expr.is_Equality: rebuilt_args = [freeze_expression(e) for e in expr.args] return Eq(*rebuilt_args, evaluate=False) else: return expr.func(*[freeze_expression(e) for e in expr.args])
def xreplace_constrained(exprs, make, rule=None, costmodel=lambda e: True, repeat=False): """ Unlike ``xreplace``, which replaces all objects specified in a mapper, this function replaces all objects satisfying two criteria: :: * The "matching rule" -- a function returning True if a node within ``expr`` satisfies a given property, and as such should be replaced; * A "cost model" -- a function triggering replacement only if a certain cost (e.g., operation count) is exceeded. This function is optional. Note that there is not necessarily a relationship between the set of nodes for which the matching rule returns True and those nodes passing the cost model check. It might happen for example that, given the expression ``a + b``, all of ``a``, ``b``, and ``a + b`` satisfy the matching rule, but only ``a + b`` satisfies the cost model. :param exprs: The target SymPy expression, or a collection of SymPy expressions. :param make: Either a mapper M: K -> V, indicating how to replace an expression in K with a symbol in V, or a function with internal state that, when called, returns unique symbols. :param rule: The matching rule (a lambda function). May be left unspecified if ``make`` is a mapper. :param costmodel: The cost model (a lambda function, optional). :param repeat: Repeatedly apply ``xreplace`` until no more replacements are possible (optional, defaults to False). """ found = OrderedDict() rebuilt = [] # Define /replace()/ based on the user-provided /make/ if isinstance(make, dict): rule = rule if rule is not None else (lambda i: i in make) replace = lambda i: make[i] else: assert callable(make) and callable(rule) def replace(expr): temporary = found.get(expr) if not temporary: temporary = make() found[expr] = temporary return temporary def run(expr): if expr.is_Atom or expr.is_Indexed: return expr, rule(expr) elif expr.is_Pow: base, flag = run(expr.base) if flag and costmodel(base): return expr.func(replace(base), expr.exp, evaluate=False), False else: return expr.func(base, expr.exp, evaluate=False), flag else: children = [run(a) for a in expr.args] matching = [a for a, flag in children if flag] other = [a for a, _ in children if a not in matching] if matching: matched = expr.func(*matching, evaluate=False) if len(matching) == len(children) and rule(expr): # Go look for longer expressions first return matched, True elif rule(matched) and costmodel(matched): # Replace what I can replace, then give up rebuilt = expr.func(*(other + [replace(matched)]), evaluate=False) return rebuilt, False else: # Replace flagged children, then give up replaced = [replace(e) for e in matching if costmodel(e)] unreplaced = [e for e in matching if not costmodel(e)] rebuilt = expr.func(*(other + replaced + unreplaced), evaluate=False) return rebuilt, False return expr.func(*other, evaluate=False), False # Process the provided expressions for expr in as_tuple(exprs): assert expr.is_Equality root = expr.rhs while True: ret, flag = run(root) if isinstance(make, dict) and root.is_Atom and flag: rebuilt.append( expr.func(expr.lhs, replace(root), evaluate=False)) break elif repeat and ret != root: root = ret else: rebuilt.append(expr.func(expr.lhs, ret, evaluate=False)) break # Post-process the output found = [Eq(v, k) for k, v in found.items()] return found + rebuilt, found
def yreplace(exprs, make, rule=None, costmodel=lambda e: True, repeat=False, eager=False): """ Unlike SymPy's ``xreplace``, which performs structural replacement based on a mapper, ``yreplace`` applies replacements using two callbacks: * The "matching rule" -- a boolean function telling whether an expression honors a certain property. * The "cost model" -- a boolean function telling whether an expression exceeds a certain (e.g., operation count) cost. Parameters ---------- exprs : expr-like or list of expr-like One or more expressions searched for replacements. make : dict or callable Either a mapper of substitution rules (just like in ``xreplace``), or or a callable returning unique symbols each time it is called. rule : callable, optional The matching rule (see above). Unnecessary if ``make`` is a dict. costmodel : callable, optional The cost model (see above). repeat : bool, optional If True, repeatedly apply ``xreplace`` until no more replacements are possible. Defaults to False. eager : bool, optional If True, replaces an expression ``e`` as soon as the condition ``rule(e) and costmodel(e)`` is True. Otherwise, the search continues for larger, more expensive expressions. Defaults to False. Notes ----- In general, there is no relationship between the set of expressions for which the matching rule gives True and the set of expressions passing the cost test. For example, in the expression `a + b` all of `a`, `b` and `a+b` may satisfy the matching rule, whereas only `a+b` satisfy the cost test. Likewise, an expression may pass the cost test, but not satisfy the matching rule. """ found = OrderedDict() rebuilt = [] # Define `replace()` based on the user-provided `make` if isinstance(make, dict): rule = rule if rule is not None else (lambda i: i in make) replace = lambda i: make[i] else: assert callable(make) and callable(rule) def replace(expr): temporary = found.get(expr) if not temporary: temporary = make() found[expr] = temporary return temporary def run(expr): if expr.is_Atom or expr.is_Indexed: return expr, rule(expr) elif expr.is_Pow: base, flag = run(expr.base) if flag and costmodel(base): return expr.func(replace(base), expr.exp, evaluate=False), False elif flag and costmodel(expr): return replace(expr), False else: return expr.func(base, expr.exp, evaluate=False), rule(expr) else: children = [run(a) for a in expr.args] matching = [a for a, flag in children if flag] other = [a for a, _ in children if a not in matching] if not matching: return expr.func(*other, evaluate=False), False if eager is False: matched = expr.func(*matching, evaluate=False) if len(matching) == len(children) and rule(expr): # Go look for larger expressions first return matched, True elif rule(matched) and costmodel(matched): # E.g.: a*b*c*d -> a*r0 rebuilt = expr.func(*(other + [replace(matched)]), evaluate=False) return rebuilt, False else: # E.g.: a*b*c*d -> a*r0*r1*r2 replaced = [replace(e) for e in matching if costmodel(e)] unreplaced = [e for e in matching if not costmodel(e)] rebuilt = expr.func(*(other + replaced + unreplaced), evaluate=False) return rebuilt, False else: replaceable, unreplaced = split(matching, lambda e: costmodel(e)) if replaceable: # E.g.: a*b*c*d -> a*r0*r1*r2 replaced = [replace(e) for e in replaceable] rebuilt = expr.func(*(other + replaced + unreplaced), evaluate=False) return rebuilt, False matched = expr.func(*matching, evaluate=False) if rule(matched) and costmodel(matched): if len(matching) == len(children): # E.g.: a*b*c*d -> r0 return replace(matched), False else: # E.g.: a*b*c*d -> a*r0 rebuilt = expr.func(*(other + [replace(matched)]), evaluate=False) return rebuilt, False elif len(matching) == len(children) and rule(expr): # Go look for larger expressions return matched, True else: # E.g.: a*b*c*d; a,b,a*b replaceable but not satisfying the cost # model, hence giving up as c,d,c*d aren't replaceable return expr.func(*(matching + other), evaluate=False), False # Process the provided expressions for expr in as_tuple(exprs): assert expr.is_Equality root = expr.rhs while True: ret, flag = run(root) if flag and costmodel(ret): if expr.lhs.function.is_Array: rebuilt.append(expr) else: # Have to replace the whole expr.rhs rebuilt.append( expr.func(expr.lhs, replace(root), evaluate=False)) break elif repeat and ret != root: root = ret else: rebuilt.append(expr.func(expr.lhs, ret, evaluate=False)) break # Post-process the output found = [Eq(v, k) for k, v in found.items()] return found + rebuilt, found