def _aresame(a, b): """Return True if a and b are structurally the same, else False. Examples ======== To SymPy, 2.0 == 2: >>> from sympy import S, Symbol, cos, sin >>> 2.0 == S(2) True Since a simple 'same or not' result is sometimes useful, this routine was written to provide that query: >>> from sympy.core.basic import _aresame >>> _aresame(S(2.0), S(2)) False """ from sympy.utilities.iterables import preorder_traversal from itertools import izip for i, j in izip(preorder_traversal(a), preorder_traversal(b)): if i != j or type(i) != type(j): return False else: return True
def test_preorder_traversal(): expr = z + w * (x + y) expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x] expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y] expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z] assert list(preorder_traversal(expr)) in [expected1, expected2, expected3] expr = Piecewise((x, x < 1), (x**2, True)) assert list(preorder_traversal(expr)) == [ Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1, x, 1, ExprCondPair(x**2, True), x**2, x, 2, True ] assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [ x, 2, x**2, x, 0, 1, Tuple(x, 0, 1), Integral(x**2, Tuple(x, 0, 1)) ] assert list(postorder_traversal( ('abc', ('d', 'ef')))) == ['abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))] expr = (x**(y**z))**(x**(y**z)) expected = [(x**(y**z))**(x**(y**z)), x**(y**z), x**(y**z)] result = [] pt = preorder_traversal(expr) for i in pt: result.append(i) if i == x**(y**z): pt.skip() assert result == expected
def _aresame(a, b): """Return True if a and b are structurally the same, else False. Examples ======== To SymPy, 2.0 == 2: >>> from sympy import S, Symbol, cos, sin >>> 2.0 == S(2) True Since a simple 'same or not' result is sometimes useful, this routine was written to provide that query: >>> from sympy.core.basic import _aresame >>> _aresame(S(2.0), S(2)) False """ from sympy.utilities.iterables import preorder_traversal from itertools import izip for i, j in izip(preorder_traversal(a), preorder_traversal(b)): if i != j or type(i) != type(j): return False else: return True
def test_preorder_traversal(): expr = z+w*(x+y) expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x] expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y] expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z] assert list(preorder_traversal(expr)) in [expected1, expected2, expected3] expr = Piecewise((x,x<1),(x**2,True)) assert list(preorder_traversal(expr)) == [ Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1, x, 1, ExprCondPair(x**2, True), x**2, x, 2, True ] assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [ x, 2, x**2, x, 0, 1, Tuple(x, 0, 1), Integral(x**2, Tuple(x, 0, 1)) ] assert list(postorder_traversal(('abc', ('d', 'ef')))) == [ 'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))] expr = (x**(y**z)) ** (x**(y**z)) expected = [(x**(y**z))**(x**(y**z)), x**(y**z), x**(y**z)] result = [] pt = preorder_traversal(expr) for i in pt: result.append(i) if i == x**(y**z): pt.skip() assert result == expected
def _aresame(a, b): """Return True if a and b are structurally the same, else False. Examples ======== To SymPy, 2.0 == 2: >>> from sympy import S, Symbol, cos, sin >>> 2.0 == S(2) True The Basic.compare method will indicate that these are not the same, but the same method allows symbols with different assumptions to compare the same: >>> S(2).compare(2.0) -1 >>> Symbol('x').compare(Symbol('x', positive=True)) 0 The Basic.compare method will not work with instances of FunctionClass: >>> sin.compare(cos) Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: unbound method compare() must be called with sin instance as first ar gument (got FunctionClass instance instead) Since a simple 'same or not' result is sometimes useful, this routine was written to provide that query. """ from sympy.utilities.iterables import preorder_traversal from itertools import izip try: if a.compare(b) == 0 and a.is_Symbol and b.is_Symbol: return a.assumptions0 == b.assumptions0 except (TypeError, AttributeError): pass for i, j in izip(preorder_traversal(a), preorder_traversal(b)): if i == j and type(i) == type(j): continue return False return True
def test_postorder_traversal(): expr = z+w*(x+y) expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)] expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)] expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)] assert list(postorder_traversal(expr)) in [expected1, expected2, expected3] expr = Piecewise((x,x<1),(x**2,True)) assert list(postorder_traversal(expr)) == [ x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2, True, ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True)) ] assert list(preorder_traversal(Integral(x**2, (x, 0, 1)))) == [ Integral(x**2, (x, 0, 1)), x**2, x, 2, Tuple(x, 0, 1), x, 0, 1 ] assert list(preorder_traversal(('abc', ('d', 'ef')))) == [ ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
def test_postorder_traversal(): expr = z+w*(x+y) expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)] expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)] expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)] assert list(postorder_traversal(expr)) in [expected1, expected2, expected3] expr = Piecewise((x,x<1),(x**2,True)) assert list(postorder_traversal(expr)) == [ x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2, True, ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True)) ] assert list(preorder_traversal(Integral(x**2, (x, 0, 1)))) == [ Integral(x**2, (x, 0, 1)), x**2, x, 2, Tuple(x, 0, 1), x, 0, 1 ] assert list(preorder_traversal(('abc', ('d', 'ef')))) == [ ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
def test_preorder_traversal(): expr = z+w*(x+y) expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x] expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y] expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z] assert list(preorder_traversal(expr)) in [expected1, expected2, expected3] expr = Piecewise((x,x<1),(x**2,True)) assert list(preorder_traversal(expr)) == [ Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1, x, 1, ExprCondPair(x**2, True), x**2, x, 2, True ] assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [ x, 2, x**2, x, 0, 1, (0, 1), (x, (0, 1)), ((x, (0, 1)),), Integral(x**2, (x, 0, 1)) ] assert list(postorder_traversal(('abc', ('d', 'ef')))) == [ 'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
def sub_post(e): """ Replace Sub(x,y) with the canonical form Add(x, Mul(NegativeOne(-1), y)). """ replacements = [] for node in preorder_traversal(e): if assumed(node, 'is_Sub'): replacements.append((node, Add(node.args[0], Mul(-1, node.args[1])))) for node, replacement in replacements: e = e.subs(node, replacement) return e
def mycollectsimp(expr): from sympy import collect counts = {} for subtree in preorder_traversal(expr): if isinstance(subtree, Symbol): counts[subtree] = counts.get(subtree, 0)+1 counts = [(count, symbol) for symbol, count in counts.iteritems()] counts.sort() for count, symbol in counts: expr = collect(expr, symbol) return expr
def sub_post(e): """ Replace Sub(x,y) with the canonical form Add(x, Mul(NegativeOne(-1), y)). """ replacements = [] for node in preorder_traversal(e): if assumed(node, 'is_Sub'): replacements.append((node, Add(node.args[0], Mul(-1, node.args[1])))) for node, replacement in replacements: e = e.subs(node, replacement) return e
def sub_post(e): """ Replace Neg(x) with -x. """ replacements = [] for node in preorder_traversal(e): if isinstance(node, Neg): replacements.append((node, -node.args[0])) for node, replacement in replacements: e = e.xreplace({node: replacement}) return e
def sub_post(e): """ Replace Neg(x) with -x. """ replacements = [] for node in preorder_traversal(e): if isinstance(node, Neg): replacements.append((node, -node.args[0])) for node, replacement in replacements: e = e.xreplace({node: replacement}) return e
def test_preorder_traversal(): expr = z + w * (x + y) expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x] expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y] expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z] assert list(preorder_traversal(expr)) in [expected1, expected2, expected3] expr = Piecewise((x, x < 1), (x**2, True)) assert list(preorder_traversal(expr)) == [ Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1, x, 1, ExprCondPair(x**2, True), x**2, x, 2, True ] assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [ x, 2, x**2, x, 0, 1, (0, 1), (x, (0, 1)), ((x, (0, 1)), ), Integral(x**2, (x, 0, 1)) ] assert list(postorder_traversal( ('abc', ('d', 'ef')))) == ['abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
def denoms(eq, x=None): """Return (recursively) set of all denominators that appear in eq that contain any symbol in x; if x is None (default) then all denominators with symbols will be returned.""" from sympy.utilities.iterables import preorder_traversal if x is None: x = eq.free_symbols dens = set() pt = preorder_traversal(eq) for e in pt: if e.is_Pow or e.func is exp: n, d = e.as_numer_denom() if d in dens: pt.skip() elif d.has(*x): dens.add(d.as_base_exp()[0]) return dens
def denoms(eq, x=None): """Return (recursively) set of all denominators that appear in eq that contain any symbol in x; if x is None (default) then all denominators with symbols will be returned.""" from sympy.utilities.iterables import preorder_traversal if x is None: x = eq.free_symbols dens = set() pt = preorder_traversal(eq) for e in pt: if e.is_Pow or e.func is exp: n, d = e.as_numer_denom() if d in dens: pt.skip() elif d.has(*x): dens.add(d.as_base_exp()[0]) return dens
def peel_terms(s): lim_min = None lim_max = None for p in preorder_traversal(s): if isinstance(p,Sum): lmin = p.limits[0][1] lmax = p.limits[0][2] if lim_min: lim_min = max(lmin,lim_min) else: lim_min = lmin if lim_max: lim_max = min(lmax,lim_max) else: lim_max = lmax p = peel_sum_terms(lim_min, lim_max) new_s = rewrite(s,[(Sum,p)]) return new_s
def peel_terms(s): lim_min = None lim_max = None for p in preorder_traversal(s): if isinstance(p, Sum): lmin = p.limits[0][1] lmax = p.limits[0][2] if lim_min: lim_min = max(lmin, lim_min) else: lim_min = lmin if lim_max: lim_max = min(lmax, lim_max) else: lim_max = lmax p = peel_sum_terms(lim_min, lim_max) new_s = rewrite(s, [(Sum, p)]) return new_s
def _atomic(e): """Return atom-like quantities as far as substitution is concerned: Derivatives, Functions and Symbols. Don't return any 'atoms' that are inside such quantities unless they also appear outside, too. Examples ======== >>> from sympy import Derivative, Function, cos >>> from sympy.abc import x, y >>> from sympy.core.basic import _atomic >>> f = Function('f') >>> _atomic(x + y) set([x, y]) >>> _atomic(x + f(y)) set([x, f(y)]) >>> _atomic(Derivative(f(x), x) + cos(x) + y) set([y, cos(x), Derivative(f(x), x)]) """ from sympy import Derivative, Function, Symbol from sympy.utilities.iterables import preorder_traversal pot = preorder_traversal(e) seen = set() try: free = e.free_symbols except AttributeError: return set([e]) atoms = set() for p in pot: if p in seen: pot.skip() continue seen.add(p) if isinstance(p, Symbol) and p in free: atoms.add(p) elif isinstance(p, (Derivative, Function)): pot.skip() atoms.add(p) return atoms
def _atomic(e): """Return atom-like quantities as far as substitution is concerned: Derivatives, Functions and Symbols. Don't return any 'atoms' that are inside such quantities unless they also appear outside, too. Examples ======== >>> from sympy import Derivative, Function, cos >>> from sympy.abc import x, y >>> from sympy.core.basic import _atomic >>> f = Function('f') >>> _atomic(x + y) set([x, y]) >>> _atomic(x + f(y)) set([x, f(y)]) >>> _atomic(Derivative(f(x), x) + cos(x) + y) set([y, cos(x), Derivative(f(x), x)]) """ from sympy import Derivative, Function, Symbol from sympy.utilities.iterables import preorder_traversal pot = preorder_traversal(e) seen = set() try: free = e.free_symbols except AttributeError: return set([e]) atoms = set() for p in pot: if p in seen: pot.skip() continue seen.add(p) if isinstance(p, Symbol) and p in free: atoms.add(p) elif isinstance(p, (Derivative, Function)): pot.skip() atoms.add(p) return atoms
def sub_pre(e): """ Replace Add(x, Mul(NegativeOne(-1), y)) with Sub(x, y). """ replacements = [] for node in preorder_traversal(e): if assumed(node, 'is_Add'): positives = [] negatives = [] for arg in node.args: if (assumed(arg, 'is_Mul') and assumed(arg.args[0], 'is_number') and assumed(arg.args[0], 'is_negative')): negatives.append(Mul(-arg.args[0], *arg.args[1:])) else: positives.append(arg) if len(negatives) > 0: replacement = Sub(Add(*positives), Add(*negatives)) replacements.append((node, replacement)) for node, replacement in replacements: e = e.subs(node, replacement) return e
def sub_pre(e): """ Replace Add(x, Mul(NegativeOne(-1), y)) with Sub(x, y). """ replacements = [] for node in preorder_traversal(e): if assumed(node, 'is_Add'): positives = [] negatives = [] for arg in node.args: if assumed(arg, 'is_Mul'): a, b = arg.as_two_terms() if (assumed(a, 'is_number') and assumed(a, 'is_negative')): negatives.append(Mul(-a, b)) continue positives.append(arg) if len(negatives) > 0: replacement = Sub(Add(*positives), Add(*negatives)) replacements.append((node, replacement)) for node, replacement in replacements: e = e.subs(node, replacement) return e
def sub_pre(e): """ Replace Add(x, Mul(NegativeOne(-1), y)) with Sub(x, y). """ replacements = [] for node in preorder_traversal(e): if node.is_Add: positives = [] negatives = [] for arg in node.args: if arg.is_Mul: a, b = arg.as_two_terms() if (a.is_number and a.is_negative): negatives.append(Mul(-a, b)) continue positives.append(arg) if len(negatives) > 0: replacement = Sub(Add(*positives), Add(*negatives)) replacements.append((node, replacement)) for node, replacement in replacements: e = e.subs(node, replacement) return e
def singles(self): counter = self.size-1 while counter >= 0: command = self.commands[counter] if command.tag != "final": # do not consider end results # check for usage used = 0 tmp = None for later_command in self.commands[counter+1:]: if command.symbol in later_command.expr: for subtree in preorder_traversal(later_command.expr): if subtree == command.symbol: used += 1 tmp = later_command if isinstance(subtree, C.Pow) and subtree.exp == 2 and subtree.base == command.symbol: used += 1 tmp = later_command if used == 1: tmp.expr = tmp.expr.subs(command.symbol, command.expr) print "SINGLE", command del self.commands[counter] counter -= 1
def test_preorder_traversal(): expr = z+w*(x+y) expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x] expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y] expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z] assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The `numbered_symbols` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, `sympy.simplify.cse.cse_optimizations` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() muls = set() adds = set() to_eliminate = [] to_eliminate_ops_count = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. def insert(subtree): '''This helper will insert the subtree into to_eliminate while maintaining the ordering by op count and will skip the insertion if subtree is already present.''' ops_count = subtree.count_ops() index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count) # all i up to this index have op count <= the current op count # so check that subtree is not yet present from this index down # (if necessary) to zero. for i in xrange(index_to_insert - 1, -1, -1): if to_eliminate_ops_count[i] == ops_count and \ subtree == to_eliminate[i]: return # already have it to_eliminate_ops_count.insert(index_to_insert, ops_count) to_eliminate.insert(index_to_insert, subtree) for expr in exprs: pt = preorder_traversal(expr) for subtree in pt: if subtree.is_Atom: # Exclude atoms, since there is no point in renaming them. continue if subtree in seen_subexp: insert(subtree) pt.skip() continue if subtree.is_Mul: muls.add(subtree) elif subtree.is_Add: adds.add(subtree) seen_subexp.add(subtree) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = [set(a.args) for a in adds] for i in xrange(len(adds)): for j in xrange(i + 1, len(adds)): com = adds[i].intersection(adds[j]) if len(com) > 1: insert(Add(*com)) # remove this set of symbols so it doesn't appear again adds[i] = adds[i].difference(com) adds[j] = adds[j].difference(com) for k in xrange(j + 1, len(adds)): if not com.difference(adds[k]): adds[k] = adds[k].difference(com) # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # use SequenceMatcher on the nc part to find the longest common expression # in common between the two nc parts sm = difflib.SequenceMatcher() muls = [a.args_cnc() for a in muls] for i in xrange(len(muls)): if muls[i][1]: sm.set_seq1(muls[i][1]) for j in xrange(i + 1, len(muls)): # the commutative part in common ccom = muls[i][0].intersection(muls[j][0]) # the non-commutative part in common if muls[i][1] and muls[j][1]: # see if there is any chance of an nc match ncom = set(muls[i][1]).intersection(set(muls[j][1])) if len(ccom) + len(ncom) < 2: continue # now work harder to find the match sm.set_seq2(muls[j][1]) i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0, len(muls[j][1])) ncom = muls[i][1][i1:i1 + n] else: ncom = [] com = list(ccom) + ncom if len(com) < 2: continue insert(Mul(*com)) # remove ccom from all if there was no ncom; to update the nc part # would require finding the subexpr and then replacing it with a # dummy to keep bounding nc symbols from being identified as a # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be # identified as a subexpr which would not be right. if not ncom: muls[i][0] = muls[i][0].difference(ccom) for k in xrange(j, len(muls)): if not ccom.difference(muls[k][0]): muls[k][0] = muls[k][0].difference(ccom) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. for j in range(i+1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs
def get_index(term): for p in preorder_traversal(term): if type(p) == Indexed: return p.indices[0] return None
def get_index(term): for p in preorder_traversal(term): if type(p) == Indexed: return p.indices[0] return None
def test_preorder_traversal(): expr = z + w * (x + y) expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x] expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y] expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z] assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]
def _mask_nc(eq): """Return ``eq`` with non-commutative objects replaced with dummy symbols. A dictionary that can be used to restore the original values is returned: if it is None, the expression is noncommutative and cannot be made commutative. The third value returned is a list of any non-commutative symbols that appeared in the equation. Notes ===== All commutative objects (other than Symbol) will be replaced; if the only non-commutative obects are Symbols, if there is only 1 Symbol, it will be replaced; if there are more than one then they will not be replaced; the calling routine should handle replacements in this case since some care must be taken to keep track of the ordering of symbols when they occur within Muls. Examples ======== >>> from sympy.physics.secondquant import Commutator, NO, F, Fd >>> from sympy import Dummy, symbols >>> from sympy.abc import x, y >>> from sympy.core.exprtools import _mask_nc >>> A, B, C = symbols('A,B,C', commutative=False) >>> Dummy._count = 0 # reset for doctest purposes >>> _mask_nc(A**2 - x**2) (_0**2 - x**2, {_0: A}, []) >>> _mask_nc(A**2 - B**2) (A**2 - B**2, None, [A, B]) >>> _mask_nc(1 + x*Commutator(A, B)) (_1*x + 1, {_1: Commutator(A, B)}, [A, B]) >>> _mask_nc(NO(Fd(x)*F(y))) (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, []) """ expr = eq if expr.is_commutative: return eq, {}, [] # if there is only one nc symbol, it can be factored regularly but # polys is going to complain, so replace it with a dummy rep = [] nc_syms = [s for s in expr.free_symbols if not s.is_commutative] if len(nc_syms) == 1: nc = Dummy() rep.append((nc_syms.pop(), nc)) expr = expr.subs(rep) # even though the noncommutative symbol may be gone, the expression # might still appear noncommutative; if it's a non-elementary object # we will replace it, but if it is a Symbol, Add, Mul, Pow we leave # it alone. nc_syms.sort(key=default_sort_key) if nc_syms or not expr.is_commutative: pot = preorder_traversal(expr) for i, a in enumerate(pot): if any(a == r[0] for r in rep): pass elif ( not a.is_commutative and not (a.is_Symbol or a.is_Add or a.is_Mul or a.is_Pow) ): rep.append((a, Dummy())) else: continue # don't skip pot.skip() # don't go any further expr = expr.subs(rep) return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
def _mask_nc(eq): """Return ``eq`` with non-commutative objects replaced with dummy symbols. A dictionary that can be used to restore the original values is returned: if it is None, the expression is noncommutative and cannot be made commutative. The third value returned is a list of any non-commutative symbols that appear in the returned equation. Notes ===== All non-commutative objects other than Symbols are replaced with a non-commutative Symbol. Identical objects will be identified by identical symbols. If there is only 1 non-commutative object in an expression it will be replaced with a commutative symbol. Otherwise, the non-commutative entities are retained and the calling routine should handle replacements in this case since some care must be taken to keep track of the ordering of symbols when they occur within Muls. Examples ======== >>> from sympy.physics.secondquant import Commutator, NO, F, Fd >>> from sympy import Dummy, symbols, Sum, Mul, Basic >>> from sympy.abc import x, y >>> from sympy.core.exprtools import _mask_nc >>> A, B, C = symbols('A,B,C', commutative=False) >>> Dummy._count = 0 # reset for doctest purposes One nc-symbol: >>> _mask_nc(A**2 - x**2) (_0**2 - x**2, {_0: A}, []) Multiple nc-symbols: >>> _mask_nc(A**2 - B**2) (A**2 - B**2, None, [A, B]) An nc-object with nc-symbols but no others outside of it: >>> _mask_nc(1 + x*Commutator(A, B)) (_1*x + 1, {_1: Commutator(A, B)}, []) >>> _mask_nc(NO(Fd(x)*F(y))) (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, []) Multiple nc-objects: >>> eq = x*Commutator(A, B) + x*Commutator(A, C)*Commutator(A, B) >>> _mask_nc(eq) (x*_3*_4 + x*_4, {_3: Commutator(A, C), _4: Commutator(A, B)}, [_3, _4]) Multiple nc-objects and nc-symbols: >>> eq = A*Commutator(A, B) + B*Commutator(A, C) >>> _mask_nc(eq) (A*_6 + B*_5, {_5: Commutator(A, C), _6: Commutator(A, B)}, [_5, _6, A, B]) If there is an object that: - doesn't contain nc-symbols - but has arguments which derive from Basic, not Expr - and doesn't define an _eval_is_commutative routine then it will give False (or None?) for the is_commutative test. Such objects are also removed by this routine: >>> from sympy import Basic, Mul >>> eq = (1 + Mul(Basic(), Basic(), evaluate=False)) >>> eq.is_commutative False >>> _mask_nc(eq) (_7**2 + 1, {_7: Basic()}, []) """ expr = eq if expr.is_commutative: return eq, {}, [] # identify nc-objects; symbols and other rep = [] nc_obj = set() nc_syms = set() pot = preorder_traversal(expr) for i, a in enumerate(pot): if any(a == r[0] for r in rep): pot.skip() elif not a.is_commutative: if a.is_Symbol: nc_syms.add(a) elif not (a.is_Add or a.is_Mul or a.is_Pow): if all(s.is_commutative for s in a.free_symbols): rep.append((a, Dummy())) else: nc_obj.add(a) pot.skip() # If there is only one nc symbol or object, it can be factored regularly # but polys is going to complain, so replace it with a Dummy. if len(nc_obj) == 1 and not nc_syms: rep.append((nc_obj.pop(), Dummy())) elif len(nc_syms) == 1 and not nc_obj: rep.append((nc_syms.pop(), Dummy())) # Any remaining nc-objects will be replaced with an nc-Dummy and # identified as an nc-Symbol to watch out for while nc_obj: nc = Dummy(commutative=False) rep.append((nc_obj.pop(), nc)) nc_syms.add(nc) expr = expr.subs(rep) nc_syms = list(nc_syms) nc_syms.sort(key=default_sort_key) return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
def _mask_nc(eq): """Return ``eq`` with non-commutative objects replaced with dummy symbols. A dictionary that can be used to restore the original values is returned: if it is None, the expression is noncommutative and cannot be made commutative. The third value returned is a list of any non-commutative symbols that appeared in the equation. Notes ===== All commutative objects (other than Symbol) will be replaced; if the only non-commutative obects are Symbols, if there is only 1 Symbol, it will be replaced; if there are more than one then they will not be replaced; the calling routine should handle replacements in this case since some care must be taken to keep track of the ordering of symbols when they occur within Muls. Examples ======== >>> from sympy.physics.secondquant import Commutator, NO, F, Fd >>> from sympy import Dummy, symbols >>> from sympy.abc import x, y >>> from sympy.core.exprtools import _mask_nc >>> A, B, C = symbols('A,B,C', commutative=False) >>> Dummy._count = 0 # reset for doctest purposes >>> _mask_nc(A**2 - x**2) (_0**2 - x**2, {_0: A}, []) >>> _mask_nc(A**2 - B**2) (A**2 - B**2, None, [A, B]) >>> _mask_nc(1 + x*Commutator(A, B)) (_1*x + 1, {_1: Commutator(A, B)}, [A, B]) >>> _mask_nc(NO(Fd(x)*F(y))) (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, []) """ expr = eq if expr.is_commutative: return eq, {}, [] # if there is only one nc symbol, it can be factored regularly but # polys is going to complain, so replace it with a dummy rep = [] nc_syms = [s for s in expr.free_symbols if not s.is_commutative] if len(nc_syms) == 1: nc = Dummy() rep.append((nc_syms.pop(), nc)) expr = expr.subs(rep) # even though the noncommutative symbol may be gone, the expression # might still appear noncommutative; if it's a non-elementary object # we will replace it, but if it is a Symbol, Add, Mul, Pow we leave # it alone. nc_syms.sort(key=default_sort_key) if nc_syms or not expr.is_commutative: pot = preorder_traversal(expr) for i, a in enumerate(pot): if any(a == r[0] for r in rep): pass elif (not a.is_commutative and not (a.is_Symbol or a.is_Add or a.is_Mul or a.is_Pow)): rep.append((a, Dummy())) else: continue # don't skip pot.skip() # don't go any further expr = expr.subs(rep) return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, ``sympy.simplify.cse.cse_optimizations`` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() muls = set() adds = set() to_eliminate = [] to_eliminate_ops_count = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. def insert(subtree): '''This helper will insert the subtree into to_eliminate while maintaining the ordering by op count and will skip the insertion if subtree is already present.''' ops_count = subtree.count_ops() index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count) # all i up to this index have op count <= the current op count # so check that subtree is not yet present from this index down # (if necessary) to zero. for i in xrange(index_to_insert - 1, -1, -1): if to_eliminate_ops_count[i] == ops_count and \ subtree == to_eliminate[i]: return # already have it to_eliminate_ops_count.insert(index_to_insert, ops_count) to_eliminate.insert(index_to_insert, subtree) for expr in exprs: pt = preorder_traversal(expr) for subtree in pt: if subtree.is_Atom: # Exclude atoms, since there is no point in renaming them. continue if subtree in seen_subexp: insert(subtree) pt.skip() continue if subtree.is_Mul: muls.add(subtree) elif subtree.is_Add: adds.add(subtree) seen_subexp.add(subtree) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = [set(a.args) for a in adds] for i in xrange(len(adds)): for j in xrange(i + 1, len(adds)): com = adds[i].intersection(adds[j]) if len(com) > 1: insert(Add(*com)) # remove this set of symbols so it doesn't appear again adds[i] = adds[i].difference(com) adds[j] = adds[j].difference(com) for k in xrange(j + 1, len(adds)): if not com.difference(adds[k]): adds[k] = adds[k].difference(com) # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # use SequenceMatcher on the nc part to find the longest common expression # in common between the two nc parts sm = difflib.SequenceMatcher() muls = [a.args_cnc() for a in muls] for i in xrange(len(muls)): if muls[i][1]: sm.set_seq1(muls[i][1]) for j in xrange(i + 1, len(muls)): # the commutative part in common ccom = muls[i][0].intersection(muls[j][0]) # the non-commutative part in common if muls[i][1] and muls[j][1]: # see if there is any chance of an nc match ncom = set(muls[i][1]).intersection(set(muls[j][1])) if len(ccom) + len(ncom) < 2: continue # now work harder to find the match sm.set_seq2(muls[j][1]) i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0, len(muls[j][1])) ncom = muls[i][1][i1:i1 + n] else: ncom = [] com = list(ccom) + ncom if len(com) < 2: continue insert(Mul(*com)) # remove ccom from all if there was no ncom; to update the nc part # would require finding the subexpr and then replacing it with a # dummy to keep bounding nc symbols from being identified as a # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be # identified as a subexpr which would not be right. if not ncom: muls[i][0] = muls[i][0].difference(ccom) for k in xrange(j, len(muls)): if not ccom.difference(muls[k][0]): muls[k][0] = muls[k][0].difference(ccom) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. for j in range(i+1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The `numbered_symbols` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, `sympy.simplify.cse.cse_optimizations` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() to_eliminate = [] to_eliminate_ops_count = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. for expr in exprs: pt = preorder_traversal(expr) for subtree in pt: if subtree.is_Atom: # Exclude atoms, since there is no point in renaming them. continue elif subtree in seen_subexp: if subtree not in to_eliminate: ops_count = subtree.count_ops() index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count) to_eliminate_ops_count.insert(index_to_insert, ops_count) to_eliminate.insert(index_to_insert, subtree) pt.skip() else: seen_subexp.add(subtree) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. # WARNING: modifying iterated list in-place! I think it's fine, # but there might be clearer alternatives. for j in range(i + 1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs
def _mask_nc(eq): """Return ``eq`` with non-commutative objects replaced with dummy symbols. A dictionary that can be used to restore the original values is returned: if it is None, the expression is noncommutative and cannot be made commutative. The third value returned is a list of any non-commutative symbols that appear in the returned equation. Notes ===== All non-commutative objects other than Symbols are replaced with a non-commutative Symbol. Identical objects will be identified by identical symbols. If there is only 1 non-commutative object in an expression it will be replaced with a commutative symbol. Otherwise, the non-commutative entities are retained and the calling routine should handle replacements in this case since some care must be taken to keep track of the ordering of symbols when they occur within Muls. Examples ======== >>> from sympy.physics.secondquant import Commutator, NO, F, Fd >>> from sympy import Dummy, symbols, Sum >>> from sympy.abc import x, y >>> from sympy.core.exprtools import _mask_nc >>> A, B, C = symbols('A,B,C', commutative=False) >>> Dummy._count = 0 # reset for doctest purposes One nc-symbol: >>> _mask_nc(A**2 - x**2) (_0**2 - x**2, {_0: A}, []) Multiple nc-symbols: >>> _mask_nc(A**2 - B**2) (A**2 - B**2, None, [A, B]) An nc-object with nc-symbols but no others outside of it: >>> _mask_nc(1 + x*Commutator(A, B)) (_1*x + 1, {_1: Commutator(A, B)}, []) >>> _mask_nc(NO(Fd(x)*F(y))) (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, []) An nc-object without nc-symbols: >>> _mask_nc(x + x*Sum(x, (x, 1, 2))) (_3*x + x, {_3: Sum(x, (x, 1, 2))}, []) Multiple nc-objects: >>> eq = x*Commutator(A, B) + x*Commutator(A, C)*Commutator(A, B) >>> _mask_nc(eq) (x*_4*_5 + x*_5, {_4: Commutator(A, C), _5: Commutator(A, B)}, [_4, _5]) Multiple nc-objects and nc-symbols: >>> eq = A*Commutator(A, B) + B*Commutator(A, C) >>> _mask_nc(eq) (A*_7 + B*_6, {_6: Commutator(A, C), _7: Commutator(A, B)}, [_6, _7, A, B]) """ expr = eq if expr.is_commutative: return eq, {}, [] # identify nc-objects; symbols and other rep = [] nc_obj = set() nc_syms = set() pot = preorder_traversal(expr) for i, a in enumerate(pot): if any(a == r[0] for r in rep): pot.skip() elif not a.is_commutative: if a.is_Symbol: nc_syms.add(a) elif not (a.is_Add or a.is_Mul or a.is_Pow): if all(s.is_commutative for s in a.free_symbols): rep.append((a, Dummy())) else: nc_obj.add(a) pot.skip() # If there is only one nc symbol or object, it can be factored regularly # but polys is going to complain, so replace it with a Dummy. if len(nc_obj) == 1 and not nc_syms: rep.append((nc_obj.pop(), Dummy())) elif len(nc_syms) == 1 and not nc_obj: rep.append((nc_syms.pop(), Dummy())) # Any remaining nc-objects will be replaced with an nc-Dummy and # identified as an nc-Symbol to watch out for while nc_obj: nc = Dummy(commutative=False) rep.append((nc_obj.pop(), nc)) nc_syms.add(nc) expr = expr.subs(rep) nc_syms = list(nc_syms) nc_syms.sort(key=default_sort_key) return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
def find_double_pow(expr): for sub in preorder_traversal(expr): if isinstance(sub, C.Pow) and isinstance(sub.base, C.Pow): return sub