示例#1
0
文件: basic.py 项目: BDGLunde/sympy
def _aresame(a, b):
    """Return True if a and b are structurally the same, else False.

    Examples
    ========

    To SymPy, 2.0 == 2:

    >>> from sympy import S, Symbol, cos, sin
    >>> 2.0 == S(2)
    True

    Since a simple 'same or not' result is sometimes useful, this routine was
    written to provide that query:

    >>> from sympy.core.basic import _aresame
    >>> _aresame(S(2.0), S(2))
    False

    """
    from sympy.utilities.iterables import preorder_traversal
    from itertools import izip

    for i, j in izip(preorder_traversal(a), preorder_traversal(b)):
        if i != j or type(i) != type(j):
            return False
    else:
        return True
示例#2
0
def test_preorder_traversal():
    expr = z + w * (x + y)
    expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x]
    expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y]
    expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x, x < 1), (x**2, True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)),
        ExprCondPair(x, x < 1), x, x < 1, x, 1,
        ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1,
        Tuple(x, 0, 1),
        Integral(x**2, Tuple(x, 0, 1))
    ]
    assert list(postorder_traversal(
        ('abc',
         ('d',
          'ef')))) == ['abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]

    expr = (x**(y**z))**(x**(y**z))
    expected = [(x**(y**z))**(x**(y**z)), x**(y**z), x**(y**z)]
    result = []
    pt = preorder_traversal(expr)
    for i in pt:
        result.append(i)
        if i == x**(y**z):
            pt.skip()
    assert result == expected
示例#3
0
文件: basic.py 项目: BDGLunde/sympy
def _aresame(a, b):
    """Return True if a and b are structurally the same, else False.

    Examples
    ========

    To SymPy, 2.0 == 2:

    >>> from sympy import S, Symbol, cos, sin
    >>> 2.0 == S(2)
    True

    Since a simple 'same or not' result is sometimes useful, this routine was
    written to provide that query:

    >>> from sympy.core.basic import _aresame
    >>> _aresame(S(2.0), S(2))
    False

    """
    from sympy.utilities.iterables import preorder_traversal
    from itertools import izip

    for i, j in izip(preorder_traversal(a), preorder_traversal(b)):
        if i != j or type(i) != type(j):
            return False
    else:
        return True
示例#4
0
def test_preorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x]
    expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y]
    expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1,
        x, 1, ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1, Tuple(x, 0, 1),
        Integral(x**2, Tuple(x, 0, 1))
    ]
    assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
        'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]

    expr = (x**(y**z)) ** (x**(y**z))
    expected = [(x**(y**z))**(x**(y**z)), x**(y**z), x**(y**z)]
    result = []
    pt = preorder_traversal(expr)
    for i in pt:
        result.append(i)
        if i == x**(y**z):
            pt.skip()
    assert result == expected
示例#5
0
def _aresame(a, b):
    """Return True if a and b are structurally the same, else False.

    Examples
    ========

    To SymPy, 2.0 == 2:

    >>> from sympy import S, Symbol, cos, sin
    >>> 2.0 == S(2)
    True

    The Basic.compare method will indicate that these are not the same, but
    the same method allows symbols with different assumptions to compare the
    same:

    >>> S(2).compare(2.0)
    -1
    >>> Symbol('x').compare(Symbol('x', positive=True))
    0

    The Basic.compare method will not work with instances of FunctionClass:

    >>> sin.compare(cos)
    Traceback (most recent call last):
     File "<stdin>", line 1, in <module>
    TypeError: unbound method compare() must be called with sin instance as first ar
    gument (got FunctionClass instance instead)

    Since a simple 'same or not' result is sometimes useful, this routine was
    written to provide that query.

    """
    from sympy.utilities.iterables import preorder_traversal
    from itertools import izip

    try:
        if a.compare(b) == 0 and a.is_Symbol and b.is_Symbol:
            return a.assumptions0 == b.assumptions0
    except (TypeError, AttributeError):
        pass

    for i, j in izip(preorder_traversal(a), preorder_traversal(b)):
        if i == j and type(i) == type(j):
            continue
        return False
    return True
示例#6
0
def test_postorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
    expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
    expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
    assert list(postorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(postorder_traversal(expr)) == [
        x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2, True,
        ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
    ]
    assert list(preorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        Integral(x**2, (x, 0, 1)), x**2, x, 2, Tuple(x, 0, 1), x, 0, 1
    ]
    assert list(preorder_traversal(('abc', ('d', 'ef')))) == [
        ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
示例#7
0
def test_postorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
    expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
    expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
    assert list(postorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(postorder_traversal(expr)) == [
        x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2, True,
        ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
    ]
    assert list(preorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        Integral(x**2, (x, 0, 1)), x**2, x, 2, Tuple(x, 0, 1), x, 0, 1
    ]
    assert list(preorder_traversal(('abc', ('d', 'ef')))) == [
        ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
示例#8
0
def test_preorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x]
    expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y]
    expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x,x<1),(x**2,True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)), ExprCondPair(x, x < 1), x, x < 1,
        x, 1, ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1, (0, 1), (x, (0, 1)), ((x, (0, 1)),),
        Integral(x**2, (x, 0, 1))
    ]
    assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
        'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
示例#9
0
def sub_post(e):
    """ Replace Sub(x,y) with the canonical form Add(x, Mul(NegativeOne(-1), y)).
    """
    replacements = []
    for node in preorder_traversal(e):
        if assumed(node, 'is_Sub'):
            replacements.append((node, Add(node.args[0], Mul(-1, node.args[1]))))
    for node, replacement in replacements:
        e = e.subs(node, replacement)

    return e
示例#10
0
def mycollectsimp(expr):
    from sympy import collect
    counts = {}
    for subtree in preorder_traversal(expr):
        if isinstance(subtree, Symbol):
            counts[subtree] = counts.get(subtree, 0)+1
    counts = [(count, symbol) for symbol, count in counts.iteritems()]
    counts.sort()
    for count, symbol in counts:
        expr = collect(expr, symbol)
    return expr
示例#11
0
文件: cse_opts.py 项目: Aang/sympy
def sub_post(e):
    """ Replace Sub(x,y) with the canonical form Add(x, Mul(NegativeOne(-1), y)).
    """
    replacements = []
    for node in preorder_traversal(e):
        if assumed(node, 'is_Sub'):
            replacements.append((node, Add(node.args[0], Mul(-1, node.args[1]))))
    for node, replacement in replacements:
        e = e.subs(node, replacement)

    return e
示例#12
0
def sub_post(e):
    """ Replace Neg(x) with -x.
    """
    replacements = []
    for node in preorder_traversal(e):
        if isinstance(node, Neg):
            replacements.append((node, -node.args[0]))
    for node, replacement in replacements:
        e = e.xreplace({node: replacement})

    return e
示例#13
0
def sub_post(e):
    """ Replace Neg(x) with -x.
    """
    replacements = []
    for node in preorder_traversal(e):
        if isinstance(node, Neg):
            replacements.append((node, -node.args[0]))
    for node, replacement in replacements:
        e = e.xreplace({node: replacement})

    return e
示例#14
0
def test_preorder_traversal():
    expr = z + w * (x + y)
    expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x]
    expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y]
    expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]

    expr = Piecewise((x, x < 1), (x**2, True))
    assert list(preorder_traversal(expr)) == [
        Piecewise((x, x < 1), (x**2, True)),
        ExprCondPair(x, x < 1), x, x < 1, x, 1,
        ExprCondPair(x**2, True), x**2, x, 2, True
    ]
    assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
        x, 2, x**2, x, 0, 1, (0, 1), (x, (0, 1)), ((x, (0, 1)), ),
        Integral(x**2, (x, 0, 1))
    ]
    assert list(postorder_traversal(
        ('abc',
         ('d',
          'ef')))) == ['abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
示例#15
0
文件: solvers.py 项目: qmattpap/sympy
def denoms(eq, x=None):
    """Return (recursively) set of all denominators that appear in eq
    that contain any symbol in x; if x is None (default) then all
    denominators with symbols will be returned."""
    from sympy.utilities.iterables import preorder_traversal

    if x is None:
        x = eq.free_symbols
    dens = set()
    pt = preorder_traversal(eq)
    for e in pt:
        if e.is_Pow or e.func is exp:
            n, d = e.as_numer_denom()
            if d in dens:
                pt.skip()
            elif d.has(*x):
                dens.add(d.as_base_exp()[0])
    return dens
示例#16
0
文件: solvers.py 项目: Jerryy/sympy
def denoms(eq, x=None):
    """Return (recursively) set of all denominators that appear in eq
    that contain any symbol in x; if x is None (default) then all
    denominators with symbols will be returned."""
    from sympy.utilities.iterables import preorder_traversal

    if x is None:
        x = eq.free_symbols
    dens = set()
    pt = preorder_traversal(eq)
    for e in pt:
        if e.is_Pow or e.func is exp:
            n, d = e.as_numer_denom()
            if d in dens:
                pt.skip()
            elif d.has(*x):
                dens.add(d.as_base_exp()[0])
    return dens
示例#17
0
def peel_terms(s):
    lim_min = None
    lim_max = None
    for p in preorder_traversal(s):
        if isinstance(p,Sum):
            lmin = p.limits[0][1]
            lmax = p.limits[0][2]
            if lim_min:
                lim_min = max(lmin,lim_min)
            else:
                lim_min = lmin
            if lim_max:
                lim_max = min(lmax,lim_max)
            else:
                lim_max = lmax
         
    p = peel_sum_terms(lim_min, lim_max)     
    new_s = rewrite(s,[(Sum,p)])
    return new_s
示例#18
0
def peel_terms(s):
    lim_min = None
    lim_max = None
    for p in preorder_traversal(s):
        if isinstance(p, Sum):
            lmin = p.limits[0][1]
            lmax = p.limits[0][2]
            if lim_min:
                lim_min = max(lmin, lim_min)
            else:
                lim_min = lmin
            if lim_max:
                lim_max = min(lmax, lim_max)
            else:
                lim_max = lmax

    p = peel_sum_terms(lim_min, lim_max)
    new_s = rewrite(s, [(Sum, p)])
    return new_s
示例#19
0
文件: basic.py 项目: goodok/sympy
def _atomic(e):
    """Return atom-like quantities as far as substitution is
    concerned: Derivatives, Functions and Symbols. Don't
    return any 'atoms' that are inside such quantities unless
    they also appear outside, too.

    Examples
    ========
    >>> from sympy import Derivative, Function, cos
    >>> from sympy.abc import x, y
    >>> from sympy.core.basic import _atomic
    >>> f = Function('f')
    >>> _atomic(x + y)
    set([x, y])
    >>> _atomic(x + f(y))
    set([x, f(y)])
    >>> _atomic(Derivative(f(x), x) + cos(x) + y)
    set([y, cos(x), Derivative(f(x), x)])

    """
    from sympy import Derivative, Function, Symbol
    from sympy.utilities.iterables import preorder_traversal
    pot = preorder_traversal(e)
    seen = set()
    try:
        free = e.free_symbols
    except AttributeError:
        return set([e])
    atoms = set()
    for p in pot:
        if p in seen:
            pot.skip()
            continue
        seen.add(p)
        if isinstance(p, Symbol) and p in free:
            atoms.add(p)
        elif isinstance(p, (Derivative, Function)):
            pot.skip()
            atoms.add(p)
    return atoms
示例#20
0
文件: basic.py 项目: BDGLunde/sympy
def _atomic(e):
    """Return atom-like quantities as far as substitution is
    concerned: Derivatives, Functions and Symbols. Don't
    return any 'atoms' that are inside such quantities unless
    they also appear outside, too.

    Examples
    ========
    >>> from sympy import Derivative, Function, cos
    >>> from sympy.abc import x, y
    >>> from sympy.core.basic import _atomic
    >>> f = Function('f')
    >>> _atomic(x + y)
    set([x, y])
    >>> _atomic(x + f(y))
    set([x, f(y)])
    >>> _atomic(Derivative(f(x), x) + cos(x) + y)
    set([y, cos(x), Derivative(f(x), x)])

    """
    from sympy import Derivative, Function, Symbol
    from sympy.utilities.iterables import preorder_traversal
    pot = preorder_traversal(e)
    seen = set()
    try:
        free = e.free_symbols
    except AttributeError:
        return set([e])
    atoms = set()
    for p in pot:
        if p in seen:
            pot.skip()
            continue
        seen.add(p)
        if isinstance(p, Symbol) and p in free:
            atoms.add(p)
        elif isinstance(p, (Derivative, Function)):
            pot.skip()
            atoms.add(p)
    return atoms
示例#21
0
def sub_pre(e):
    """ Replace Add(x, Mul(NegativeOne(-1), y)) with Sub(x, y).
    """
    replacements = []
    for node in preorder_traversal(e):
        if assumed(node, 'is_Add'):
            positives = []
            negatives = []
            for arg in node.args:
                if (assumed(arg, 'is_Mul') and
                    assumed(arg.args[0], 'is_number') and
                    assumed(arg.args[0], 'is_negative')):
                    negatives.append(Mul(-arg.args[0], *arg.args[1:]))
                else:
                    positives.append(arg)
            if len(negatives) > 0:
                replacement = Sub(Add(*positives), Add(*negatives))
                replacements.append((node, replacement))
    for node, replacement in replacements:
        e = e.subs(node, replacement)

    return e
示例#22
0
def sub_pre(e):
    """ Replace Add(x, Mul(NegativeOne(-1), y)) with Sub(x, y).
    """
    replacements = []
    for node in preorder_traversal(e):
        if assumed(node, 'is_Add'):
            positives = []
            negatives = []
            for arg in node.args:
                if assumed(arg, 'is_Mul'):
                    a, b = arg.as_two_terms()
                    if (assumed(a, 'is_number') and assumed(a, 'is_negative')):
                        negatives.append(Mul(-a, b))
                        continue
                positives.append(arg)
            if len(negatives) > 0:
                replacement = Sub(Add(*positives), Add(*negatives))
                replacements.append((node, replacement))
    for node, replacement in replacements:
        e = e.subs(node, replacement)

    return e
示例#23
0
文件: cse_opts.py 项目: 101man/sympy
def sub_pre(e):
    """ Replace Add(x, Mul(NegativeOne(-1), y)) with Sub(x, y).
    """
    replacements = []
    for node in preorder_traversal(e):
        if node.is_Add:
            positives = []
            negatives = []
            for arg in node.args:
                if arg.is_Mul:
                    a, b = arg.as_two_terms()
                    if (a.is_number and a.is_negative):
                        negatives.append(Mul(-a, b))
                        continue
                positives.append(arg)
            if len(negatives) > 0:
                replacement = Sub(Add(*positives), Add(*negatives))
                replacements.append((node, replacement))
    for node, replacement in replacements:
        e = e.subs(node, replacement)

    return e
示例#24
0
 def singles(self):
     counter = self.size-1
     while counter >= 0:
         command = self.commands[counter]
         if command.tag != "final":
             # do not consider end results
             # check for usage
             used = 0
             tmp = None
             for later_command in self.commands[counter+1:]:
                 if command.symbol in later_command.expr:
                     for subtree in preorder_traversal(later_command.expr):
                         if subtree == command.symbol:
                             used += 1
                             tmp = later_command
                         if isinstance(subtree, C.Pow) and subtree.exp == 2 and subtree.base == command.symbol:
                             used += 1
                             tmp = later_command
             if used == 1:
                 tmp.expr = tmp.expr.subs(command.symbol, command.expr)
                 print "SINGLE", command
                 del self.commands[counter]
         counter -= 1
示例#25
0
def test_preorder_traversal():
    expr = z+w*(x+y)
    expected1 = [z + w*(x + y), z, w*(x + y), w, x + y, y, x]
    expected2 = [z + w*(x + y), z, w*(x + y), w, x + y, x, y]
    expected3 = [z + w*(x + y), w*(x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]
示例#26
0
文件: cse_main.py 项目: Jerryy/sympy
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.

    Parameters:

    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The `numbered_symbols` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
        iterator.
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        `sympy.simplify.cse.cse_optimizations` is used.

    Returns:

    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    """
    if symbols is None:
        symbols = numbered_symbols()
    else:
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    muls = set()
    adds = set()
    to_eliminate = []
    to_eliminate_ops_count = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    def insert(subtree):
        '''This helper will insert the subtree into to_eliminate while
        maintaining the ordering by op count and will skip the insertion
        if subtree is already present.'''
        ops_count = subtree.count_ops()
        index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count)
        # all i up to this index have op count <= the current op count
        # so check that subtree is not yet present from this index down
        # (if necessary) to zero.
        for i in xrange(index_to_insert - 1, -1, -1):
            if to_eliminate_ops_count[i] == ops_count and \
               subtree == to_eliminate[i]:
                return # already have it
        to_eliminate_ops_count.insert(index_to_insert, ops_count)
        to_eliminate.insert(index_to_insert, subtree)

    for expr in exprs:
        pt = preorder_traversal(expr)
        for subtree in pt:
            if subtree.is_Atom:
                # Exclude atoms, since there is no point in renaming them.
                continue

            if subtree in seen_subexp:
                insert(subtree)
                pt.skip()
                continue

            if subtree.is_Mul:
                muls.add(subtree)
            elif subtree.is_Add:
                adds.add(subtree)

            seen_subexp.add(subtree)

    # process adds - any adds that weren't repeated might contain
    # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
    adds = [set(a.args) for a in adds]
    for i in xrange(len(adds)):
        for j in xrange(i + 1, len(adds)):
            com = adds[i].intersection(adds[j])
            if len(com) > 1:
                insert(Add(*com))

                # remove this set of symbols so it doesn't appear again
                adds[i] = adds[i].difference(com)
                adds[j] = adds[j].difference(com)
                for k in xrange(j + 1, len(adds)):
                    if not com.difference(adds[k]):
                        adds[k] = adds[k].difference(com)

    # process muls - any muls that weren't repeated might contain
    # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common

    # use SequenceMatcher on the nc part to find the longest common expression
    # in common between the two nc parts
    sm = difflib.SequenceMatcher()

    muls = [a.args_cnc() for a in muls]
    for i in xrange(len(muls)):
        if muls[i][1]:
            sm.set_seq1(muls[i][1])
        for j in xrange(i + 1, len(muls)):
            # the commutative part in common
            ccom = muls[i][0].intersection(muls[j][0])

            # the non-commutative part in common
            if muls[i][1] and muls[j][1]:
                # see if there is any chance of an nc match
                ncom = set(muls[i][1]).intersection(set(muls[j][1]))
                if len(ccom) + len(ncom) < 2:
                    continue

                # now work harder to find the match
                sm.set_seq2(muls[j][1])
                i1, _, n = sm.find_longest_match(0, len(muls[i][1]),
                                                 0, len(muls[j][1]))
                ncom = muls[i][1][i1:i1 + n]
            else:
                ncom = []

            com = list(ccom) + ncom
            if len(com) < 2:
                continue

            insert(Mul(*com))

            # remove ccom from all if there was no ncom; to update the nc part
            # would require finding the subexpr and then replacing it with a
            # dummy to keep bounding nc symbols from being identified as a
            # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be
            # identified as a subexpr which would not be right.
            if not ncom:
                muls[i][0] = muls[i][0].difference(ccom)
                for k in xrange(j, len(muls)):
                    if not ccom.difference(muls[k][0]):
                        muls[k][0] = muls[k][0].difference(ccom)

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        for j in range(i+1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]

    return replacements, reduced_exprs
示例#27
0
def get_index(term):
    for p in preorder_traversal(term):
        if type(p) == Indexed:
            return p.indices[0]
    return None
示例#28
0
def get_index(term):
    for p in preorder_traversal(term):
        if type(p) == Indexed:
            return p.indices[0]
    return None
示例#29
0
def test_preorder_traversal():
    expr = z + w * (x + y)
    expected1 = [z + w * (x + y), z, w * (x + y), w, x + y, y, x]
    expected2 = [z + w * (x + y), z, w * (x + y), w, x + y, x, y]
    expected3 = [z + w * (x + y), w * (x + y), w, x + y, y, x, z]
    assert list(preorder_traversal(expr)) in [expected1, expected2, expected3]
示例#30
0
文件: exprtools.py 项目: ENuge/sympy
def _mask_nc(eq):
    """Return ``eq`` with non-commutative objects replaced with dummy
    symbols. A dictionary that can be used to restore the original
    values is returned: if it is None, the expression is
    noncommutative and cannot be made commutative. The third value
    returned is a list of any non-commutative symbols that appeared
    in the equation.

    Notes
    =====
    All commutative objects (other than Symbol) will be replaced;
    if the only non-commutative obects are Symbols, if there is only
    1 Symbol, it will be replaced; if there are more than one then
    they will not be replaced; the calling routine should handle
    replacements in this case since some care must be taken to keep
    track of the ordering of symbols when they occur within Muls.

    Examples
    ========
    >>> from sympy.physics.secondquant import Commutator, NO, F, Fd
    >>> from sympy import Dummy, symbols
    >>> from sympy.abc import x, y
    >>> from sympy.core.exprtools import _mask_nc
    >>> A, B, C = symbols('A,B,C', commutative=False)
    >>> Dummy._count = 0 # reset for doctest purposes
    >>> _mask_nc(A**2 - x**2)
    (_0**2 - x**2, {_0: A}, [])
    >>> _mask_nc(A**2 - B**2)
    (A**2 - B**2, None, [A, B])
    >>> _mask_nc(1 + x*Commutator(A, B))
    (_1*x + 1, {_1: Commutator(A, B)}, [A, B])
    >>> _mask_nc(NO(Fd(x)*F(y)))
    (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, [])

    """
    expr = eq
    if expr.is_commutative:
        return eq, {}, []
    # if there is only one nc symbol, it can be factored regularly but
    # polys is going to complain, so replace it with a dummy
    rep = []
    nc_syms = [s for s in expr.free_symbols if not s.is_commutative]
    if len(nc_syms) == 1:
        nc = Dummy()
        rep.append((nc_syms.pop(), nc))
        expr = expr.subs(rep)
    # even though the noncommutative symbol may be gone, the expression
    # might still appear noncommutative; if it's a non-elementary object
    # we will replace it, but if it is a Symbol, Add, Mul, Pow we leave
    # it alone.
    nc_syms.sort(key=default_sort_key)
    if nc_syms or not expr.is_commutative:
        pot = preorder_traversal(expr)
        for i, a in enumerate(pot):
            if any(a == r[0] for r in rep):
                pass
            elif (
                not a.is_commutative and
                not (a.is_Symbol or a.is_Add or a.is_Mul or a.is_Pow)
                ):
                rep.append((a, Dummy()))
            else:
                continue # don't skip
            pot.skip() # don't go any further
        expr = expr.subs(rep)
    return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
示例#31
0
def _mask_nc(eq):
    """Return ``eq`` with non-commutative objects replaced with dummy
    symbols. A dictionary that can be used to restore the original
    values is returned: if it is None, the expression is
    noncommutative and cannot be made commutative. The third value
    returned is a list of any non-commutative symbols that appear
    in the returned equation.

    Notes
    =====
    All non-commutative objects other than Symbols are replaced with
    a non-commutative Symbol. Identical objects will be identified
    by identical symbols.

    If there is only 1 non-commutative object in an expression it will
    be replaced with a commutative symbol. Otherwise, the non-commutative
    entities are retained and the calling routine should handle
    replacements in this case since some care must be taken to keep
    track of the ordering of symbols when they occur within Muls.

    Examples
    ========
    >>> from sympy.physics.secondquant import Commutator, NO, F, Fd
    >>> from sympy import Dummy, symbols, Sum, Mul, Basic
    >>> from sympy.abc import x, y
    >>> from sympy.core.exprtools import _mask_nc
    >>> A, B, C = symbols('A,B,C', commutative=False)
    >>> Dummy._count = 0 # reset for doctest purposes

    One nc-symbol:

    >>> _mask_nc(A**2 - x**2)
    (_0**2 - x**2, {_0: A}, [])

    Multiple nc-symbols:

    >>> _mask_nc(A**2 - B**2)
    (A**2 - B**2, None, [A, B])

    An nc-object with nc-symbols but no others outside of it:

    >>> _mask_nc(1 + x*Commutator(A, B))
    (_1*x + 1, {_1: Commutator(A, B)}, [])
    >>> _mask_nc(NO(Fd(x)*F(y)))
    (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, [])

    Multiple nc-objects:

    >>> eq = x*Commutator(A, B) + x*Commutator(A, C)*Commutator(A, B)
    >>> _mask_nc(eq)
    (x*_3*_4 + x*_4, {_3: Commutator(A, C), _4: Commutator(A, B)}, [_3, _4])

    Multiple nc-objects and nc-symbols:

    >>> eq = A*Commutator(A, B) + B*Commutator(A, C)
    >>> _mask_nc(eq)
    (A*_6 + B*_5, {_5: Commutator(A, C), _6: Commutator(A, B)}, [_5, _6, A, B])

    If there is an object that:

        - doesn't contain nc-symbols
        - but has arguments which derive from Basic, not Expr
        - and doesn't define an _eval_is_commutative routine

    then it will give False (or None?) for the is_commutative test. Such
    objects are also removed by this routine:

    >>> from sympy import Basic, Mul
    >>> eq = (1 + Mul(Basic(), Basic(), evaluate=False))
    >>> eq.is_commutative
    False
    >>> _mask_nc(eq)
    (_7**2 + 1, {_7: Basic()}, [])

    """
    expr = eq
    if expr.is_commutative:
        return eq, {}, []

    # identify nc-objects; symbols and other
    rep = []
    nc_obj = set()
    nc_syms = set()
    pot = preorder_traversal(expr)
    for i, a in enumerate(pot):
        if any(a == r[0] for r in rep):
            pot.skip()
        elif not a.is_commutative:
            if a.is_Symbol:
                nc_syms.add(a)
            elif not (a.is_Add or a.is_Mul or a.is_Pow):
                if all(s.is_commutative for s in a.free_symbols):
                    rep.append((a, Dummy()))
                else:
                    nc_obj.add(a)
                pot.skip()

    # If there is only one nc symbol or object, it can be factored regularly
    # but polys is going to complain, so replace it with a Dummy.
    if len(nc_obj) == 1 and not nc_syms:
        rep.append((nc_obj.pop(), Dummy()))
    elif len(nc_syms) == 1 and not nc_obj:
        rep.append((nc_syms.pop(), Dummy()))

    # Any remaining nc-objects will be replaced with an nc-Dummy and
    # identified as an nc-Symbol to watch out for
    while nc_obj:
        nc = Dummy(commutative=False)
        rep.append((nc_obj.pop(), nc))
        nc_syms.add(nc)
    expr = expr.subs(rep)

    nc_syms = list(nc_syms)
    nc_syms.sort(key=default_sort_key)
    return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
示例#32
0
文件: exprtools.py 项目: ness01/sympy
def _mask_nc(eq):
    """Return ``eq`` with non-commutative objects replaced with dummy
    symbols. A dictionary that can be used to restore the original
    values is returned: if it is None, the expression is
    noncommutative and cannot be made commutative. The third value
    returned is a list of any non-commutative symbols that appeared
    in the equation.

    Notes
    =====
    All commutative objects (other than Symbol) will be replaced;
    if the only non-commutative obects are Symbols, if there is only
    1 Symbol, it will be replaced; if there are more than one then
    they will not be replaced; the calling routine should handle
    replacements in this case since some care must be taken to keep
    track of the ordering of symbols when they occur within Muls.

    Examples
    ========
    >>> from sympy.physics.secondquant import Commutator, NO, F, Fd
    >>> from sympy import Dummy, symbols
    >>> from sympy.abc import x, y
    >>> from sympy.core.exprtools import _mask_nc
    >>> A, B, C = symbols('A,B,C', commutative=False)
    >>> Dummy._count = 0 # reset for doctest purposes
    >>> _mask_nc(A**2 - x**2)
    (_0**2 - x**2, {_0: A}, [])
    >>> _mask_nc(A**2 - B**2)
    (A**2 - B**2, None, [A, B])
    >>> _mask_nc(1 + x*Commutator(A, B))
    (_1*x + 1, {_1: Commutator(A, B)}, [A, B])
    >>> _mask_nc(NO(Fd(x)*F(y)))
    (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, [])

    """
    expr = eq
    if expr.is_commutative:
        return eq, {}, []
    # if there is only one nc symbol, it can be factored regularly but
    # polys is going to complain, so replace it with a dummy
    rep = []
    nc_syms = [s for s in expr.free_symbols if not s.is_commutative]
    if len(nc_syms) == 1:
        nc = Dummy()
        rep.append((nc_syms.pop(), nc))
        expr = expr.subs(rep)
    # even though the noncommutative symbol may be gone, the expression
    # might still appear noncommutative; if it's a non-elementary object
    # we will replace it, but if it is a Symbol, Add, Mul, Pow we leave
    # it alone.
    nc_syms.sort(key=default_sort_key)
    if nc_syms or not expr.is_commutative:
        pot = preorder_traversal(expr)
        for i, a in enumerate(pot):
            if any(a == r[0] for r in rep):
                pass
            elif (not a.is_commutative
                  and not (a.is_Symbol or a.is_Add or a.is_Mul or a.is_Pow)):
                rep.append((a, Dummy()))
            else:
                continue  # don't skip
            pot.skip()  # don't go any further
        expr = expr.subs(rep)
    return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
示例#33
0
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.

    Parameters:

    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
        iterator.
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        ``sympy.simplify.cse.cse_optimizations`` is used.

    Returns:

    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    """
    if symbols is None:
        symbols = numbered_symbols()
    else:
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    muls = set()
    adds = set()
    to_eliminate = []
    to_eliminate_ops_count = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    def insert(subtree):
        '''This helper will insert the subtree into to_eliminate while
        maintaining the ordering by op count and will skip the insertion
        if subtree is already present.'''
        ops_count = subtree.count_ops()
        index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count)
        # all i up to this index have op count <= the current op count
        # so check that subtree is not yet present from this index down
        # (if necessary) to zero.
        for i in xrange(index_to_insert - 1, -1, -1):
            if to_eliminate_ops_count[i] == ops_count and \
               subtree == to_eliminate[i]:
                return # already have it
        to_eliminate_ops_count.insert(index_to_insert, ops_count)
        to_eliminate.insert(index_to_insert, subtree)

    for expr in exprs:
        pt = preorder_traversal(expr)
        for subtree in pt:
            if subtree.is_Atom:
                # Exclude atoms, since there is no point in renaming them.
                continue

            if subtree in seen_subexp:
                insert(subtree)
                pt.skip()
                continue

            if subtree.is_Mul:
                muls.add(subtree)
            elif subtree.is_Add:
                adds.add(subtree)

            seen_subexp.add(subtree)

    # process adds - any adds that weren't repeated might contain
    # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
    adds = [set(a.args) for a in adds]
    for i in xrange(len(adds)):
        for j in xrange(i + 1, len(adds)):
            com = adds[i].intersection(adds[j])
            if len(com) > 1:
                insert(Add(*com))

                # remove this set of symbols so it doesn't appear again
                adds[i] = adds[i].difference(com)
                adds[j] = adds[j].difference(com)
                for k in xrange(j + 1, len(adds)):
                    if not com.difference(adds[k]):
                        adds[k] = adds[k].difference(com)

    # process muls - any muls that weren't repeated might contain
    # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common

    # use SequenceMatcher on the nc part to find the longest common expression
    # in common between the two nc parts
    sm = difflib.SequenceMatcher()

    muls = [a.args_cnc() for a in muls]
    for i in xrange(len(muls)):
        if muls[i][1]:
            sm.set_seq1(muls[i][1])
        for j in xrange(i + 1, len(muls)):
            # the commutative part in common
            ccom = muls[i][0].intersection(muls[j][0])

            # the non-commutative part in common
            if muls[i][1] and muls[j][1]:
                # see if there is any chance of an nc match
                ncom = set(muls[i][1]).intersection(set(muls[j][1]))
                if len(ccom) + len(ncom) < 2:
                    continue

                # now work harder to find the match
                sm.set_seq2(muls[j][1])
                i1, _, n = sm.find_longest_match(0, len(muls[i][1]),
                                                 0, len(muls[j][1]))
                ncom = muls[i][1][i1:i1 + n]
            else:
                ncom = []

            com = list(ccom) + ncom
            if len(com) < 2:
                continue

            insert(Mul(*com))

            # remove ccom from all if there was no ncom; to update the nc part
            # would require finding the subexpr and then replacing it with a
            # dummy to keep bounding nc symbols from being identified as a
            # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be
            # identified as a subexpr which would not be right.
            if not ncom:
                muls[i][0] = muls[i][0].difference(ccom)
                for k in xrange(j, len(muls)):
                    if not ccom.difference(muls[k][0]):
                        muls[k][0] = muls[k][0].difference(ccom)

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        for j in range(i+1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]

    return replacements, reduced_exprs
示例#34
0
文件: cse_main.py 项目: pyc111/sympy
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.

    Parameters:

    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The `numbered_symbols` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
        iterator.
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        `sympy.simplify.cse.cse_optimizations` is used.

    Returns:

    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    """
    if symbols is None:
        symbols = numbered_symbols()
    else:
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    to_eliminate = []
    to_eliminate_ops_count = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    for expr in exprs:
        pt = preorder_traversal(expr)
        for subtree in pt:
            if subtree.is_Atom:
                # Exclude atoms, since there is no point in renaming them.
                continue
            elif subtree in seen_subexp:
                if subtree not in to_eliminate:
                    ops_count = subtree.count_ops()
                    index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count)

                    to_eliminate_ops_count.insert(index_to_insert, ops_count)
                    to_eliminate.insert(index_to_insert, subtree)
                pt.skip()
            else:
                seen_subexp.add(subtree)

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        # WARNING: modifying iterated list in-place! I think it's fine,
        # but there might be clearer alternatives.
        for j in range(i + 1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]

    return replacements, reduced_exprs
示例#35
0
def _mask_nc(eq):
    """Return ``eq`` with non-commutative objects replaced with dummy
    symbols. A dictionary that can be used to restore the original
    values is returned: if it is None, the expression is
    noncommutative and cannot be made commutative. The third value
    returned is a list of any non-commutative symbols that appear
    in the returned equation.

    Notes
    =====
    All non-commutative objects other than Symbols are replaced with
    a non-commutative Symbol. Identical objects will be identified
    by identical symbols.

    If there is only 1 non-commutative object in an expression it will
    be replaced with a commutative symbol. Otherwise, the non-commutative
    entities are retained and the calling routine should handle
    replacements in this case since some care must be taken to keep
    track of the ordering of symbols when they occur within Muls.

    Examples
    ========
    >>> from sympy.physics.secondquant import Commutator, NO, F, Fd
    >>> from sympy import Dummy, symbols, Sum
    >>> from sympy.abc import x, y
    >>> from sympy.core.exprtools import _mask_nc
    >>> A, B, C = symbols('A,B,C', commutative=False)
    >>> Dummy._count = 0 # reset for doctest purposes

    One nc-symbol:

    >>> _mask_nc(A**2 - x**2)
    (_0**2 - x**2, {_0: A}, [])

    Multiple nc-symbols:

    >>> _mask_nc(A**2 - B**2)
    (A**2 - B**2, None, [A, B])

    An nc-object with nc-symbols but no others outside of it:

    >>> _mask_nc(1 + x*Commutator(A, B))
    (_1*x + 1, {_1: Commutator(A, B)}, [])
    >>> _mask_nc(NO(Fd(x)*F(y)))
    (_2, {_2: NO(CreateFermion(x)*AnnihilateFermion(y))}, [])

    An nc-object without nc-symbols:

    >>> _mask_nc(x + x*Sum(x, (x, 1, 2)))
    (_3*x + x, {_3: Sum(x, (x, 1, 2))}, [])

    Multiple nc-objects:

    >>> eq = x*Commutator(A, B) + x*Commutator(A, C)*Commutator(A, B)
    >>> _mask_nc(eq)
    (x*_4*_5 + x*_5, {_4: Commutator(A, C), _5: Commutator(A, B)}, [_4, _5])

    Multiple nc-objects and nc-symbols:

    >>> eq = A*Commutator(A, B) + B*Commutator(A, C)
    >>> _mask_nc(eq)
    (A*_7 + B*_6, {_6: Commutator(A, C), _7: Commutator(A, B)}, [_6, _7, A, B])

    """
    expr = eq
    if expr.is_commutative:
        return eq, {}, []

    # identify nc-objects; symbols and other
    rep = []
    nc_obj = set()
    nc_syms = set()
    pot = preorder_traversal(expr)
    for i, a in enumerate(pot):
        if any(a == r[0] for r in rep):
            pot.skip()
        elif not a.is_commutative:
            if a.is_Symbol:
                nc_syms.add(a)
            elif not (a.is_Add or a.is_Mul or a.is_Pow):
                if all(s.is_commutative for s in a.free_symbols):
                    rep.append((a, Dummy()))
                else:
                    nc_obj.add(a)
                pot.skip()

    # If there is only one nc symbol or object, it can be factored regularly
    # but polys is going to complain, so replace it with a Dummy.
    if len(nc_obj) == 1 and not nc_syms:
        rep.append((nc_obj.pop(), Dummy()))
    elif len(nc_syms) == 1 and not nc_obj:
        rep.append((nc_syms.pop(), Dummy()))

    # Any remaining nc-objects will be replaced with an nc-Dummy and
    # identified as an nc-Symbol to watch out for
    while nc_obj:
        nc = Dummy(commutative=False)
        rep.append((nc_obj.pop(), nc))
        nc_syms.add(nc)
    expr = expr.subs(rep)

    nc_syms = list(nc_syms)
    nc_syms.sort(key=default_sort_key)
    return expr, dict([(v, k) for k, v in rep]) or None, nc_syms
示例#36
0
 def find_double_pow(expr):
     for sub in preorder_traversal(expr):
         if isinstance(sub, C.Pow) and isinstance(sub.base, C.Pow):
             return sub