示例#1
0
def test_cse_single():
    # Simple substitution.
    e = Add(Pow(x + y, 2), sqrt(x + y))
    substs, reduced = cse([e])
    assert substs == [(x0, x + y)]
    assert reduced == [sqrt(x0) + x0**2]
    assert cse([e], order='none') == cse([e])
示例#2
0
def test_non_commutative_order():
    A, B, C = symbols('A B C', commutative=False)
    x0 = symbols('x0', commutative=False)
    l = [B+C, A*(B+C)]
    assert cse(l) == ([(x0, B+C)], [x0, A*x0])
    l = [(A - B)**2 + A - B]
    assert cse(l) == ([(x0, A - B)], [x0**2 + x0])
示例#3
0
def test_sympyissue_7840():
    # daveknippers' example
    C393 = sympify(
        'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
        C391 > 2.35), (C392, True)), True))'
    )
    C391 = sympify(
        'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))'
    )
    C393 = C393.subs({'C391': C391})
    # simple substitution
    sub = {}
    sub['C390'] = 0.703451854
    sub['C392'] = 1.01417794
    ss_answer = C393.subs(sub)
    # cse
    substitutions, new_eqn = cse(C393)
    for pair in substitutions:
        sub[pair[0].name] = pair[1].subs(sub)
    cse_answer = new_eqn[0].subs(sub)
    # both methods should be the same
    assert ss_answer == cse_answer

    # GitRay's example
    expr = Piecewise((Symbol('ON'), Eq(Symbol('mode'), Symbol('ON'))),
                     (Piecewise((Piecewise((Symbol('OFF'), Symbol('x') < Symbol('threshold')),
                                           (Symbol('ON'), true)), Eq(Symbol('mode'), Symbol('AUTO'))),
                                (Symbol('OFF'), true)), true))
    substitutions, new_eqn = cse(expr)
    # this Piecewise should be exactly the same
    assert new_eqn[0] == expr
    # there should not be any replacements
    assert len(substitutions) < 1
示例#4
0
def test_cse_single2():
    # Simple substitution, test for being able to pass the expression directly
    e = Add(Pow(x + y, 2), sqrt(x + y))
    substs, reduced = cse(e)
    assert substs == [(x0, x + y)]
    assert reduced == [sqrt(x0) + x0**2]
    substs, reduced = cse(Matrix([[1]]))
    assert isinstance(reduced[0], Matrix)
示例#5
0
def test_bypass_non_commutatives():
    A, B, C = symbols('A B C', commutative=False)
    l = [A*B*C, A*C]
    assert cse(l) == ([], l)
    l = [A*B*C, A*B]
    assert cse(l) == ([], l)
    l = [B*C, A*B*C]
    assert cse(l) == ([], l)
示例#6
0
def test_cse_not_possible():
    # No substitution possible.
    e = Add(x, y)
    substs, reduced = cse([e])
    assert substs == []
    assert reduced == [x + y]
    # issue sympy/sympy#6329
    eq = (meijerg((1, 2), (y, 4), (5,), [], x) +
          meijerg((1, 3), (y, 4), (5,), [], x))
    assert cse(eq) == ([], [eq])
示例#7
0
def test_derivative_subs():
    y = Symbol('y')
    f = Function('f')
    assert Derivative(f(x), x).subs({f(x): y}) != 0
    assert Derivative(f(x), x).subs({f(x): y}).subs({y: f(x)}) == \
        Derivative(f(x), x)
    # issues sympy/sympy#5085, sympy/sympy#5037
    assert cse(Derivative(f(x), x) + f(x))[1][0].has(Derivative)
    assert cse(Derivative(f(x, y), x) +
               Derivative(f(x, y), y))[1][0].has(Derivative)
示例#8
0
def test_dont_cse_tuples():
    f = Function("f")
    g = Function("g")

    name_val, (expr,) = cse(Subs(f(x, y), (x, 0), (y, 1)) +
                            Subs(g(x, y), (x, 0), (y, 1)))

    assert name_val == []
    assert expr == (Subs(f(x, y), (x, 0), (y, 1))
                    + Subs(g(x, y), (x, 0), (y, 1)))

    name_val, (expr,) = cse(Subs(f(x, y), (x, 0), (y, x + y)) +
                            Subs(g(x, y), (x, 0), (y, x + y)))

    assert name_val == [(x0, x + y)]
    assert expr == Subs(f(x, y), (x, 0), (y, x0)) + Subs(g(x, y), (x, 0), (y, x0))
示例#9
0
def test_sympyissue_8891():
    for cls in (MutableDenseMatrix, MutableSparseMatrix,
                ImmutableDenseMatrix, ImmutableSparseMatrix):
        m = cls(2, 2, [x + y, 0, 0, 0])
        res = cse([x + y, m])
        ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])])
        assert res == ans
        assert isinstance(res[1][-1], cls)
示例#10
0
def test_sympyissue_6559():
    assert (-12*x + y).subs({-x: 1}) == 12 + y
    # though this involves cse it generated a failure in Mul._eval_subs
    x0, x1 = symbols('x0 x1')
    e = -log(-12*sqrt(2) + 17)/24 - log(-2*sqrt(2) + 3)/12 + sqrt(2)/3
    # XXX modify cse so x1 is eliminated and x0 = -sqrt(2)?
    assert cse(e) == (
        [(x0, sqrt(2))], [x0/3 - log(-12*x0 + 17)/24 - log(-2*x0 + 3)/12])
示例#11
0
def test_subtraction_opt():
    # Make sure subtraction is optimized.
    e = (x - y)*(z - y) + exp((x - y)*(z - y))
    substs, reduced = cse(
        [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
    assert substs == [(x0, (x - y)*(y - z))]
    assert reduced == [-x0 + exp(-x0)]
    e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
    substs, reduced = cse(
        [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
    assert substs == [(x0, (x - y)*(y - z))]
    assert reduced == [x0 + exp(x0)]
    # issue sympy/sympy#4077
    n = -1 + 1/x
    e = n/x/(-n)**2 - 1/n/x
    assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
        ([], [0])
示例#12
0
def test_cse_MatrixSymbol():
    A = MatrixSymbol('A', 3, 3)
    y = MatrixSymbol('y', 3, 1)

    expr1 = (A.T*A).inverse() * A * y
    expr2 = (A.T*A) * A * y
    replacements, reduced_exprs = cse([expr1, expr2])
    assert len(replacements) > 0
示例#13
0
def test_cse_Indexed():
    len_y = 5
    y = IndexedBase('y', shape=(len_y,))
    x = IndexedBase('x', shape=(len_y,))
    i = Idx('i', len_y-1)

    expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
    expr2 = 1/(x[i+1]-x[i])
    replacements, reduced_exprs = cse([expr1, expr2])
    assert len(replacements) > 0
示例#14
0
def test_pow_invpow():
    assert cse(1/x**2 + x**2) == \
        ([(x0, x**2)], [x0 + 1/x0])
    assert cse(x**2 + (1 + 1/x**2)/x**2) == \
        ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
    assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
        ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
    assert cse(cos(1/x**2) + sin(1/x**2)) == \
        ([(x0, x**(-2))], [sin(x0) + cos(x0)])
    assert cse(cos(x**2) + sin(x**2)) == \
        ([(x0, x**2)], [sin(x0) + cos(x0)])
    assert cse(y/(2 + x**2) + z/x**2/y) == \
        ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
    assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
        ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
    assert cse((1 + 1/x**2)/x**2) == \
        ([(x0, x**(-2))], [x0*(x0 + 1)])
    assert cse(x**(2*y) + x**(-2*y)) == \
        ([(x0, x**(2*y))], [x0 + 1/x0])
示例#15
0
def test_sympyissue_4499():
    # previously, this gave 16 constants
    B = Function('B')
    G = Function('G')
    t = Tuple(*
              (a, a + Rational(1, 2), 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
                                                                                     b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
               sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
                                                               sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
                                                                                                                               sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
               (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
                                                       sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, Rational(1, 2), z/2, -b + 1, -2*a + b,
                  -2*a))
    c = cse(t)
    ans = (
        [(x0, 2*a), (x1, -b), (x2, x1 + 1), (x3, x0 + x2), (x4, sqrt(z)), (x5,
                                                                           B(x0 + x1, x4)), (x6, G(b)), (x7, G(x3)), (x8, -x0), (x9,
                                                                                                                                 (x4/2)**(x8 + 1)), (x10, x6*x7*x9*B(b - 1, x4)), (x11, x6*x7*x9*B(b,
                                                                                                                                                                                                   x4)), (x12, B(x3, x4))], [(a, a + Rational(1, 2), x0, b, x3, x10*x5,
                                                                                                                                                                                                                              x11*x4*x5, x10*x12*x4, x11*x12, 1, 0, Rational(1, 2), z/2, x2, b + x8, x8)])
    assert ans == c
示例#16
0
def test_postprocess():
    eq = (x + 1 + exp((x + 1) / (y + 1)) + cos(y + 1))
    assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
               postprocess=cse_main.cse_separate) == \
        [[(x1, y + 1), (x2, z + 1), (x, x2), (x0, x + 1)],
         [x0 + exp(x0/x1) + cos(x1), z - 2, x0*x2]]
示例#17
0
def test_sympyissue_4203():
    assert cse(sin(x**x) / x**x) == ([(x0, x**x)], [sin(x0) / x0])
示例#18
0
def test_sympyissue_4498():
    assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
        ([], [(w - z)/(x - y)])
示例#19
0
def test_non_commutative_cse_mul():
    x0 = symbols('x0', commutative=False)
    A, B, C = symbols('A B C', commutative=False)
    l = [A * B * C, A * B]
    assert cse(l) == ([(x0, A * B)], [x0 * C, x0])
示例#20
0
def test_multiple_expressions():
    e1 = (x + y) * z
    e2 = (x + y) * w
    substs, reduced = cse([e1, e2])
    assert substs == [(x0, x + y)]
    assert reduced == [x0 * z, x0 * w]
    l = [w * x * y + z, w * y]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == rsubsts
    assert reduced == [z + x * x0, x0]
    l = [w * x * y, w * x * y + z, w * y]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == rsubsts
    assert reduced == [x1, x1 + z, x0]
    l = [(x - z) * (y - z), x - z, y - z]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
    assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
    assert reduced == [x1 * x2, x1, x2]
    l = [w * y + w + x + y + z, w * x * y]
    assert cse(l) == ([(x0, w * y)], [w + x + x0 + y + z, x * x0])
    assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
    assert cse([x + y, x + z]) == ([], [x + y, x + z])
    assert cse([x*y, z + x*y, x*y*z + 3]) == \
        ([(x0, x*y)], [x0, z + x0, 3 + x0*z])
示例#21
0
def test_multiple_expressions():
    e1 = (x + y)*z
    e2 = (x + y)*w
    substs, reduced = cse([e1, e2])
    assert substs == [(x0, x + y)]
    assert reduced == [x0*z, x0*w]
    l = [w*x*y + z, w*y]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == rsubsts
    assert reduced == [z + x*x0, x0]
    l = [w*x*y, w*x*y + z, w*y]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == rsubsts
    assert reduced == [x1, x1 + z, x0]
    l = [(x - z)*(y - z), x - z, y - z]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
    assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
    assert reduced == [x1*x2, x1, x2]
    l = [w*y + w + x + y + z, w*x*y]
    assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
    assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
    assert cse([x + y, x + z]) == ([], [x + y, x + z])
    assert cse([x*y, z + x*y, x*y*z + 3]) == \
        ([(x0, x*y)], [x0, z + x0, 3 + x0*z])
示例#22
0
def test_ignore_order_terms():
    eq = exp(x).series(x, 0, 3) + sin(y + x**3) - 1
    assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)])
示例#23
0
def test_Piecewise():
    f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
    ans = cse(f)
    actual_ans = ([(x0, -z), (x1, x*y)], [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))])
    assert ans == actual_ans
示例#24
0
def test_sympyissue_6169():
    r = RootOf(x**6 - 4*x**5 - 2, 1)
    assert cse(r) == ([], [r])
    # and a check that the right thing is done with the new
    # mechanism
    assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y
示例#25
0
def test_postprocess():
    eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
    assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
               postprocess=cse_main.cse_separate) == \
        [[(x1, y + 1), (x2, z + 1), (x, x2), (x0, x + 1)],
         [x0 + exp(x0/x1) + cos(x1), z - 2, x0*x2]]
示例#26
0
def test_Piecewise():
    f = Piecewise((-z + x * y, Eq(y, 0)), (-z - x * y, True))
    ans = cse(f)
    actual_ans = ([(x0, -z), (x1, x * y)],
                  [Piecewise((x0 + x1, Eq(y, 0)), (x0 - x1, True))])
    assert ans == actual_ans
示例#27
0
def test_non_commutative_order():
    A, B, C = symbols('A B C', commutative=False)
    x0 = symbols('x0', commutative=False)
    l = [B+C, A*(B+C)]
    assert cse(l) == ([(x0, B+C)], [x0, A*x0])
示例#28
0
def test_sympyissue_6263():
    e = Eq(x*(-x + 1) + x*(x - 1), 0)
    assert cse(e, optimizations='basic') == ([], [True])
示例#29
0
def test_cse_single():
    # Simple substitution.
    e = Add(Pow(x + y, 2), sqrt(x + y))
    substs, reduced = cse([e])
    assert substs == [(x0, x + y)]
    assert reduced == [sqrt(x0) + x0**2]
示例#30
0
def test_name_conflict():
    z1 = x0 + y
    z2 = x2 + x3
    l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
    substs, reduced = cse(l)
    assert [e.subs(reversed(substs)) for e in reduced] == l
示例#31
0
def test_sympyissue_4020():
    assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
        == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])
示例#32
0
def test_symbols_exhausted_error():
    l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
    sym = [x, y, z]
    with pytest.raises(ValueError) as excinfo:
        cse(l, symbols=sym)
示例#33
0
def test_name_conflict_cust_symbols():
    z1 = x0 + y
    z2 = x2 + x3
    l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
    substs, reduced = cse(l, symbols("x:10"))
    assert [e.subs(reversed(substs)) for e in reduced] == l
示例#34
0
def test_nested_substitution():
    # Substitution within a substitution.
    e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
    substs, reduced = cse([e])
    assert substs == [(x0, w*x + y)]
    assert reduced == [sqrt(x0) + x0**2]
示例#35
0
def test_symbols_exhausted_error():
    l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
    sym = [x, y, z]
    with pytest.raises(ValueError):
        cse(l, symbols=sym)
示例#36
0
def test_non_commutative_cse():
    A, B, C = symbols('A B C', commutative=False)
    l = [A * B * C, A * C]
    assert cse(l) == ([], l)
示例#37
0
def test_name_conflict():
    z1 = x0 + y
    z2 = x2 + x3
    l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
    substs, reduced = cse(l)
    assert [e.subs(reversed(substs)) for e in reduced] == l
示例#38
0
def test_powers():
    assert cse(x * y**2 + x * y) == ([(x0, x * y)], [x0 * y + x0])
示例#39
0
def test_sympyissue_4498():
    assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
        ([], [(w - z)/(x - y)])
示例#40
0
def test_sympyissue_4020():
    assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
        == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])
示例#41
0
def test_powers():
    assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0])
示例#42
0
def test_sympyissue_6263():
    e = Eq(x * (-x + 1) + x * (x - 1), 0)
    assert cse(e, optimizations='basic') == ([], [True])
示例#43
0
def test_non_commutative_cse_mul():
    x0 = symbols('x0', commutative=False)
    A, B, C = symbols('A B C', commutative=False)
    l = [A*B*C, A*B]
    assert cse(l) == ([(x0, A*B)], [x0*C, x0])
示例#44
0
def test_sympyissue_6169():
    r = RootOf(x**6 - 4 * x**5 - 2, 1)
    assert cse(r) == ([], [r])
    # and a check that the right thing is done with the new
    # mechanism
    assert sub_post(sub_pre((-x - y) * z - x - y)) == -z * (x + y) - x - y
示例#45
0
def test_nested_substitution():
    # Substitution within a substitution.
    e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
    substs, reduced = cse([e])
    assert substs == [(x0, w*x + y)]
    assert reduced == [sqrt(x0) + x0**2]
示例#46
0
def test_ignore_order_terms():
    eq = exp(x).series(x, 0, 3) + sin(y + x**3) - 1
    assert cse(eq) == ([], [sin(x**3 + y) + x + x**2 / 2 + O(x**3)])
示例#47
0
def test_sympyissue_4203():
    assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
示例#48
0
def test_name_conflict_cust_symbols():
    z1 = x0 + y
    z2 = x2 + x3
    l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
    substs, reduced = cse(l, symbols('x:10'))
    assert [e.subs(reversed(substs)) for e in reduced] == l
示例#49
0
def test_basic_optimization():
    # issue sympy/sympy#4498
    assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
        ([], [(w - z)/(x - y)])