def test_apart_list():
    from sympy.utilities.iterables import numbered_symbols

    w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2")
    _a = Dummy("a")

    f = (-2 * x - 2 * x ** 2) / (3 * x ** 2 - 6 * x)
    assert apart_list(f, x, dummies=numbered_symbols("w")) == (
        Poly(S(2) / 3, x, domain="QQ"),
        [(Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 2), Lambda(_a, -_a + x), 1)],

    assert apart_list(2 / (x ** 2 - 2), x, dummies=numbered_symbols("w")) == (
        Poly(0, x, domain="ZZ"),
        [(Poly(w0 ** 2 - 2, w0, domain="ZZ"), Lambda(_a, _a / 2), Lambda(_a, -_a + x), 1)],

    f = 36 / (x ** 5 - 2 * x ** 4 - 2 * x ** 3 + 4 * x ** 2 + x - 2)
    assert apart_list(f, x, dummies=numbered_symbols("w")) == (
        Poly(0, x, domain="ZZ"),
            (Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 4), Lambda(_a, -_a + x), 1),
            (Poly(w1 ** 2 - 1, w1, domain="ZZ"), Lambda(_a, -3 * _a - 6), Lambda(_a, -_a + x), 2),
            (Poly(w2 + 1, w2, domain="ZZ"), Lambda(_a, -4), Lambda(_a, -_a + x), 1),
def test_apart_list():
    from sympy.utilities.iterables import numbered_symbols

    def dummy_eq(i, j):
        if type(i) in (list, tuple):
            return all(dummy_eq(i, j) for i, j in zip(i, j))
        return i == j or i.dummy_eq(j)

    w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2")
    _a = Dummy("a")

    f = (-2 * x - 2 * x**2) / (3 * x**2 - 6 * x)
    got = apart_list(f, x, dummies=numbered_symbols("w"))
    ans = (-1, Poly(Rational(2, 3),
                    x, domain='QQ'), [(Poly(w0 - 2, w0, domain='ZZ'),
                                       Lambda(_a, 2), Lambda(_a, -_a + x), 1)])
    assert dummy_eq(got, ans)

    got = apart_list(2 / (x**2 - 2), x, dummies=numbered_symbols("w"))
    ans = (1, Poly(0, x, domain='ZZ'), [(Poly(w0**2 - 2, w0,
                                              domain='ZZ'), Lambda(_a, _a / 2),
                                         Lambda(_a, -_a + x), 1)])
    assert dummy_eq(got, ans)

    f = 36 / (x**5 - 2 * x**4 - 2 * x**3 + 4 * x**2 + x - 2)
    got = apart_list(f, x, dummies=numbered_symbols("w"))
    ans = (1, Poly(0, x, domain='ZZ'), [
        (Poly(w0 - 2, w0, domain='ZZ'), Lambda(_a, 4), Lambda(_a, -_a + x), 1),
        (Poly(w1**2 - 1, w1,
              domain='ZZ'), Lambda(_a, -3 * _a - 6), Lambda(_a, -_a + x), 2),
        (Poly(w2 + 1, w2, domain='ZZ'), Lambda(_a, -4), Lambda(_a, -_a + x), 1)
    assert dummy_eq(got, ans)
def test_satisfiable_all_models():
    from sympy.abc import A, B
    assert next(satisfiable(False, all_models=True)) is False
    assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False]
    assert list(satisfiable(True, all_models=True)) == [{true: true}]

    models = [{A: True, B: False}, {A: False, B: True}]
    result = satisfiable(A ^ B, all_models=True)
    raises(StopIteration, lambda: next(result))
    assert not models

    assert list(satisfiable(Equivalent(A, B), all_models=True)) == \
    [{A: False, B: False}, {A: True, B: True}]

    models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}]
    for model in satisfiable(A >> B, all_models=True):
    assert not models

    # This is a santiy test to check that only the required number
    # of solutions are generated. The expr below has 2**100 - 1 models
    # which would time out the test if all are generated at once.
    from sympy.utilities.iterables import numbered_symbols
    from sympy.logic.boolalg import Or
    sym = numbered_symbols()
    X = [next(sym) for i in range(100)]
    result = satisfiable(Or(*X), all_models=True)
    for i in range(10):
        assert next(result)
def test_take():
    X = numbered_symbols()

    assert take(X, 5) == list(symbols('x0:5'))
    assert take(X, 5) == list(symbols('x5:10'))

    assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
    def _eval_expand_trig(self, **hints):
        arg = self.args[0]
        x = None
        if arg.is_Add:
            from sympy import symmetric_poly
            n = len(arg.args)
            TX = []
            for x in arg.args:
                tx = tan(x, evaluate=False)._eval_expand_trig()

            Yg = numbered_symbols('Y')
            Y = [Yg.next() for i in xrange(n)]

            p = [0, 0]
            for i in xrange(n + 1):
                p[1 - i % 2] += symmetric_poly(i, Y) * (-1)**((i % 4) // 2)
            return (p[0] / p[1]).subs(zip(Y, TX))

            coeff, terms = arg.as_coeff_Mul(rational=True)
            if coeff.is_Integer and coeff > 1:
                I = S.ImaginaryUnit
                z = C.Symbol('dummy', real=True)
                P = ((1 + I * z)**coeff).expand()
                return (C.im(P) / C.re(P)).subs([(z, tan(terms))])
        return tan(arg)
def test_take():
    X = numbered_symbols()

    assert take(X, 5) == list(symbols('x0:5'))
    assert take(X, 5) == list(symbols('x5:10'))

    assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
    def _eval_expand_trig(self, **hints):
        arg = self.args[0]
        x = None
        if arg.is_Add:
            from sympy import symmetric_poly

            n = len(arg.args)
            CX = []
            for x in arg.args:
                cx = cot(x, evaluate=False)._eval_expand_trig()

            Yg = numbered_symbols("Y")
            Y = [Yg.next() for i in xrange(n)]

            p = [0, 0]
            for i in xrange(n, -1, -1):
                p[(n - i) % 2] += symmetric_poly(i, Y) * (-1) ** (((n - i) % 4) // 2)
            return (p[0] / p[1]).subs(zip(Y, CX))
            coeff, terms = arg.as_coeff_Mul(rational=True)
            if coeff.is_Integer and coeff > 1:
                I = S.ImaginaryUnit
                z = C.Symbol("dummy", real=True)
                P = ((z + I) ** coeff).expand()
                return (C.re(P) / C.im(P)).subs([(z, cot(terms))])
        return cot(arg)
    def _eval_expand_trig(self, **hints):
        arg = self.args[0]
        x = None
        if arg.is_Add:
            from sympy import symmetric_poly
            n = len(arg.args)
            TX = []
            for x in arg.args:
                tx = tan(x, evaluate=False)._eval_expand_trig()

            Yg = numbered_symbols('Y')
            Y = [ Yg.next() for i in xrange(n) ]

            p = [0,0]
            for i in xrange(n+1):
                p[1-i%2] += symmetric_poly(i,Y)*(-1)**((i%4)//2)
            return (p[0]/p[1]).subs(zip(Y,TX))

            coeff, terms = arg.as_coeff_Mul(rational=True)
            if coeff.is_Integer and coeff > 1:
                I = S.ImaginaryUnit
                z = C.Symbol('dummy',real=True)
                P = ((1+I*z)**coeff).expand()
                return (C.im(P)/C.re(P)).subs([(z,tan(terms))])
        return tan(arg)
def test_apart_list():
    from sympy.utilities.iterables import numbered_symbols

    w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2")
    _a = Dummy("a")

    f = (-2 * x - 2 * x**2) / (3 * x**2 - 6 * x)
    assert apart_list(f, x, dummies=numbered_symbols("w")) == (
        Poly(Rational(2, 3), x, domain="QQ"),
        [(Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 2), Lambda(_a,
                                                               -_a + x), 1)],

    assert apart_list(2 / (x**2 - 2), x, dummies=numbered_symbols("w")) == (
        Poly(0, x, domain="ZZ"),
            Poly(w0**2 - 2, w0, domain="ZZ"),
            Lambda(_a, _a / 2),
            Lambda(_a, -_a + x),

    f = 36 / (x**5 - 2 * x**4 - 2 * x**3 + 4 * x**2 + x - 2)
    assert apart_list(f, x, dummies=numbered_symbols("w")) == (
        Poly(0, x, domain="ZZ"),
            (Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 4), Lambda(_a,
                                                                  -_a + x), 1),
                Poly(w1**2 - 1, w1, domain="ZZ"),
                Lambda(_a, -3 * _a - 6),
                Lambda(_a, -_a + x),
            (Poly(w2 + 1, w2,
                  domain="ZZ"), Lambda(_a, -4), Lambda(_a, -_a + x), 1),
def cse(exprs, symbols=None, optimizations=None):
    Perform common subexpression elimination on an expression.

    :arg exprs: A list of sympy expressions, or a single sympy expression to reduce
    :arg symbols: An iterator yielding unique Symbols used to label the
        common subexpressions which are pulled out. The ``numbered_symbols``
        generator from sympy is useful. The default is a stream of symbols of the
        form "x0", "x1", etc. This must be an infinite iterator.
    :arg optimizations: A list of (callable, callable) pairs consisting of
        (preprocessor, postprocessor) pairs of external optimization functions.

    :return: This returns a pair ``(replacements, reduced_exprs)``.

        * ``replacements`` is a list of (Symbol, expression) pairs consisting of
          all of the common subexpressions that were replaced. Subexpressions
          earlier in this list might show up in subexpressions later in this list.
        * ``reduced_exprs`` is a list of sympy expressions. This contains the
          reduced expressions with all of the replacements above.
    if isinstance(exprs, Basic):
        exprs = [exprs]

    exprs = list(exprs)

    if optimizations is None:
        optimizations = []

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    if symbols is None:
        symbols = numbered_symbols(cls=Symbol)
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs

    return replacements, reduced_exprs
    def build_cse_fn(symname, symfunc, symbolslist):
        tmpsyms = numbered_symbols("R")
        symbols, simple = cse(symfunc, symbols=tmpsyms)

        code = "double %s(%s)\n" % (str(symname), ", ".join(
            "double const& %s" % x for x in symbolslist))
        code += "{\n"
        for s in symbols:
            code += "    double %s = %s;\n" % (ccode(s[0]), ccode(s[1]))
        code += "    double result = %s;\n" % ccode(simple[0])
        code += "    return result;\n"
        code += "}\n"

        return code
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The `numbered_symbols` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        `sympy.simplify.cse.cse_optimizations` is used.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    to_eliminate = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    for expr in exprs:
        for subtree in postorder_traversal(expr):
            if subtree.args == ():
                # Exclude atoms, since there is no point in renaming them.
            if (subtree.args != () and subtree in seen_subexp
                    and subtree not in to_eliminate):

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        # WARNING: modifying iterated list in-place! I think it's fine,
        # but there might be clearer alternatives.
        for j in range(i + 1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs

    return replacements, reduced_exprs
def test_numbered_symbols():
    s = numbered_symbols(cls=Dummy)
    assert isinstance(next(s), Dummy)
    assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
    def extract_sub_expressions(self, cache_prefix='cache', sub_prefix='sub', prefix='XoXoXoX'):
        # Do the common sub expression elimination.
        common_sub_expressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))

        self.variables[cache_prefix] = []
        self.variables[sub_prefix] = []

        # Create dictionary of new sub expressions
        sub_expression_dict = {}
        for var, void in common_sub_expressions:
            sub_expression_dict[var.name] = var

        # Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable).
        cacheable_list = []
        params_change_list = []
        # common_sube_expressions contains a list of paired tuples with the new variable and what it equals
        for var, expr in common_sub_expressions:
            arg_list = [e for e in expr.atoms() if e.is_Symbol]
            # List any cacheable dependencies of the sub-expression
            cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
            if cacheable_symbols:
                # list which ensures dependencies are cacheable.

        replace_dict = {}
        for i, expr in enumerate(cacheable_list):
            sym_var = sym.var(cache_prefix + str(i))
            replace_dict[expr.name] = sym_var
        for i, expr in enumerate(params_change_list):
            sym_var = sym.var(sub_prefix + str(i))
            replace_dict[expr.name] = sym_var

        for replace, void in common_sub_expressions:
            for expr, keys in zip(expression_substituted_list, self.expression_keys):
                setInDict(self.expressions, keys, expr.subs(replace, replace_dict[replace.name]))
            for void, expr in common_sub_expressions:
                expr = expr.subs(replace, replace_dict[replace.name])

        # Replace original code with code including subexpressions.
        for keys in self.expression_keys:
            for replace, void in common_sub_expressions:
                setInDict(self.expressions, keys, getFromDict(self.expressions, keys).subs(replace, replace_dict[replace.name]))
        self.expressions['parameters_changed'] = {}
        self.expressions['update_cache'] = {}
        for var, expr in common_sub_expressions:
            for replace, void in common_sub_expressions:
                expr = expr.subs(replace, replace_dict[replace.name])
            if var in cacheable_list:
                self.expressions['update_cache'][replace_dict[var.name].name] = expr
                self.expressions['parameters_changed'][replace_dict[var.name].name] = expr
Exemplo n.º 15
def test_numbered_symbols():
    s = numbered_symbols(cls=Dummy)
    assert isinstance(next(s), Dummy)
    assert next(numbered_symbols("C", start=1, exclude=[symbols("C1")])) == symbols(
    def __init__(

        self.print_lambda = print_lambda
        self.use_evalf = use_evalf
        self.float_wrap_evalf = float_wrap_evalf
        self.complex_wrap_evalf = complex_wrap_evalf
        self.use_np = use_np
        self.use_python_math = use_python_math
        self.use_python_cmath = use_python_cmath
        self.use_interval = use_interval

        # Constructing the argument string
        # - check
        if not all([isinstance(a, Symbol) for a in args]):
            raise ValueError("The arguments must be Symbols.")
        # - use numbered symbols
        syms = numbered_symbols(exclude=expr.free_symbols)
        newargs = [next(syms) for _ in args]
        expr = expr.xreplace(dict(zip(args, newargs)))
        argstr = ", ".join([str(a) for a in newargs])
        del syms, newargs, args

        # Constructing the translation dictionaries and making the translation
        self.dict_str = self.get_dict_str()
        self.dict_fun = self.get_dict_fun()
        exprstr = str(expr)
        newexpr = self.tree2str_translate(self.str2tree(exprstr))

        # Constructing the namespaces
        namespace = {}
        # XXX Workaround
        # Ugly workaround because Pow(a,Half) prints as sqrt(a)
        # and sympy_expression_namespace can not catch it.
        from sympy import sqrt

        namespace.update({"sqrt": sqrt})
        namespace.update({"Eq": lambda x, y: x == y})
        namespace.update({"Ne": lambda x, y: x != y})
        # End workaround.
        if use_python_math:
            namespace.update({"math": __import__("math")})
        if use_python_cmath:
            namespace.update({"cmath": __import__("cmath")})
        if use_np:
                namespace.update({"np": __import__("numpy")})
            except ImportError:
                raise ImportError(
                    "experimental_lambdify failed to import numpy.")
        if use_interval:
            namespace.update({"math": __import__("math")})

        # Construct the lambda
        if self.print_lambda:
        eval_str = "lambda %s : ( %s )" % (argstr, newexpr)
        self.eval_str = eval_str
        exec_("from __future__ import division; MYNEWLAMBDA = %s" % eval_str,
        self.lambda_func = namespace["MYNEWLAMBDA"]
    def _eval_rewrite_as_sqrt(self, arg):
        _EXPAND_INTS = False

        def migcdex(x):
            # recursive calcuation of gcd and linear combination
            # for a sequence of integers.
            # Given  (x1, x2, x3)
            # Returns (y1, y1, y3, g)
            # such that g is the gcd and x1*y1+x2*y2+x3*y3 - g = 0
            # Note, that this is only one such linear combination.
            if len(x) == 1:
                return (1, x[0])
            if len(x) == 2:
                return igcdex(x[0], x[-1])
            g = migcdex(x[1:])
            u, v, h = igcdex(x[0], g[-1])
            return tuple([u] + [v * i for i in g[0:-1]] + [h])

        def ipartfrac(r, factors=None):
            if isinstance(r, int):
                return r
            assert isinstance(r, C.Rational)
            n = r.q
            if 2 > r.q * r.q:
                return r.q

            if None == factors:
                a = [n // x**y for x, y in factorint(r.q).iteritems()]
                a = [n // x for x in factors]
            if len(a) == 1:
                return [r]
            h = migcdex(a)
            ans = [r.p * C.Rational(i * j, r.q) for i, j in zip(h[:-1], a)]
            assert r == sum(ans)
            return ans

        pi_coeff = _pi_coeff(arg)
        if pi_coeff is None:
            return None

        assert not pi_coeff.is_integer, "should have been simplified already"

        if not pi_coeff.is_Rational:
            return None

        cst_table_some = {
            5: (sqrt(5) + 1) / 4,
            sqrt((15 + sqrt(17)) / 32 + sqrt(2) * (sqrt(17 - sqrt(17)) + sqrt(
                sqrt(2) *
                (-8 * sqrt(17 + sqrt(17)) -
                 (1 - sqrt(17)) * sqrt(17 - sqrt(17))) + 6 * sqrt(17) + 34)) /
            # 65537 and 257 are the only other known Fermat primes
            # Please add if you would like them

        def fermatCoords(n):
            assert isinstance(n, int)
            assert n > 0
            if n == 1 or 0 == n % 2:
                return False
            primes = dict([(p, 0) for p in cst_table_some])
            assert 1 not in primes
            for p_i in primes:
                while 0 == n % p_i:
                    n = n / p_i
                    primes[p_i] += 1
            if 1 != n:
                return False
            if max(primes.values()) > 1:
                return False
            return tuple([p for p in primes if primes[p] == 1])

        if pi_coeff.q in cst_table_some:
            return C.chebyshevt(pi_coeff.p,

        if 0 == pi_coeff.q % 2:  # recursively remove powers of 2
            narg = (pi_coeff * 2) * S.Pi
            nval = cos(narg)
            if None == nval:
                return None
            nval = nval.rewrite(sqrt)
            if not _EXPAND_INTS:
                if (isinstance(nval, cos) or isinstance(-nval, cos)):
                    return None
            x = (2 * pi_coeff + 1) / 2
            sign_cos = (-1)**((-1 if x < 0 else 1) * int(abs(x)))
            return sign_cos * sqrt((1 + nval) / 2)

        FC = fermatCoords(pi_coeff.q)
        if FC:
            decomp = ipartfrac(pi_coeff, FC)
            X = [(x[1], x[0] * S.Pi)
                 for x in zip(decomp, numbered_symbols('z'))]
            pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X)
            return pcls.rewrite(sqrt)
        if _EXPAND_INTS:
            decomp = ipartfrac(pi_coeff)
            X = [(x[1], x[0] * S.Pi)
                 for x in zip(decomp, numbered_symbols('z'))]
            pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X)
            return pcls
        return None
def cse(exprs, symbols=None, optimizations=None, postprocess=None,
        order='canonical', ignore=()):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a
        stream of symbols of the form "x0", "x1", etc. This must be an
        infinite iterator.
    optimizations : list of (callable, callable) pairs
        The (preprocessor, postprocessor) pairs of external optimization
        functions. Optionally 'basic' can be passed for a set of predefined
        basic optimizations. Such 'basic' optimizations were used by default
        in old implementation, however they can be really slow on larger
        expressions. Now, no pre or post optimizations are made by default.
    postprocess : a function which accepts the two return values of cse and
        returns the desired form of output from cse, e.g. if you want the
        replacements reversed the function might be the following lambda:
        lambda r, e: return reversed(r), e
    order : string, 'none' or 'canonical'
        The order by which Mul and Add arguments are processed. If set to
        'canonical', arguments will be canonically ordered. If set to 'none',
        ordering will be faster but dependent on expressions hashes, thus
        machine dependent and variable. For large expressions where speed is a
        concern, use the setting order='none'.
    ignore : iterable of Symbols
        Substitutions containing any Symbol from ``ignore`` will be ignored.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.


    >>> from sympy import cse, SparseMatrix
    >>> from sympy.abc import x, y, z, w
    >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
    ([(x0, w + y + z)], [x0*(x + x0)/(w + x)**3])

    Note that currently, y + z will not get substituted if -y - z is used.

     >>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3)
     ([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3])

    List of expressions with recursive substitutions:

    >>> m = SparseMatrix([x + y, x + y + z])
    >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
    ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([

    Note: the type and mutability of input matrices is retained.

    >>> isinstance(_[1][-1], SparseMatrix)

    The user may disallow substitutions containing certain symbols:
    >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
    ([(x0, x + 1)], [x0*y**2, 3*x0*y**2])

    from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
                                SparseMatrix, ImmutableSparseMatrix)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, (Basic, MatrixBase)):
        exprs = [exprs]

    copy = exprs
    temp = []
    for e in exprs:
        if isinstance(e, (Matrix, ImmutableMatrix)):
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
    exprs = temp
    del temp

    if optimizations is None:
        optimizations = list()
    elif optimizations == 'basic':
        optimizations = basic_optimizations

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    excluded_symbols = set().union(*[expr.atoms(Symbol)
                                   for expr in reduced_exprs])

    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    symbols = filter_symbols(symbols, excluded_symbols)

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs, order)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
                                           order, ignore)

    # Postprocess the expressions to return the expressions to canonical form.
    exprs = copy
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations)
                     for e in reduced_exprs]

    # Get the matrices back
    for i, e in enumerate(exprs):
        if isinstance(e, (Matrix, ImmutableMatrix)):
            reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
            if isinstance(e, ImmutableMatrix):
                reduced_exprs[i] = reduced_exprs[i].as_immutable()
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
            m = SparseMatrix(e.rows, e.cols, {})
            for k, v in reduced_exprs[i]:
                m[k] = v
            if isinstance(e, ImmutableSparseMatrix):
                m = m.as_immutable()
            reduced_exprs[i] = m

    if postprocess is None:
        return replacements, reduced_exprs

    return postprocess(replacements, reduced_exprs)
def cse(exprs, symbols=None, optimizations=None, postprocess=None):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a
        stream of symbols of the form "x0", "x1", etc. This must be an infinite
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        ``sympy.simplify.cse.cse_optimizations`` is used.
    postprocess : a function which accepts the two return values of cse and
        returns the desired form of output from cse, e.g. if you want the
        replacements reversed the function might be the following lambda:
        lambda r, e: return reversed(r), e


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    from sympy.matrices import Matrix

    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    muls = set()
    adds = set()
    to_eliminate = set()

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    for expr in reduced_exprs:
        if not isinstance(expr, Basic):
        pt = preorder_traversal(expr)
        for subtree in pt:

            inv = 1 / subtree if subtree.is_Pow else None

            if subtree.is_Atom or iterable(subtree) or inv and inv.is_Atom:
                # Exclude atoms, since there is no point in renaming them.

            if subtree in seen_subexp:
                if inv and _coeff_isneg(subtree.exp):
                    # save the form with positive exponent
                    subtree = inv

            if inv and inv in seen_subexp:
                if _coeff_isneg(subtree.exp):
                    # save the form with positive exponent
                    subtree = inv
            elif subtree.is_Mul:
            elif subtree.is_Add:


    # process adds - any adds that weren't repeated might contain
    # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
    adds = [set(a.args) for a in ordered(adds)]
    for i in xrange(len(adds)):
        for j in xrange(i + 1, len(adds)):
            com = adds[i].intersection(adds[j])
            if len(com) > 1:

                # remove this set of symbols so it doesn't appear again
                adds[i] = adds[i].difference(com)
                adds[j] = adds[j].difference(com)
                for k in xrange(j + 1, len(adds)):
                    if not com.difference(adds[k]):
                        adds[k] = adds[k].difference(com)

    # process muls - any muls that weren't repeated might contain
    # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common

    # use SequenceMatcher on the nc part to find the longest common expression
    # in common between the two nc parts
    sm = difflib.SequenceMatcher()

    muls = [a.args_cnc(cset=True) for a in ordered(muls)]
    for i in xrange(len(muls)):
        if muls[i][1]:
        for j in xrange(i + 1, len(muls)):
            # the commutative part in common
            ccom = muls[i][0].intersection(muls[j][0])

            # the non-commutative part in common
            if muls[i][1] and muls[j][1]:
                # see if there is any chance of an nc match
                ncom = set(muls[i][1]).intersection(set(muls[j][1]))
                if len(ccom) + len(ncom) < 2:

                # now work harder to find the match
                i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0,
                ncom = muls[i][1][i1:i1 + n]
                ncom = []

            com = list(ccom) + ncom
            if len(com) < 2:


            # remove ccom from all if there was no ncom; to update the nc part
            # would require finding the subexpr and then replacing it with a
            # dummy to keep bounding nc symbols from being identified as a
            # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be
            # identified as a subexpr which would not be right.
            if not ncom:
                muls[i][0] = muls[i][0].difference(ccom)
                for k in xrange(j, len(muls)):
                    if not ccom.difference(muls[k][0]):
                        muls[k][0] = muls[k][0].difference(ccom)

    # make to_eliminate canonical; we will prefer non-Muls to Muls
    # so select them first (non-Muls will have False for is_Mul and will
    # be first in the ordering.
    to_eliminate = list(ordered(to_eliminate, lambda _: _.is_Mul))

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(reduced_exprs)
    hit = True
    for i, subtree in enumerate(to_eliminate):
        if hit:
            sym = next(symbols)
        hit = False
        if subtree.is_Pow and subtree.exp.is_Rational:
            update = lambda x: x.xreplace({subtree: sym, 1 / subtree: 1 / sym})
            update = lambda x: x.subs(subtree, sym)
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            old = reduced_exprs[j]
            reduced_exprs[j] = update(expr)
            hit = hit or (old != reduced_exprs[j])
        # Make the substitution in all of the subsequent substitutions.
        for j in range(i + 1, len(to_eliminate)):
            old = to_eliminate[j]
            to_eliminate[j] = update(to_eliminate[j])
            hit = hit or (old != to_eliminate[j])
        if hit:
            replacements.append((sym, subtree))

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs

    # remove replacements that weren't used more than once
    _remove_singletons(replacements, reduced_exprs)

    if isinstance(exprs, Matrix):
        reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)]
    if postprocess is None:
        return replacements, reduced_exprs
    return postprocess(replacements, reduced_exprs)
    def __init__(self, args, expr, print_lambda=False, use_evalf=False,
                 float_wrap_evalf=False, complex_wrap_evalf=False,
                 use_np=False, use_python_math=False, use_python_cmath=False,

        self.print_lambda = print_lambda
        self.use_evalf = use_evalf
        self.float_wrap_evalf = float_wrap_evalf
        self.complex_wrap_evalf = complex_wrap_evalf
        self.use_np = use_np
        self.use_python_math = use_python_math
        self.use_python_cmath = use_python_cmath
        self.use_interval = use_interval

        # Constructing the argument string
        # - check
        if not all([isinstance(a, Symbol) for a in args]):
            raise ValueError('The arguments must be Symbols.')
        # - use numbered symbols
        syms = numbered_symbols(exclude=expr.free_symbols)
        newargs = [next(syms) for i in args]
        expr = expr.xreplace(dict(zip(args, newargs)))
        argstr = ', '.join([str(a) for a in newargs])
        del syms, newargs, args

        # Constructing the translation dictionaries and making the translation
        self.dict_str = self.get_dict_str()
        self.dict_fun = self.get_dict_fun()
        exprstr = str(expr)
        # the & and | operators don't work on tuples, see discussion #12108
        exprstr = exprstr.replace(" & "," and ").replace(" | "," or ")

        newexpr = self.tree2str_translate(self.str2tree(exprstr))

        # Constructing the namespaces
        namespace = {}
        # XXX Workaround
        # Ugly workaround because Pow(a,Half) prints as sqrt(a)
        # and sympy_expression_namespace can not catch it.
        from sympy import sqrt
        namespace.update({'sqrt': sqrt})
        namespace.update({'Eq': lambda x, y: x == y})
        # End workaround.
        if use_python_math:
            namespace.update({'math': __import__('math')})
        if use_python_cmath:
            namespace.update({'cmath': __import__('cmath')})
        if use_np:
                namespace.update({'np': __import__('numpy')})
            except ImportError:
                raise ImportError(
                    'experimental_lambdify failed to import numpy.')
        if use_interval:
            namespace.update({'imath': __import__(
                'sympy.plotting.intervalmath', fromlist=['intervalmath'])})
            namespace.update({'math': __import__('math')})

        # Construct the lambda
        if self.print_lambda:
        eval_str = 'lambda %s : ( %s )' % (argstr, newexpr)
        self.eval_str = eval_str
        exec_("from __future__ import division; MYNEWLAMBDA = %s" % eval_str, namespace)
        self.lambda_func = namespace['MYNEWLAMBDA']
def cse(exprs, symbols=None, optimizations=None, postprocess=None):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a
        stream of symbols of the form "x0", "x1", etc. This must be an infinite
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        ``sympy.simplify.cse.cse_optimizations`` is used.
    postprocess : a function which accepts the two return values of cse and
        returns the desired form of output from cse, e.g. if you want the
        replacements reversed the function might be the following lambda:
        lambda r, e: return reversed(r), e


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    from sympy.matrices import Matrix

    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    tmp_symbols = numbered_symbols('_csetmp')
    subexp_iv = dict()
    muls = set()
    adds = set()

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]

    # Preprocess the expressions to give us better optimization opportunities.
    prep_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all subexpressions.
    def _parse(expr):
        if expr.is_Atom:
            # Exclude atoms, since there is no point in renaming them.
            return expr
        if iterable(expr):
            return expr
        subexpr = type(expr)(*map(_parse, expr.args))

        if subexpr in subexp_iv:
            return subexp_iv[subexpr]
        if subexpr.is_Mul:
        elif subexpr.is_Add:

        ivar = next(tmp_symbols)
        subexp_iv[subexpr] = ivar
        return ivar
    tmp_exprs = list()
    for expr in prep_exprs:
        if isinstance(expr, Basic):
    # process adds - any adds that weren't repeated might contain
    # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
    adds = list(ordered(adds))
    addargs = [set(a.args) for a in adds]
    for i in xrange(len(addargs)):
        for j in xrange(i + 1, len(addargs)):
            com = addargs[i].intersection(addargs[j])
            if len(com) > 1:
                add_subexp = Add(*com)
                diff_add_i = addargs[i].difference(com)
                diff_add_j = addargs[j].difference(com)
                if add_subexp in subexp_iv:
                    ivar = subexp_iv[add_subexp]
                    ivar = next(tmp_symbols)
                    subexp_iv[add_subexp] = ivar
                if diff_add_i:
                    newadd = Add(ivar,*diff_add_i)
                    subexp_iv[newadd] = subexp_iv.pop(adds[i])
                    adds[i] = newadd
                #else add_i is itself subexp_iv[add_subexp] -> ivar
                if diff_add_j:
                    newadd = Add(ivar,*diff_add_j)
                    subexp_iv[newadd] = subexp_iv.pop(adds[j])
                    adds[j] = newadd
                #else add_j is itself subexp_iv[add_subexp] -> ivar
                addargs[i] = diff_add_i
                addargs[j] = diff_add_j
                for k in xrange(j + 1, len(addargs)):
                    if com.issubset(addargs[k]):
                        diff_add_k = addargs[k].difference(com)
                        if diff_add_k:
                            newadd = Add(ivar,*diff_add_k)
                            subexp_iv[newadd] = subexp_iv.pop(adds[k])
                            adds[k] = newadd
                        #else add_k is itself subexp_iv[add_subexp] -> ivar
                        addargs[k] = diff_add_k

    # process muls - any muls that weren't repeated might contain
    # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common
    # *assumes that there are no non-commutative parts*
    muls = list(ordered(muls))
    mulargs = [set(a.args) for a in muls]
    for i in xrange(len(mulargs)):
        for j in xrange(i + 1, len(mulargs)):
            com = mulargs[i].intersection(mulargs[j])
            if len(com) > 1:
                mul_subexp = Mul(*com)
                diff_mul_i = mulargs[i].difference(com)
                diff_mul_j = mulargs[j].difference(com)
                if mul_subexp in subexp_iv:
                    ivar = subexp_iv[mul_subexp]
                    ivar = next(tmp_symbols)
                    subexp_iv[mul_subexp] = ivar
                if diff_mul_i:
                    newmul = Mul(ivar,*diff_mul_i)
                    subexp_iv[newmul] = subexp_iv.pop(muls[i])
                    muls[i] = newmul
                #else mul_i is itself subexp_iv[mul_subexp] -> ivar
                if diff_mul_j:
                    newmul = Mul(ivar,*diff_mul_j)
                    subexp_iv[newmul] = subexp_iv.pop(muls[j])
                    muls[j] = newmul
                #else mul_j is itself subexp_iv[mul_subexp] -> ivar
                mulargs[i] = diff_mul_i
                mulargs[j] = diff_mul_j
                for k in xrange(j + 1, len(mulargs)):
                    if com.issubset(mulargs[k]):
                        diff_mul_k = mulargs[k].difference(com)
                        if diff_mul_k:
                            newmul = Mul(ivar,*diff_mul_k)
                            subexp_iv[newmul] = subexp_iv.pop(muls[k])
                            muls[k] = newmul
                        #else mul_k is itself subexp_iv[mul_subexp] -> ivar
                        mulargs[k] = diff_mul_k
    # Find all of the repeated subexpressions.
    ivar_se = {iv:se for se,iv in subexp_iv.iteritems()}
    used_ivs = set()
    repeated = set()
    def _find_repeated_subexprs(subexpr):
        if subexpr.is_Atom:
            symbs = [subexpr]
            symbs = subexpr.args
        for symb in symbs:
            if symb in ivar_se:
                if symb not in used_ivs:
    for expr in tmp_exprs:
    # Substitute symbols for all of the repeated subexpressions.
    # remove temporary replacements that weren't used more than once
    tmpivs_ivs = dict()
    ordered_iv_se = OrderedDict()
    def _get_subexprs(args):
        args = list(args)
        for i,symb in enumerate(args):
            if symb in ivar_se:
                if symb in tmpivs_ivs:
                    args[i] = tmpivs_ivs[symb]
                    subexpr = ivar_se[symb]
                    subexpr = type(subexpr)(*_get_subexprs(subexpr.args))
                    if symb in repeated:
                        ivar = next(symbols)
                        ordered_iv_se[ivar] = subexpr
                        tmpivs_ivs[symb] = ivar
                        args[i] = ivar
                        args[i] = subexpr
        return args

    out_exprs = _get_subexprs(tmp_exprs)    
    # Postprocess the expressions to return the expressions to canonical form.
    ordered_iv_se_notopt = ordered_iv_se
    ordered_iv_se = OrderedDict()
    for i, (ivar, subexpr) in enumerate(ordered_iv_se_notopt.items()):
        subexpr = postprocess_for_cse(subexpr, optimizations)
        ordered_iv_se[ivar] = subexpr
    out_exprs = [postprocess_for_cse(e, optimizations) for e in out_exprs]

    if isinstance(exprs, Matrix):
        out_exprs = Matrix(exprs.rows, exprs.cols, out_exprs)
    if postprocess is None:
        return ordered_iv_se.items(), out_exprs
    return postprocess(ordered_iv_se.items(), out_exprs)
def cse(exprs, symbols=None, optimizations=None, postprocess=None):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a
        stream of symbols of the form "x0", "x1", etc. This must be an infinite
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        ``sympy.simplify.cse.cse_optimizations`` is used.
    postprocess : a function which accepts the two return values of cse and
        returns the desired form of output from cse, e.g. if you want the
        replacements reversed the function might be the following lambda:
        lambda r, e: return reversed(r), e


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    from sympy.matrices import Matrix

    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    muls = set()
    adds = set()
    to_eliminate = set()

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    for expr in reduced_exprs:
        if not isinstance(expr, Basic):
        pt = preorder_traversal(expr)
        for subtree in pt:

            inv = 1/subtree if subtree.is_Pow else None

            if subtree.is_Atom or iterable(subtree) or inv and inv.is_Atom:
                # Exclude atoms, since there is no point in renaming them.

            if subtree in seen_subexp:
                if inv and _coeff_isneg(subtree.exp):
                    # save the form with positive exponent
                    subtree = inv

            if inv and inv in seen_subexp:
                if _coeff_isneg(subtree.exp):
                    # save the form with positive exponent
                    subtree = inv
            elif subtree.is_Mul:
            elif subtree.is_Add:


    # process adds - any adds that weren't repeated might contain
    # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
    adds = [set(a.args) for a in ordered(adds)]
    for i in xrange(len(adds)):
        for j in xrange(i + 1, len(adds)):
            com = adds[i].intersection(adds[j])
            if len(com) > 1:

                # remove this set of symbols so it doesn't appear again
                adds[i] = adds[i].difference(com)
                adds[j] = adds[j].difference(com)
                for k in xrange(j + 1, len(adds)):
                    if not com.difference(adds[k]):
                        adds[k] = adds[k].difference(com)

    # process muls - any muls that weren't repeated might contain
    # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common

    # use SequenceMatcher on the nc part to find the longest common expression
    # in common between the two nc parts
    sm = difflib.SequenceMatcher()

    muls = [a.args_cnc(cset=True) for a in ordered(muls)]
    for i in xrange(len(muls)):
        if muls[i][1]:
        for j in xrange(i + 1, len(muls)):
            # the commutative part in common
            ccom = muls[i][0].intersection(muls[j][0])

            # the non-commutative part in common
            if muls[i][1] and muls[j][1]:
                # see if there is any chance of an nc match
                ncom = set(muls[i][1]).intersection(set(muls[j][1]))
                if len(ccom) + len(ncom) < 2:

                # now work harder to find the match
                i1, _, n = sm.find_longest_match(0, len(muls[i][1]),
                                                 0, len(muls[j][1]))
                ncom = muls[i][1][i1:i1 + n]
                ncom = []

            com = list(ccom) + ncom
            if len(com) < 2:


            # remove ccom from all if there was no ncom; to update the nc part
            # would require finding the subexpr and then replacing it with a
            # dummy to keep bounding nc symbols from being identified as a
            # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be
            # identified as a subexpr which would not be right.
            if not ncom:
                muls[i][0] = muls[i][0].difference(ccom)
                for k in xrange(j, len(muls)):
                    if not ccom.difference(muls[k][0]):
                        muls[k][0] = muls[k][0].difference(ccom)

    # make to_eliminate canonical; we will prefer non-Muls to Muls
    # so select them first (non-Muls will have False for is_Mul and will
    # be first in the ordering.
    to_eliminate = list(ordered(to_eliminate, lambda _: _.is_Mul))

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(reduced_exprs)
    hit = True
    for i, subtree in enumerate(to_eliminate):
        if hit:
            sym = next(symbols)
        hit = False
        if subtree.is_Pow and subtree.exp.is_Rational:
            update = lambda x: x.xreplace({subtree: sym, 1/subtree: 1/sym})
            update = lambda x: x.subs(subtree, sym)
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            old = reduced_exprs[j]
            reduced_exprs[j] = update(expr)
            hit = hit or (old != reduced_exprs[j])
        # Make the substitution in all of the subsequent substitutions.
        for j in range(i + 1, len(to_eliminate)):
            old = to_eliminate[j]
            to_eliminate[j] = update(to_eliminate[j])
            hit = hit or (old != to_eliminate[j])
        if hit:
            replacements.append((sym, subtree))

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations)
        for e in reduced_exprs]

    # remove replacements that weren't used more than once
    _remove_singletons(replacements, reduced_exprs)

    if isinstance(exprs, Matrix):
        reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)]
    if postprocess is None:
        return replacements, reduced_exprs
    return postprocess(replacements, reduced_exprs)
def cgen_ncomp(ncomp=3, nporder=2, aggstat=False, debug=False):
    """Generates a C function for ncomp (int) number of components.
    The jth key component is always in the first position and the kth
    key component is always in the second.  The number of enrichment 
    stages (NP) is calculated via a taylor series approximation.  The
    order of this approximation may be set with nporder.  Only values
    of 1 or 2 are allowed. The aggstat argument determines whether the
    status messages should be aggreated and printed at the end or output
    as the function executes.
    start_time = time.time()
    stat = _aggstatus('', "generating {0} component enrichment".format(ncomp), aggstat)
    r = range(0, ncomp)
    j = 0
    k = 1

    # setup-symbols
    alpha = Symbol('alpha', positive=True, real=True)
    LpF = Symbol('LpF', positive=True, real=True)
    PpF = Symbol('PpF', positive=True, real=True)
    TpF = Symbol('TpF', positive=True, real=True)
    SWUpF = Symbol('SWUpF', positive=True, real=True)
    SWUpP = Symbol('SWUpP', positive=True, real=True)
    NP = Symbol('NP', positive=True, real=True)   # Enrichment Stages
    NT = Symbol('NT', positive=True, real=True)   # De-enrichment Stages
    NP0 = Symbol('NP0', positive=True, real=True) # Enrichment Stages Initial Guess
    NT0 = Symbol('NT0', positive=True, real=True) # De-enrichment Stages Initial Guess
    NP1 = Symbol('NP1', positive=True, real=True) # Enrichment Stages Computed Value
    NT1 = Symbol('NT1', positive=True, real=True) # De-enrichment Stages Computed Value
    Mstar = Symbol('Mstar', positive=True, real=True)
    MW = [Symbol('MW[{0}]'.format(i), positive=True, real=True) for i in r]
    beta = [alpha**(Mstar - MWi) for MWi in MW]

    # np_closed helper terms
    NP_b = Symbol('NP_b', real=True)
    NP_2a = Symbol('NP_2a', real=True)
    NP_sqrt_base = Symbol('NP_sqrt_base', real=True)

    xF = [Symbol('xF[{0}]'.format(i), positive=True, real=True) for i in r]
    xPi = [Symbol('xP[{0}]'.format(i), positive=True, real=True) for i in r]
    xTi = [Symbol('xT[{0}]'.format(i), positive=True, real=True) for i in r]
    xPj = Symbol('xPj', positive=True, real=True)
    xFj = xF[j]
    xTj = Symbol('xTj', positive=True, real=True)
    ppf = (xFj - xTj)/(xPj - xTj)
    tpf = (xFj - xPj)/(xTj - xPj)

    xP = [(((xF[i]/ppf)*(beta[i]**(NT+1) - 1))/(beta[i]**(NT+1) - beta[i]**(-NP))) \
                                                                            for i in r]
    xT = [(((xF[i]/tpf)*(1 - beta[i]**(-NP)))/(beta[i]**(NT+1) - beta[i]**(-NP))) \
                                                                            for i in r]
    rfeed = xFj / xF[k]
    rprod = xPj / xP[k]
    rtail = xTj / xT[k]

    # setup constraint equations
    numer = [ppf*xP[i]*log(rprod) + tpf*xT[i]*log(rtail) - xF[i]*log(rfeed) for i in r]
    denom = [log(beta[j]) * ((beta[i] - 1.0)/(beta[i] + 1.0)) for i in r]
    LoverF = sum([n/d for n, d in zip(numer, denom)])
    SWUoverF = -1.0 * sum(numer)
    SWUoverP = SWUoverF / ppf

    prod_constraint = (xPj/xFj)*ppf - (beta[j]**(NT+1) - 1)/\
                      (beta[j]**(NT+1) - beta[j]**(-NP))
    tail_constraint = (xTj/xFj)*(sum(xT)) - (1 - beta[j]**(-NP))/\
                      (beta[j]**(NT+1) - beta[j]**(-NP))
    #xp_constraint = 1.0 - sum(xP)
    #xf_constraint = 1.0 - sum(xF)
    #xt_constraint = 1.0 - sum(xT)

    # This is NT(NP,...) and is correct!
    #nt_closed = solve(prod_constraint, NT)[0] 

    # However, this is NT(NP,...) rewritten (by hand) to minimize the number of NP 
    # and M* instances in the expression.  Luckily this is only depends on the key 
    # component and remains general no matter the number of components.
    nt_closed = (-MW[0]*log(alpha) + Mstar*log(alpha) + log(xTj) + log((-1.0 + xPj/\
        xF[0])/(xPj - xTj)) - log(alpha**(NP*(MW[0] - Mstar))*(xF[0]*xPj - xPj*xTj)/\
        (-xF[0]*xPj + xF[0]*xTj) + 1))/((MW[0] - Mstar)*log(alpha))

    # new expression for normalized flow rate
    # NOTE: not needed, solved below
    #loverf = LoverF.xreplace({NT: nt_closed})

    # Define the constraint equation with which to solve NP. This is chosen such to 
    # minimize the number of ops in the derivatives (and thus np_closed).  Other, 
    # more verbose possibilities are commented out.
    #np_constraint = (xP[j]/sum(xP) - xPj).xreplace({NT: nt_closed})
    #np_constraint = (xP[j]- sum(xP)*xPj).xreplace({NT: nt_closed})
    #np_constraint = (xT[j]/sum(xT) - xTj).xreplace({NT: nt_closed})
    np_constraint = (xT[j] - sum(xT)*xTj).xreplace({NT: nt_closed})

    # get closed form approximation of NP via symbolic derivatives
    stat = _aggstatus(stat, "  order-{0} NP approximation".format(nporder), aggstat)
    d0NP = np_constraint.xreplace({NP: NP0})
    d1NP = diff(np_constraint, NP, 1).xreplace({NP: NP0})
    if 1 == nporder:
        np_closed = NP0 - d1NP / d0NP
    elif 2 == nporder:
        d2NP = diff(np_constraint, NP, 2).xreplace({NP: NP0})/2.0
        # taylor series polynomial coefficients, grouped by order
        # f(x) = ax**2 + bx + c
        a = d2NP
        b = d1NP - 2*NP0*d2NP
        c = d0NP - NP0*d1NP + NP0*NP0*d2NP
        # quadratic eq. (minus only)
        #np_closed = (-b - sqrt(b**2 - 4*a*c)) / (2*a)
        # However, we need to break up this expr as follows to prevent 
        # a floating point arithmetic bug if b**2 - 4*a*c is very close
        # to zero but happens to be negative.  LAME!!!
        np_2a = 2*a
        np_sqrt_base = b**2 - 4*a*c
        np_closed = (-NP_b - sqrt(NP_sqrt_base)) / (NP_2a)
        raise ValueError("nporder must be 1 or 2")

    # generate cse for writing out
    msg = "  minimizing ops by eliminating common sub-expressions"
    stat = _aggstatus(stat, msg, aggstat)
    exprstages = [Eq(NP_b, b), Eq(NP_2a, np_2a), 
                  # fix for floating point sqrt() error
                  Eq(NP_sqrt_base, np_sqrt_base), Eq(NP_sqrt_base, Abs(NP_sqrt_base)), 
                  Eq(NP1, np_closed), Eq(NT1, nt_closed).xreplace({NP: NP1})]
    cse_stages = cse(exprstages, numbered_symbols('n'))
    exprothers = [Eq(LpF, LoverF), Eq(PpF, ppf), Eq(TpF, tpf), 
                  Eq(SWUpF, SWUoverF), Eq(SWUpP, SWUoverP)] + \
                 [Eq(*z) for z in zip(xPi, xP)] + [Eq(*z) for z in zip(xTi, xT)]
    exprothers = [e.xreplace({NP: NP1, NT: NT1}) for e in exprothers]
    cse_others = cse(exprothers, numbered_symbols('g'))
    exprops = count_ops(exprstages + exprothers)
    cse_ops = count_ops(cse_stages + cse_others)
    msg = "    reduced {0} ops to {1}".format(exprops, cse_ops)
    stat = _aggstatus(stat, msg, aggstat)

    # create function body
    ccode, repnames = cse_to_c(*cse_stages, indent=6, debug=debug)
    ccode_others, repnames_others = cse_to_c(*cse_others, indent=6, debug=debug)
    ccode += ccode_others
    repnames |= repnames_others

    msg = "  completed in {0:.3G} s".format(time.time() - start_time)
    stat = _aggstatus(stat, msg, aggstat)
    if aggstat:
    return ccode, repnames, stat
    def extract_sub_expressions(self, cache_prefix='cache', sub_prefix='sub', prefix='XoXoXoX'):
        # Do the common sub expression elimination.
        common_sub_expressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))

        self.variables[cache_prefix] = []
        self.variables[sub_prefix] = []

        # Create dictionary of new sub expressions
        sub_expression_dict = {}
        for var, void in common_sub_expressions:
            sub_expression_dict[var.name] = var

        # Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable).
        cacheable_list = []
        params_change_list = []
        # common_sube_expressions contains a list of paired tuples with the new variable and what it equals
        for var, expr in common_sub_expressions:
            arg_list = [e for e in expr.atoms() if e.is_Symbol]
            # List any cacheable dependencies of the sub-expression
            cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
            if cacheable_symbols:
                # list which ensures dependencies are cacheable.

        replace_dict = {}
        for i, expr in enumerate(cacheable_list):
            sym_var = sym.var(cache_prefix + str(i))
            replace_dict[expr.name] = sym_var
        for i, expr in enumerate(params_change_list):
            sym_var = sym.var(sub_prefix + str(i))
            replace_dict[expr.name] = sym_var

        for replace, void in common_sub_expressions:
            for expr, keys in zip(expression_substituted_list, self.expression_keys):
                setInDict(self.expressions, keys, expr.subs(replace, replace_dict[replace.name]))
            for void, expr in common_sub_expressions:
                expr = expr.subs(replace, replace_dict[replace.name])

        # Replace original code with code including subexpressions.
        for keys in self.expression_keys:
            for replace, void in common_sub_expressions:
                setInDict(self.expressions, keys, getFromDict(self.expressions, keys).subs(replace, replace_dict[replace.name]))
        self.expressions['parameters_changed'] = {}
        self.expressions['update_cache'] = {}
        for var, expr in common_sub_expressions:
            for replace, void in common_sub_expressions:
                expr = expr.subs(replace, replace_dict[replace.name])
            if var in cacheable_list:
                self.expressions['update_cache'][replace_dict[var.name].name] = expr
                self.expressions['parameters_changed'][replace_dict[var.name].name] = expr
    def __init__(self, args, expr, print_lambda=False, use_evalf=False,
                 float_wrap_evalf=False, complex_wrap_evalf=False,
                 use_np=False, use_python_math=False, use_python_cmath=False,

        self.print_lambda = print_lambda
        self.use_evalf = use_evalf
        self.float_wrap_evalf = float_wrap_evalf
        self.complex_wrap_evalf = complex_wrap_evalf
        self.use_np = use_np
        self.use_python_math = use_python_math
        self.use_python_cmath = use_python_cmath
        self.use_interval = use_interval

        # Constructing the argument string
        # - check
        if not all([isinstance(a, Symbol) for a in args]):
            raise ValueError('The arguments must be Symbols.')
        # - use numbered symbols
        syms = numbered_symbols(exclude=expr.free_symbols)
        newargs = [next(syms) for i in args]
        expr = expr.xreplace(dict(zip(args, newargs)))
        argstr = ', '.join([str(a) for a in newargs])
        del syms, newargs, args

        # Constructing the translation dictionaries and making the translation
        self.dict_str = self.get_dict_str()
        self.dict_fun = self.get_dict_fun()
        exprstr = str(expr)
        # the & and | operators don't work on tuples, see discussion #12108
        exprstr = exprstr.replace(" & "," and ").replace(" | "," or ")

        newexpr = self.tree2str_translate(self.str2tree(exprstr))

        # Constructing the namespaces
        namespace = {}
        # XXX Workaround
        # Ugly workaround because Pow(a,Half) prints as sqrt(a)
        # and sympy_expression_namespace can not catch it.
        from sympy import sqrt
        namespace.update({'sqrt': sqrt})
        namespace.update({'Eq': lambda x, y: x == y})
        # End workaround.
        if use_python_math:
            namespace.update({'math': __import__('math')})
        if use_python_cmath:
            namespace.update({'cmath': __import__('cmath')})
        if use_np:
                namespace.update({'np': __import__('numpy')})
            except ImportError:
                raise ImportError(
                    'experimental_lambdify failed to import numpy.')
        if use_interval:
            namespace.update({'imath': __import__(
                'sympy.plotting.intervalmath', fromlist=['intervalmath'])})
            namespace.update({'math': __import__('math')})

        # Construct the lambda
        if self.print_lambda:
        eval_str = 'lambda %s : ( %s )' % (argstr, newexpr)
        self.eval_str = eval_str
        exec_("from __future__ import division; MYNEWLAMBDA = %s" % eval_str, namespace)
        self.lambda_func = namespace['MYNEWLAMBDA']
def cse(exprs,
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a
        stream of symbols of the form "x0", "x1", etc. This must be an
        infinite iterator.
    optimizations : list of (callable, callable) pairs
        The (preprocessor, postprocessor) pairs of external optimization
        functions. Optionally 'basic' can be passed for a set of predefined
        basic optimizations. Such 'basic' optimizations were used by default
        in old implementation, however they can be really slow on larger
        expressions. Now, no pre or post optimizations are made by default.
    postprocess : a function which accepts the two return values of cse and
        returns the desired form of output from cse, e.g. if you want the
        replacements reversed the function might be the following lambda:
        lambda r, e: return reversed(r), e
    order : string, 'none' or 'canonical'
        The order by which Mul and Add arguments are processed. If set to
        'canonical', arguments will be canonically ordered. If set to 'none',
        ordering will be faster but dependent on expressions hashes, thus
        machine dependent and variable. For large expressions where speed is a
        concern, use the setting order='none'.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    from sympy.matrices import Matrix

    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    if optimizations is None:
        optimizations = list()
    elif optimizations == 'basic':
        optimizations = basic_optimizations

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs, order)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs

    if isinstance(exprs, Matrix):
        reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)]
    if postprocess is None:
        return replacements, reduced_exprs
    return postprocess(replacements, reduced_exprs)
def _gauss_jordan_solve(M, B, freevar=False):
    Solves ``Ax = B`` using Gauss Jordan elimination.

    There may be zero, one, or infinite solutions.  If one solution
    exists, it will be returned. If infinite solutions exist, it will
    be returned parametrically. If no solutions exist, It will throw


    B : Matrix
        The right hand side of the equation to be solved for.  Must have
        the same number of rows as matrix A.

    freevar : List
        If the system is underdetermined (e.g. A has more columns than
        rows), infinite solutions are possible, in terms of arbitrary
        values of free variables. Then the index of the free variables
        in the solutions (column Matrix) will be returned by freevar, if
        the flag `freevar` is set to `True`.


    x : Matrix
        The matrix that will satisfy ``Ax = B``.  Will have as many rows as
        matrix A has columns, and as many columns as matrix B.

    params : Matrix
        If the system is underdetermined (e.g. A has more columns than
        rows), infinite solutions are possible, in terms of arbitrary
        parameters. These arbitrary parameters are returned as params


    >>> from sympy import Matrix
    >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]])
    >>> B = Matrix([7, 12, 4])
    >>> sol, params = A.gauss_jordan_solve(B)
    >>> sol
    [-2*tau0 - 3*tau1 + 2],
    [                 tau0],
    [           2*tau1 + 5],
    [                 tau1]])
    >>> params
    >>> taus_zeroes = { tau:0 for tau in params }
    >>> sol_unique = sol.xreplace(taus_zeroes)
    >>> sol_unique

    >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
    >>> B = Matrix([3, 6, 9])
    >>> sol, params = A.gauss_jordan_solve(B)
    >>> sol
    [ 2],
    [ 0]])
    >>> params
    Matrix(0, 1, [])

    >>> A = Matrix([[2, -7], [-1, 4]])
    >>> B = Matrix([[-21, 3], [12, -2]])
    >>> sol, params = A.gauss_jordan_solve(B)
    >>> sol
    [0, -2],
    [3, -1]])
    >>> params
    Matrix(0, 2, [])

    See Also



    .. [1] https://en.wikipedia.org/wiki/Gaussian_elimination


    from sympy.matrices import Matrix, zeros

    cls = M.__class__
    aug = M.hstack(M.copy(), B.copy())
    B_cols = B.cols
    row, col = aug[:, :-B_cols].shape

    # solve by reduced row echelon form
    A, pivots = aug.rref(simplify=True)
    A, v = A[:, :-B_cols], A[:, -B_cols:]
    pivots = list(filter(lambda p: p < col, pivots))
    rank = len(pivots)

    # Bring to block form
    permutation = Matrix(range(col)).T

    for i, c in enumerate(pivots):
        permutation.col_swap(i, c)

    # check for existence of solutions
    # rank of aug Matrix should be equal to rank of coefficient matrix
    if not v[rank:, :].is_zero_matrix:
        raise ValueError("Linear system has no solution")

    # Get index of free symbols (free parameters)
    # non-pivots columns are free variables
    free_var_index = permutation[len(pivots):]

    # Free parameters
    # what are current unnumbered free symbol names?
    name = _uniquely_named_symbol(
        'tau', aug, compare=lambda i: str(i).rstrip('1234567890')).name
    gen = numbered_symbols(name)
    tau = Matrix([next(gen) for k in range((col - rank) * B_cols)
                  ]).reshape(col - rank, B_cols)

    # Full parametric solution
    V = A[:rank, [c for c in range(A.cols) if c not in pivots]]
    vt = v[:rank, :]
    free_sol = tau.vstack(vt - V * tau, tau)

    # Undo permutation
    sol = zeros(col, B_cols)

    for k in range(col):
        sol[permutation[k], :] = free_sol[k, :]

    sol, tau = cls(sol), cls(tau)

    if freevar:
        return sol, tau, free_var_index
        return sol, tau
Exemplo n.º 28
def cse(exprs,
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a
        stream of symbols of the form "x0", "x1", etc. This must be an
        infinite iterator.
    optimizations : list of (callable, callable) pairs
        The (preprocessor, postprocessor) pairs of external optimization
        functions. Optionally 'basic' can be passed for a set of predefined
        basic optimizations. Such 'basic' optimizations were used by default
        in old implementation, however they can be really slow on larger
        expressions. Now, no pre or post optimizations are made by default.
    postprocess : a function which accepts the two return values of cse and
        returns the desired form of output from cse, e.g. if you want the
        replacements reversed the function might be the following lambda:
        lambda r, e: return reversed(r), e
    order : string, 'none' or 'canonical'
        The order by which Mul and Add arguments are processed. If set to
        'canonical', arguments will be canonically ordered. If set to 'none',
        ordering will be faster but dependent on expressions hashes, thus
        machine dependent and variable. For large expressions where speed is a
        concern, use the setting order='none'.
    ignore : iterable of Symbols
        Substitutions containing any Symbol from ``ignore`` will be ignored.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.


    >>> from sympy import cse, SparseMatrix
    >>> from sympy.abc import x, y, z, w
    >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
    ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3])

    Note that currently, y + z will not get substituted if -y - z is used.

     >>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3)
     ([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3])

    List of expressions with recursive substitutions:

    >>> m = SparseMatrix([x + y, x + y + z])
    >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
    ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([

    Note: the type and mutability of input matrices is retained.

    >>> isinstance(_[1][-1], SparseMatrix)

    The user may disallow substitutions containing certain symbols:
    >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
    ([(x0, x + 1)], [x0*y**2, 3*x0*y**2])

    from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
                                SparseMatrix, ImmutableSparseMatrix)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, (Basic, MatrixBase)):
        exprs = [exprs]

    copy = exprs
    temp = []
    for e in exprs:
        if isinstance(e, (Matrix, ImmutableMatrix)):
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
    exprs = temp
    del temp

    if optimizations is None:
        optimizations = list()
    elif optimizations == 'basic':
        optimizations = basic_optimizations

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    if symbols is None:
        symbols = numbered_symbols(cls=Symbol)
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs, order)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
                                           order, ignore)

    # Postprocess the expressions to return the expressions to canonical form.
    exprs = copy
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs

    # Get the matrices back
    for i, e in enumerate(exprs):
        if isinstance(e, (Matrix, ImmutableMatrix)):
            reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
            if isinstance(e, ImmutableMatrix):
                reduced_exprs[i] = reduced_exprs[i].as_immutable()
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
            m = SparseMatrix(e.rows, e.cols, {})
            for k, v in reduced_exprs[i]:
                m[k] = v
            if isinstance(e, ImmutableSparseMatrix):
                m = m.as_immutable()
            reduced_exprs[i] = m

    if postprocess is None:
        return replacements, reduced_exprs

    return postprocess(replacements, reduced_exprs)
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The `numbered_symbols` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        `sympy.simplify.cse.cse_optimizations` is used.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    muls = set()
    adds = set()
    to_eliminate = []
    to_eliminate_ops_count = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    def insert(subtree):
        '''This helper will insert the subtree into to_eliminate while
        maintaining the ordering by op count and will skip the insertion
        if subtree is already present.'''
        ops_count = subtree.count_ops()
        index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count)
        # all i up to this index have op count <= the current op count
        # so check that subtree is not yet present from this index down
        # (if necessary) to zero.
        for i in xrange(index_to_insert - 1, -1, -1):
            if to_eliminate_ops_count[i] == ops_count and \
               subtree == to_eliminate[i]:
                return # already have it
        to_eliminate_ops_count.insert(index_to_insert, ops_count)
        to_eliminate.insert(index_to_insert, subtree)

    for expr in exprs:
        pt = preorder_traversal(expr)
        for subtree in pt:
            if subtree.is_Atom:
                # Exclude atoms, since there is no point in renaming them.

            if subtree in seen_subexp:

            if subtree.is_Mul:
            elif subtree.is_Add:


    # process adds - any adds that weren't repeated might contain
    # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
    adds = [set(a.args) for a in adds]
    for i in xrange(len(adds)):
        for j in xrange(i + 1, len(adds)):
            com = adds[i].intersection(adds[j])
            if len(com) > 1:

                # remove this set of symbols so it doesn't appear again
                adds[i] = adds[i].difference(com)
                adds[j] = adds[j].difference(com)
                for k in xrange(j + 1, len(adds)):
                    if not com.difference(adds[k]):
                        adds[k] = adds[k].difference(com)

    # process muls - any muls that weren't repeated might contain
    # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common

    # use SequenceMatcher on the nc part to find the longest common expression
    # in common between the two nc parts
    sm = difflib.SequenceMatcher()

    muls = [a.args_cnc() for a in muls]
    for i in xrange(len(muls)):
        if muls[i][1]:
        for j in xrange(i + 1, len(muls)):
            # the commutative part in common
            ccom = muls[i][0].intersection(muls[j][0])

            # the non-commutative part in common
            if muls[i][1] and muls[j][1]:
                # see if there is any chance of an nc match
                ncom = set(muls[i][1]).intersection(set(muls[j][1]))
                if len(ccom) + len(ncom) < 2:

                # now work harder to find the match
                i1, _, n = sm.find_longest_match(0, len(muls[i][1]),
                                                 0, len(muls[j][1]))
                ncom = muls[i][1][i1:i1 + n]
                ncom = []

            com = list(ccom) + ncom
            if len(com) < 2:


            # remove ccom from all if there was no ncom; to update the nc part
            # would require finding the subexpr and then replacing it with a
            # dummy to keep bounding nc symbols from being identified as a
            # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be
            # identified as a subexpr which would not be right.
            if not ncom:
                muls[i][0] = muls[i][0].difference(ccom)
                for k in xrange(j, len(muls)):
                    if not ccom.difference(muls[k][0]):
                        muls[k][0] = muls[k][0].difference(ccom)

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        for j in range(i+1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]

    return replacements, reduced_exprs
Arquivo: cse.py Projeto: rpep/fmmgen
def cse(exprs,
    if isinstance(exprs, (Basic, MatrixBase)):
        exprs = [exprs]

    copy = exprs
    temp = []
    for e in exprs:
        if isinstance(e, (Matrix, ImmutableMatrix)):
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
    exprs = temp
    del temp

    if optimizations is None:
        optimizations = list()
    elif optimizations == 'basic':
        optimizations = basic_optimizations

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    if symbols is None:
        symbols = numbered_symbols(cls=Symbol)
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs, order)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
                                           order, ignore, light_ignore)

    # Postprocess the expressions to return the expressions to canonical form.
    exprs = copy
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs

    # Get the matrices back
    for i, e in enumerate(exprs):
        if isinstance(e, (Matrix, ImmutableMatrix)):
            reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
            if isinstance(e, ImmutableMatrix):
                reduced_exprs[i] = reduced_exprs[i].as_immutable()
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
            m = SparseMatrix(e.rows, e.cols, {})
            for k, v in reduced_exprs[i]:
                m[k] = v
            if isinstance(e, ImmutableSparseMatrix):
                m = m.as_immutable()
            reduced_exprs[i] = m

    if postprocess is None:
        return replacements, reduced_exprs

    return postprocess(replacements, reduced_exprs)
Exemplo n.º 31
def test_numbered_symbols():
    s = numbered_symbols(cls=Dummy)
    assert isinstance(next(s), Dummy)
    assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The `numbered_symbols` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        `sympy.simplify.cse.cse_optimizations` is used.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    to_eliminate = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    for expr in exprs:
        for subtree in postorder_traversal(expr):
            if subtree.args == ():
                # Exclude atoms, since there is no point in renaming them.
            if (subtree.args != () and
                subtree in seen_subexp and
                subtree not in to_eliminate):

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        # WARNING: modifying iterated list in-place! I think it's fine,
        # but there might be clearer alternatives.
        for j in range(i+1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]

    return replacements, reduced_exprs
def test_numbered_symbols():
    s = numbered_symbols(cls=Dummy)
    assert isinstance(s.next(), Dummy)
def cse(exprs, symbols=None, optimizations=None, postprocess=None,
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a
        stream of symbols of the form "x0", "x1", etc. This must be an
        infinite iterator.
    optimizations : list of (callable, callable) pairs
        The (preprocessor, postprocessor) pairs of external optimization
        functions. Optionally 'basic' can be passed for a set of predefined
        basic optimizations. Such 'basic' optimizations were used by default
        in old implementation, however they can be really slow on larger
        expressions. Now, no pre or post optimizations are made by default.
    postprocess : a function which accepts the two return values of cse and
        returns the desired form of output from cse, e.g. if you want the
        replacements reversed the function might be the following lambda:
        lambda r, e: return reversed(r), e
    order : string, 'none' or 'canonical'
        The order by which Mul and Add arguments are processed. If set to
        'canonical', arguments will be canonically ordered. If set to 'none',
        ordering will be faster but dependent on expressions hashes, thus
        machine dependent and variable. For large expressions where speed is a
        concern, use the setting order='none'.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    from sympy.matrices import Matrix

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]

    if optimizations is None:
        optimizations = list()
    elif optimizations == 'basic':
        optimizations = basic_optimizations

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    excluded_symbols = set.union(*[expr.atoms(Symbol)
                                   for expr in reduced_exprs])

    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    symbols = filter_symbols(symbols, excluded_symbols)

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs, order)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations)
                     for e in reduced_exprs]

    if isinstance(exprs, Matrix):
        reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)]
    if postprocess is None:
        return replacements, reduced_exprs
    return postprocess(replacements, reduced_exprs)
def rsolve_hyper(coeffs, f, n, **hints):
    Given linear recurrence operator `\operatorname{L}` of order `k`
    with polynomial coefficients and inhomogeneous equation
    `\operatorname{L} y = f` we seek for all hypergeometric solutions
    over field `K` of characteristic zero.

    The inhomogeneous part can be either hypergeometric or a sum
    of a fixed number of pairwise dissimilar hypergeometric terms.

    The algorithm performs three basic steps:

        (1) Group together similar hypergeometric terms in the
            inhomogeneous part of `\operatorname{L} y = f`, and find
            particular solution using Abramov's algorithm.

        (2) Compute generating set of `\operatorname{L}` and find basis
            in it, so that all solutions are linearly independent.

        (3) Form final solution with the number of arbitrary
            constants equal to dimension of basis of `\operatorname{L}`.

    Term `a(n)` is hypergeometric if it is annihilated by first order
    linear difference equations with polynomial coefficients or, in
    simpler words, if consecutive term ratio is a rational function.

    The output of this procedure is a linear combination of fixed
    number of hypergeometric terms. However the underlying method
    can generate larger class of solutions - D'Alembertian terms.

    Note also that this method not only computes the kernel of the
    inhomogeneous equation, but also reduces in to a basis so that
    solutions generated by this procedure are linearly independent


    >>> from sympy.solvers import rsolve_hyper
    >>> from sympy.abc import x

    >>> rsolve_hyper([-1, -1, 1], 0, x)
    C0*(1/2 + sqrt(5)/2)**x + C1*(-sqrt(5)/2 + 1/2)**x

    >>> rsolve_hyper([-1, 1], 1 + x, x)
    C0 + x*(x + 1)/2


    .. [1] M. Petkovsek, Hypergeometric solutions of linear recurrences
           with polynomial coefficients, J. Symbolic Computation,
           14 (1992), 243-264.

    .. [2] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996.
    coeffs = map(sympify, coeffs)

    f = sympify(f)

    r, kernel, symbols = len(coeffs) - 1, [], set()

    if not f.is_zero:
        if f.is_Add:
            similar = {}

            for g in f.expand().args:
                if not g.is_hypergeometric(n):
                    return None

                for h in similar.iterkeys():
                    if hypersimilar(g, h, n):
                        similar[h] += g
                    similar[g] = S.Zero

            inhomogeneous = []

            for g, h in similar.iteritems():
                inhomogeneous.append(g + h)
        elif f.is_hypergeometric(n):
            inhomogeneous = [f]
            return None

        for i, g in enumerate(inhomogeneous):
            coeff, polys = S.One, coeffs[:]
            denoms = [ S.One ] * (r + 1)

            s = hypersimp(g, n)

            for j in xrange(1, r + 1):
                coeff *= s.subs(n, n + j - 1)

                p, q = coeff.as_numer_denom()

                polys[j] *= p
                denoms[j] = q

            for j in xrange(0, r + 1):
                polys[j] *= Mul(*(denoms[:j] + denoms[j + 1:]))

            R = rsolve_poly(polys, Mul(*denoms), n)

            if not (R is None or R is S.Zero):
                inhomogeneous[i] *= R
                return None

            result = Add(*inhomogeneous)
        result = S.Zero

    Z = Dummy('Z')

    p, q = coeffs[0], coeffs[r].subs(n, n - r + 1)

    p_factors = [ z for z in roots(p, n).iterkeys() ]
    q_factors = [ z for z in roots(q, n).iterkeys() ]

    factors = [ (S.One, S.One) ]

    for p in p_factors:
        for q in q_factors:
            if p.is_integer and q.is_integer and p <= q:
                factors += [(n - p, n - q)]

    p = [ (n - p, S.One) for p in p_factors ]
    q = [ (S.One, n - q) for q in q_factors ]

    factors = p + factors + q

    for A, B in factors:
        polys, degrees = [], []
        D = A*B.subs(n, n + r - 1)

        for i in xrange(0, r + 1):
            a = Mul(*[ A.subs(n, n + j) for j in xrange(0, i) ])
            b = Mul(*[ B.subs(n, n + j) for j in xrange(i, r) ])

            poly = quo(coeffs[i]*a*b, D, n)

            if not poly.is_zero:

        d, poly = max(degrees), S.Zero

        for i in xrange(0, r + 1):
            coeff = polys[i].nth(d)

            if coeff is not S.Zero:
                poly += coeff * Z**i

        for z in roots(poly, Z).iterkeys():
            if z.is_zero:

            (C, s) = rsolve_poly([ polys[i]*z**i for i in xrange(r + 1) ], 0, n, symbols=True)

            if C is not None and C is not S.Zero:
                symbols |= set(s)

                ratio = z * A * C.subs(n, n + 1) / B / C
                ratio = simplify(ratio)
                # If there is a nonnegative root in the denominator of the ratio,
                # this indicates that the term y(n_root) is zero, and one should
                # start the product with the term y(n_root + 1).
                n0 = 0
                for n_root in roots(ratio.as_numer_denom()[1], n).keys():
                    if (n0 < (n_root + 1)) is True:
                        n0 = n_root + 1
                K = product(ratio, (n, n0, n - 1))
                if K.has(factorial, FallingFactorial, RisingFactorial):
                    K = simplify(K)

                if casoratian(kernel + [K], n, zero=False) != 0:

    sk = zip(numbered_symbols('C'), kernel)

    if sk:
        for C, ker in sk:
            result += C * ker
        return None

    if hints.get('symbols', False):
        symbols |= set([s for s, k in sk])
        return (result, list(symbols))
        return result
def test_filter_symbols():
    s = numbered_symbols()
    filtered = filter_symbols(s, symbols("x0 x2 x3"))
    assert take(filtered, 3) == list(symbols("x1 x4 x5"))
Exemplo n.º 37
def test_numbered_symbols():
    s = numbered_symbols(cls=Dummy)
    assert isinstance(s.next(), Dummy)
def cse(exprs, symbols=None, optimizations=None):
    """ Perform common subexpression elimination on an expression.


    exprs : list of sympy expressions, or a single sympy expression
        The expressions to reduce.
    symbols : infinite iterator yielding unique Symbols
        The symbols used to label the common subexpressions which are pulled
        out. The ``numbered_symbols`` generator is useful. The default is a stream
        of symbols of the form "x0", "x1", etc. This must be an infinite
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs. If not provided,
        ``sympy.simplify.cse.cse_optimizations`` is used.


    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced. Subexpressions
        earlier in this list might show up in subexpressions later in this list.
    reduced_exprs : list of sympy expressions
        The reduced expressions with all of the replacements above.
    if symbols is None:
        symbols = numbered_symbols()
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)
    seen_subexp = set()
    muls = set()
    adds = set()
    to_eliminate = []
    to_eliminate_ops_count = []

    if optimizations is None:
        # Pull out the default here just in case there are some weird
        # manipulations of the module-level list in some other thread.
        optimizations = list(cse_optimizations)

    # Handle the case if just one expression was passed.
    if isinstance(exprs, Basic):
        exprs = [exprs]
    # Preprocess the expressions to give us better optimization opportunities.
    exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    # Find all of the repeated subexpressions.
    def insert(subtree):
        '''This helper will insert the subtree into to_eliminate while
        maintaining the ordering by op count and will skip the insertion
        if subtree is already present.'''
        ops_count = subtree.count_ops()
        index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count)
        # all i up to this index have op count <= the current op count
        # so check that subtree is not yet present from this index down
        # (if necessary) to zero.
        for i in xrange(index_to_insert - 1, -1, -1):
            if to_eliminate_ops_count[i] == ops_count and \
               subtree == to_eliminate[i]:
                return # already have it
        to_eliminate_ops_count.insert(index_to_insert, ops_count)
        to_eliminate.insert(index_to_insert, subtree)

    for expr in exprs:
        pt = preorder_traversal(expr)
        for subtree in pt:
            if subtree.is_Atom:
                # Exclude atoms, since there is no point in renaming them.

            if subtree in seen_subexp:

            if subtree.is_Mul:
            elif subtree.is_Add:


    # process adds - any adds that weren't repeated might contain
    # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
    adds = [set(a.args) for a in adds]
    for i in xrange(len(adds)):
        for j in xrange(i + 1, len(adds)):
            com = adds[i].intersection(adds[j])
            if len(com) > 1:

                # remove this set of symbols so it doesn't appear again
                adds[i] = adds[i].difference(com)
                adds[j] = adds[j].difference(com)
                for k in xrange(j + 1, len(adds)):
                    if not com.difference(adds[k]):
                        adds[k] = adds[k].difference(com)

    # process muls - any muls that weren't repeated might contain
    # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common

    # use SequenceMatcher on the nc part to find the longest common expression
    # in common between the two nc parts
    sm = difflib.SequenceMatcher()

    muls = [a.args_cnc() for a in muls]
    for i in xrange(len(muls)):
        if muls[i][1]:
        for j in xrange(i + 1, len(muls)):
            # the commutative part in common
            ccom = muls[i][0].intersection(muls[j][0])

            # the non-commutative part in common
            if muls[i][1] and muls[j][1]:
                # see if there is any chance of an nc match
                ncom = set(muls[i][1]).intersection(set(muls[j][1]))
                if len(ccom) + len(ncom) < 2:

                # now work harder to find the match
                i1, _, n = sm.find_longest_match(0, len(muls[i][1]),
                                                 0, len(muls[j][1]))
                ncom = muls[i][1][i1:i1 + n]
                ncom = []

            com = list(ccom) + ncom
            if len(com) < 2:


            # remove ccom from all if there was no ncom; to update the nc part
            # would require finding the subexpr and then replacing it with a
            # dummy to keep bounding nc symbols from being identified as a
            # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be
            # identified as a subexpr which would not be right.
            if not ncom:
                muls[i][0] = muls[i][0].difference(ccom)
                for k in xrange(j, len(muls)):
                    if not ccom.difference(muls[k][0]):
                        muls[k][0] = muls[k][0].difference(ccom)

    # Substitute symbols for all of the repeated subexpressions.
    replacements = []
    reduced_exprs = list(exprs)
    for i, subtree in enumerate(to_eliminate):
        sym = symbols.next()
        replacements.append((sym, subtree))
        # Make the substitution in all of the target expressions.
        for j, expr in enumerate(reduced_exprs):
            reduced_exprs[j] = expr.subs(subtree, sym)
        # Make the substitution in all of the subsequent substitutions.
        for j in range(i+1, len(to_eliminate)):
            to_eliminate[j] = to_eliminate[j].subs(subtree, sym)

    # Postprocess the expressions to return the expressions to canonical form.
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]

    return replacements, reduced_exprs
def test_filter_symbols():
    s = numbered_symbols()
    filtered = filter_symbols(s, symbols("x0 x2 x3"))
    assert take(filtered, 3) == list(symbols("x1 x4 x5"))
def cgen_ncomp(ncomp=3, nporder=2, aggstat=False, debug=False):
    """Generates a C function for ncomp (int) number of components.
    The jth key component is always in the first position and the kth
    key component is always in the second.  The number of enrichment 
    stages (NP) is calculated via a taylor series approximation.  The
    order of this approximation may be set with nporder.  Only values
    of 1 or 2 are allowed. The aggstat argument determines whether the
    status messages should be aggreated and printed at the end or output
    as the function executes.
    start_time = time.time()
    stat = _aggstatus('', "generating {0} component enrichment".format(ncomp), aggstat)
    r = range(0, ncomp)
    j = 0
    k = 1

    # setup-symbols
    alpha = Symbol('alpha', positive=True, real=True)
    LpF = Symbol('LpF', positive=True, real=True)
    PpF = Symbol('PpF', positive=True, real=True)
    TpF = Symbol('TpF', positive=True, real=True)
    SWUpF = Symbol('SWUpF', positive=True, real=True)
    SWUpP = Symbol('SWUpP', positive=True, real=True)
    NP = Symbol('NP', positive=True, real=True)   # Enrichment Stages
    NT = Symbol('NT', positive=True, real=True)   # De-enrichment Stages
    NP0 = Symbol('NP0', positive=True, real=True) # Enrichment Stages Initial Guess
    NT0 = Symbol('NT0', positive=True, real=True) # De-enrichment Stages Initial Guess
    NP1 = Symbol('NP1', positive=True, real=True) # Enrichment Stages Computed Value
    NT1 = Symbol('NT1', positive=True, real=True) # De-enrichment Stages Computed Value
    Mstar = Symbol('Mstar', positive=True, real=True)
    MW = [Symbol('MW[{0}]'.format(i), positive=True, real=True) for i in r]
    beta = [alpha**(Mstar - MWi) for MWi in MW]

    # np_closed helper terms
    NP_b = Symbol('NP_b', real=True)
    NP_2a = Symbol('NP_2a', real=True)
    NP_sqrt_base = Symbol('NP_sqrt_base', real=True)

    xF = [Symbol('xF[{0}]'.format(i), positive=True, real=True) for i in r]
    xPi = [Symbol('xP[{0}]'.format(i), positive=True, real=True) for i in r]
    xTi = [Symbol('xT[{0}]'.format(i), positive=True, real=True) for i in r]
    xPj = Symbol('xPj', positive=True, real=True)
    xFj = xF[j]
    xTj = Symbol('xTj', positive=True, real=True)
    ppf = (xFj - xTj)/(xPj - xTj)
    tpf = (xFj - xPj)/(xTj - xPj)

    xP = [(((xF[i]/ppf)*(beta[i]**(NT+1) - 1))/(beta[i]**(NT+1) - beta[i]**(-NP))) \
                                                                            for i in r]
    xT = [(((xF[i]/tpf)*(1 - beta[i]**(-NP)))/(beta[i]**(NT+1) - beta[i]**(-NP))) \
                                                                            for i in r]
    rfeed = xFj / xF[k]
    rprod = xPj / xP[k]
    rtail = xTj / xT[k]

    # setup constraint equations
    numer = [ppf*xP[i]*log(rprod) + tpf*xT[i]*log(rtail) - xF[i]*log(rfeed) for i in r]
    denom = [log(beta[j]) * ((beta[i] - 1.0)/(beta[i] + 1.0)) for i in r]
    LoverF = sum([n/d for n, d in zip(numer, denom)])
    SWUoverF = -1.0 * sum(numer)
    SWUoverP = SWUoverF / ppf

    prod_constraint = (xPj/xFj)*ppf - (beta[j]**(NT+1) - 1)/\
                      (beta[j]**(NT+1) - beta[j]**(-NP))
    tail_constraint = (xTj/xFj)*(sum(xT)) - (1 - beta[j]**(-NP))/\
                      (beta[j]**(NT+1) - beta[j]**(-NP))
    #xp_constraint = 1.0 - sum(xP)
    #xf_constraint = 1.0 - sum(xF)
    #xt_constraint = 1.0 - sum(xT)

    # This is NT(NP,...) and is correct!
    #nt_closed = solve(prod_constraint, NT)[0] 

    # However, this is NT(NP,...) rewritten (by hand) to minimize the number of NP 
    # and M* instances in the expression.  Luckily this is only depends on the key 
    # component and remains general no matter the number of components.
    nt_closed = (-MW[0]*log(alpha) + Mstar*log(alpha) + log(xTj) + log((-1.0 + xPj/\
        xF[0])/(xPj - xTj)) - log(alpha**(NP*(MW[0] - Mstar))*(xF[0]*xPj - xPj*xTj)/\
        (-xF[0]*xPj + xF[0]*xTj) + 1))/((MW[0] - Mstar)*log(alpha))

    # new expression for normalized flow rate
    # NOTE: not needed, solved below
    #loverf = LoverF.xreplace({NT: nt_closed})

    # Define the constraint equation with which to solve NP. This is chosen such to 
    # minimize the number of ops in the derivatives (and thus np_closed).  Other, 
    # more verbose possibilities are commented out.
    #np_constraint = (xP[j]/sum(xP) - xPj).xreplace({NT: nt_closed})
    #np_constraint = (xP[j]- sum(xP)*xPj).xreplace({NT: nt_closed})
    #np_constraint = (xT[j]/sum(xT) - xTj).xreplace({NT: nt_closed})
    np_constraint = (xT[j] - sum(xT)*xTj).xreplace({NT: nt_closed})

    # get closed form approximation of NP via symbolic derivatives
    stat = _aggstatus(stat, "  order-{0} NP approximation".format(nporder), aggstat)
    d0NP = np_constraint.xreplace({NP: NP0})
    d1NP = diff(np_constraint, NP, 1).xreplace({NP: NP0})
    if 1 == nporder:
        np_closed = NP0 - d1NP / d0NP
    elif 2 == nporder:
        d2NP = diff(np_constraint, NP, 2).xreplace({NP: NP0})/2.0
        # taylor series polynomial coefficients, grouped by order
        # f(x) = ax**2 + bx + c
        a = d2NP
        b = d1NP - 2*NP0*d2NP
        c = d0NP - NP0*d1NP + NP0*NP0*d2NP
        # quadratic eq. (minus only)
        #np_closed = (-b - sqrt(b**2 - 4*a*c)) / (2*a)
        # However, we need to break up this expr as follows to prevent 
        # a floating point arithmetic bug if b**2 - 4*a*c is very close
        # to zero but happens to be negative.  LAME!!!
        np_2a = 2*a
        np_sqrt_base = b**2 - 4*a*c
        np_closed = (-NP_b - sqrt(NP_sqrt_base)) / (NP_2a)
        raise ValueError("nporder must be 1 or 2")

    # generate cse for writing out
    msg = "  minimizing ops by eliminating common sub-expressions"
    stat = _aggstatus(stat, msg, aggstat)
    exprstages = [Eq(NP_b, b), Eq(NP_2a, np_2a), 
                  # fix for floating point sqrt() error
                  Eq(NP_sqrt_base, np_sqrt_base), Eq(NP_sqrt_base, Abs(NP_sqrt_base)), 
                  Eq(NP1, np_closed), Eq(NT1, nt_closed).xreplace({NP: NP1})]
    cse_stages = cse(exprstages, numbered_symbols('n'))
    exprothers = [Eq(LpF, LoverF), Eq(PpF, ppf), Eq(TpF, tpf), 
                  Eq(SWUpF, SWUoverF), Eq(SWUpP, SWUoverP)] + \
                 [Eq(*z) for z in zip(xPi, xP)] + [Eq(*z) for z in zip(xTi, xT)]
    exprothers = [e.xreplace({NP: NP1, NT: NT1}) for e in exprothers]
    cse_others = cse(exprothers, numbered_symbols('g'))
    exprops = count_ops(exprstages + exprothers)
    cse_ops = count_ops(cse_stages + cse_others)
    msg = "    reduced {0} ops to {1}".format(exprops, cse_ops)
    stat = _aggstatus(stat, msg, aggstat)

    # create function body
    ccode, repnames = cse_to_c(*cse_stages, indent=6, debug=debug)
    ccode_others, repnames_others = cse_to_c(*cse_others, indent=6, debug=debug)
    ccode += ccode_others
    repnames |= repnames_others

    msg = "  completed in {0:.3G} s".format(time.time() - start_time)
    stat = _aggstatus(stat, msg, aggstat)
    if aggstat:
    return ccode, repnames, stat
    def _eval_rewrite_as_sqrt(self, arg):
        _EXPAND_INTS = False

        def migcdex(x):
            # recursive calcuation of gcd and linear combination
            # for a sequence of integers.
            # Given  (x1, x2, x3)
            # Returns (y1, y1, y3, g)
            # such that g is the gcd and x1*y1+x2*y2+x3*y3 - g = 0
            # Note, that this is only one such linear combination.
            if len(x) == 1:
                return (1, x[0])
            if len(x) == 2:
                return igcdex(x[0], x[-1])
            g = migcdex(x[1:])
            u, v, h = igcdex(x[0], g[-1])
            return tuple([u] + [v*i for i in g[0:-1] ] + [h])

        def ipartfrac(r, factors=None):
            if isinstance(r, int):
                return r
            assert isinstance(r, C.Rational)
            n = r.q
            if 2 > r.q*r.q:
                return r.q

            if None == factors:
                a = [n/x**y for x, y in factorint(r.q).iteritems()]
                a = [n/x for x in factors]
            if len(a) == 1:
                return [ r ]
            h = migcdex(a)
            ans = [ r.p*C.Rational(i*j, r.q) for i, j in zip(h[:-1], a) ]
            assert r == sum(ans)
            return ans
        pi_coeff = _pi_coeff(arg)
        if pi_coeff is None:
            return None

        assert not pi_coeff.is_integer, "should have been simplified already"

        if not pi_coeff.is_Rational:
            return None

        cst_table_some = {
            3: S.Half,
            5: (sqrt(5) + 1)/4,
            17: sqrt((15 + sqrt(17))/32 + sqrt(2)*(sqrt(17 - sqrt(17)) +
                sqrt(sqrt(2)*(-8*sqrt(17 + sqrt(17)) - (1 - sqrt(17))
                *sqrt(17 - sqrt(17))) + 6*sqrt(17) + 34))/32)
            # 65537 and 257 are the only other known Fermat primes
            # Please add if you would like them

        def fermatCoords(n):
            assert isinstance(n, int)
            assert n > 0
            if n == 1 or 0 == n % 2:
                return False
            primes = dict( [(p, 0) for p in cst_table_some ] )
            assert 1 not in primes
            for p_i in primes:
                while 0 == n % p_i:
                    n = n/p_i
                    primes[p_i] += 1
            if 1 != n:
                return False
            if max(primes.values()) > 1:
                return False
            return tuple([ p for p in primes if primes[p] == 1])

        if pi_coeff.q in cst_table_some:
            return C.chebyshevt(pi_coeff.p, cst_table_some[pi_coeff.q]).expand()

        if 0 == pi_coeff.q % 2:  # recursively remove powers of 2
            narg = (pi_coeff*2)*S.Pi
            nval = cos(narg)
            if None == nval:
                return None
            nval = nval.rewrite(sqrt)
            if not _EXPAND_INTS:
                if (isinstance(nval, cos) or isinstance(-nval, cos)):
                    return None
            x = (2*pi_coeff + 1)/2
            sign_cos = (-1)**((-1 if x < 0 else 1)*int(abs(x)))
            return sign_cos*sqrt( (1 + nval)/2 )

        FC = fermatCoords(pi_coeff.q)
        if FC:
            decomp = ipartfrac(pi_coeff, FC)
            X = [(x[1], x[0]*S.Pi) for x in zip(decomp, numbered_symbols('z'))]
            pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X)
            return pcls.rewrite(sqrt)
        if _EXPAND_INTS:
            decomp = ipartfrac(pi_coeff)
            X = [(x[1], x[0]*S.Pi) for x in zip(decomp, numbered_symbols('z'))]
            pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X)
            return pcls
        return None
 def __init__(self, commands):
     self.commands = commands
     self.get_symbol = numbered_symbols("tmp")