Beispiel #1
0
    def run(expr):
        # Return semantic (rebuilt expression, factorization candidates)

        if expr.is_Number or expr.is_Symbol:
            return expr, [expr]
        elif expr.is_Indexed or expr.is_Atom:
            return expr, []
        elif expr.is_Add:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])

            w_numbers = [
                i for i in rebuilt if any(j.is_Number for j in i.args)
            ]
            wo_numbers = [i for i in rebuilt if i not in w_numbers]

            w_numbers = collect_const(expr.func(*w_numbers))
            wo_numbers = expr.func(*wo_numbers)

            if aggressive is True and wo_numbers:
                for i in flatten(candidates):
                    wo_numbers = collect(wo_numbers, i)

            rebuilt = expr.func(w_numbers, wo_numbers)
            return rebuilt, []
        elif expr.is_Mul:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])
            rebuilt = collect_const(expr.func(*rebuilt))
            return rebuilt, flatten(candidates)
        elif expr.is_Equality:
            rebuilt, candidates = zip(*[run(expr.lhs), run(expr.rhs)])
            return expr.func(*rebuilt, evaluate=False), flatten(candidates)
        else:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])
            return expr.func(*rebuilt), flatten(candidates)
Beispiel #2
0
    def run(expr):
        # Return semantic (rebuilt expression, factorization candidates)

        if expr.is_Number or expr.is_Symbol:
            return expr, [expr]
        elif expr.is_Indexed or expr.is_Atom:
            return expr, []
        elif expr.is_Add:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])

            w_numbers = [i for i in rebuilt if any(j.is_Number for j in i.args)]
            wo_numbers = [i for i in rebuilt if i not in w_numbers]

            w_numbers = collect_const(expr.func(*w_numbers))
            wo_numbers = expr.func(*wo_numbers)

            if aggressive is True and wo_numbers:
                for i in flatten(candidates):
                    wo_numbers = collect(wo_numbers, i)

            rebuilt = expr.func(w_numbers, wo_numbers)
            return rebuilt, []
        elif expr.is_Mul:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])
            rebuilt = collect_const(expr.func(*rebuilt))
            return rebuilt, flatten(candidates)
        elif expr.is_Equality:
            rebuilt, candidates = zip(*[run(expr.lhs), run(expr.rhs)])
            return expr.func(*rebuilt, evaluate=False), flatten(candidates)
        else:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])
            return expr.func(*rebuilt), flatten(candidates)
Beispiel #3
0
def test_collect_const():
    # coverage not provided by above tests
    assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \
        2*(2*sqrt(5)*a + sqrt(3))  # let the primitive reabsorb
    assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \
        2*sqrt(3) + 4*a*sqrt(5)
    assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
        sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)

    # issue 5290
    assert collect_const(2*x + 2*y + 1, 2) == \
        collect_const(2*x + 2*y + 1) == \
        Add(S(1), Mul(2, x + y, evaluate=False), evaluate=False)
    assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False)
    assert collect_const(2*x - 2*y - 2*z, 2) == \
        Mul(2, x - y - z, evaluate=False)
    assert collect_const(2*x - 2*y - 2*z, -2) == \
        _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False))

    # this is why the content_primitive is used
    eq = (sqrt(15 + 5 * sqrt(2)) * x + sqrt(3 + sqrt(2)) * y) * 2
    assert collect_sqrt(eq + 2) == \
        2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2

    # issue 16296
    assert collect_const(a + b + x / 2 +
                         y / 2) == a + b + Mul(S.Half, x + y, evaluate=False)
Beispiel #4
0
def test_collect_const():
    # coverage not provided by above tests
    assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \
        2*(2*sqrt(5)*a + sqrt(3))  # let the primitive reabsorb
    assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \
        2*sqrt(3) + 4*a*sqrt(5)
    assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
        sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)

    # issue 5290
    assert collect_const(2*x + 2*y + 1, 2) == \
        collect_const(2*x + 2*y + 1) == \
        Add(S(1), Mul(2, x + y, evaluate=False), evaluate=False)
    assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False)
    assert collect_const(2*x - 2*y - 2*z, 2) == \
        Mul(2, x - y - z, evaluate=False)
    assert collect_const(2*x - 2*y - 2*z, -2) == \
        _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False))

    # this is why the content_primitive is used
    eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2
    assert collect_sqrt(eq + 2) == \
        2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2

    # issue 16296
    assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False)
Beispiel #5
0
def test_radsimp():
    r2=sqrt(2)
    r3=sqrt(3)
    r5=sqrt(5)
    r7=sqrt(7)
    assert radsimp(1/r2) == \
        sqrt(2)/2
    assert radsimp(1/(1 + r2)) == \
        -1 + sqrt(2)
    assert radsimp(1/(r2 + r3)) == \
        -sqrt(2) + sqrt(3)
    assert fraction(radsimp(1/(1 + r2 + r3))) == \
        (-sqrt(6) + sqrt(2) + 2, 4)
    assert fraction(radsimp(1/(r2 + r3 + r5))) == \
        (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12)
    assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == \
        (-34*sqrt(10) -
        26*sqrt(15) -
        55*sqrt(3) -
        61*sqrt(2) +
        14*sqrt(30) +
        93 +
        46*sqrt(6) +
        53*sqrt(5), 71)
    assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == \
        (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) -
        145*sqrt(3) + 22*sqrt(105) + 185*sqrt(2) +
        62*sqrt(30) + 135*sqrt(7), 215)
    z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7))
    assert len((3616791619821680643598*z).args) == 16
    assert radsimp(1/z) == 1/z
    assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7
    assert radsimp(1/(r2*3)) == \
        sqrt(2)/6
    assert radsimp(1/(r2*a + r3 + r5 + r7)) == 1/(r2*a + r3 + r5 + r7)
    assert radsimp(1/(r2*a + r2*b + r3 + r7)) == \
        ((sqrt(42)*(a + b) +
        sqrt(3)*(-a**2 - 2*a*b - b**2 - 2)  +
        sqrt(7)*(-a**2 - 2*a*b - b**2 + 2)  +
        sqrt(2)*(a**3 + 3*a**2*b + 3*a*b**2 - 5*a + b**3 - 5*b))/
        ((a**4 + 4*a**3*b + 6*a**2*b**2 - 10*a**2  +
        4*a*b**3 - 20*a*b + b**4 - 10*b**2 + 4)))/2
    assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \
        (sqrt(2)/(a + b + c + d))/2
    assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == \
        ((sqrt(2)*(-a - b - c - d) + 1)/
        (-2*a**2 - 4*a*b - 4*a*c - 4*a*d - 2*b**2 -
        4*b*c - 4*b*d - 2*c**2 - 4*c*d - 2*d**2 + 1))
    assert radsimp((y**2 - x)/(y - sqrt(x))) == \
        sqrt(x) + y
    assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \
        -(sqrt(x) + y)
    assert radsimp(1/(1 - I + a*I)) == \
        (I*(-a + 1) + 1)/(a**2 - 2*a + 2)
    assert radsimp(1/((-x + y)*(x - sqrt(y)))) == (x + sqrt(y))/((-x + y)*(x**2 - y))
    e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y))
    assert radsimp(e) == 9*x*(1 + sqrt(2))*(x - sqrt(y))
    assert radsimp(1/e) == (-1 + sqrt(2))*(x + sqrt(y))/(9*x*(x**2 - y))
    assert radsimp(1 + 1/(1 + sqrt(3))) == Mul(S(1)/2, 1 + sqrt(3), evaluate=False)
    A = symbols("A", commutative=False)
    assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == x**2 + sqrt(2)*(x**2 - x*A)
    assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3)
    assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -11*sqrt(2) + 9*sqrt(3)

    # coverage not provided by above tests
    assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == Mul(2, (2*sqrt(5)*a + sqrt(3)), evaluate=False)
    assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == 2*(2*sqrt(5)*a + sqrt(3))
    assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
        sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)
Beispiel #6
0
    def run(expr):
        # Return semantic (rebuilt expression, factorization candidates)

        if expr.is_Number:
            return expr, {'coeffs': expr}
        elif expr.is_Function:
            return expr, {'funcs': expr}
        elif expr.is_Pow:
            return expr, {'pows': expr}
        elif expr.is_Symbol or expr.is_Indexed or expr.is_Atom:
            return expr, {}
        elif expr.is_Add:
            args, candidates = zip(*[run(arg) for arg in expr.args])
            candidates = ReducerMap.fromdicts(*candidates)

            funcs = candidates.getall('funcs', [])
            pows = candidates.getall('pows', [])
            coeffs = candidates.getall('coeffs', [])

            # Functions/Pows are collected first, coefficients afterwards
            # Note: below we use sets, but SymPy will ensure determinism
            args = set(args)
            w_funcs = {i for i in args if any(j in funcs for j in i.args)}
            args -= w_funcs
            w_pows = {i for i in args if any(j in pows for j in i.args)}
            args -= w_pows
            w_coeffs = {i for i in args if any(j in coeffs for j in i.args)}
            args -= w_coeffs

            # Collect common funcs
            w_funcs = collect(expr.func(*w_funcs), funcs, evaluate=False)
            try:
                w_funcs = Add(
                    *[Mul(k, collect_const(v)) for k, v in w_funcs.items()])
            except AttributeError:
                assert w_funcs == 0

            # Collect common pows
            w_pows = collect(expr.func(*w_pows), pows, evaluate=False)
            try:
                w_pows = Add(
                    *[Mul(k, collect_const(v)) for k, v in w_pows.items()])
            except AttributeError:
                assert w_pows == 0

            # Collect common coefficients
            w_coeffs = collect_const(expr.func(*w_coeffs))

            rebuilt = Add(w_funcs, w_pows, w_coeffs, *args)

            return rebuilt, {}
        elif expr.is_Mul:
            args, candidates = zip(*[run(arg) for arg in expr.args])

            # Always collect coefficients
            rebuilt = collect_const(expr.func(*args))
            try:
                if rebuilt.args:
                    # Note: Mul(*()) -> 1, and since sympy.S.Zero.args == (),
                    # the `if` prevents turning 0 into 1
                    rebuilt = Mul(*rebuilt.args)
            except AttributeError:
                pass

            return rebuilt, ReducerMap.fromdicts(*candidates)
        elif expr.is_Equality:
            args, candidates = zip(*[run(expr.lhs), run(expr.rhs)])
            return expr.func(*args,
                             evaluate=False), ReducerMap.fromdicts(*candidates)
        else:
            args, candidates = zip(*[run(arg) for arg in expr.args])
            return expr.func(*args), ReducerMap.fromdicts(*candidates)
Beispiel #7
0
def test_radsimp():
    r2 = sqrt(2)
    r3 = sqrt(3)
    r5 = sqrt(5)
    r7 = sqrt(7)
    assert radsimp(1/r2) == \
        sqrt(2)/2
    assert radsimp(1/(1 + r2)) == \
        -1 + sqrt(2)
    assert radsimp(1/(r2 + r3)) == \
        -sqrt(2) + sqrt(3)
    assert fraction(radsimp(1/(1 + r2 + r3))) == \
        (-sqrt(6) + sqrt(2) + 2, 4)
    assert fraction(radsimp(1/(r2 + r3 + r5))) == \
        (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12)
    assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == \
        (-34*sqrt(10) -
        26*sqrt(15) -
        55*sqrt(3) -
        61*sqrt(2) +
        14*sqrt(30) +
        93 +
        46*sqrt(6) +
        53*sqrt(5), 71)
    assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == \
        (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) -
        145*sqrt(3) + 22*sqrt(105) + 185*sqrt(2) +
        62*sqrt(30) + 135*sqrt(7), 215)
    assert radsimp(1/(r2*3)) == \
        sqrt(2)/6
    assert radsimp(1/(r2*a + r2*b + r3 + r7)) == \
        ((sqrt(42)*(a + b) +
        sqrt(3)*(-a**2 - 2*a*b - b**2 - 2)  +
        sqrt(7)*(-a**2 - 2*a*b - b**2 + 2)  +
        sqrt(2)*(a**3 + 3*a**2*b + 3*a*b**2 - 5*a + b**3 - 5*b))/
        ((a**4 + 4*a**3*b + 6*a**2*b**2 - 10*a**2  +
        4*a*b**3 - 20*a*b + b**4 - 10*b**2 + 4)))/2
    assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \
        (sqrt(2)/(a + b + c + d))/2
    assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == \
        ((sqrt(2)*(-a - b - c - d) + 1)/
        (-2*a**2 - 4*a*b - 4*a*c - 4*a*d - 2*b**2 -
        4*b*c - 4*b*d - 2*c**2 - 4*c*d - 2*d**2 + 1))
    assert radsimp((y**2 - x)/(y - sqrt(x))) == \
        sqrt(x) + y
    assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \
        -(sqrt(x) + y)
    assert radsimp(1/(1 - I + a*I)) == \
        (I*(-a + 1) + 1)/(a**2 - 2*a + 2)
    assert radsimp(1 / ((-x + y) *
                        (x - sqrt(y)))) == (x + sqrt(y)) / ((-x + y) *
                                                            (x**2 - y))
    e = (3 + 3 * sqrt(2)) * x * (3 * x - 3 * sqrt(y))
    assert radsimp(e) == 9 * x * (1 + sqrt(2)) * (x - sqrt(y))
    assert radsimp(1 / e) == (-1 + sqrt(2)) * (x + sqrt(y)) / (9 * x *
                                                               (x**2 - y))
    assert radsimp(1 + 1 / (1 + sqrt(3))) == Mul(S(1) / 2,
                                                 1 + sqrt(3),
                                                 evaluate=False)
    A = symbols("A", commutative=False)
    assert radsimp(x**2 + sqrt(2) * x**2 -
                   sqrt(2) * x * A) == x**2 + sqrt(2) * (x**2 - x * A)
    assert radsimp(1 / sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3)
    assert radsimp(1 / sqrt(5 + 2 * sqrt(6))**3) == -11 * sqrt(2) + 9 * sqrt(3)

    # coverage not provided by above tests
    assert collect_const(2 * sqrt(3) + 4 * a * sqrt(5)) == Mul(
        2, (2 * sqrt(5) * a + sqrt(3)), evaluate=False)
    assert collect_const(2 * sqrt(3) + 4 * a * sqrt(5),
                         sqrt(3)) == 2 * (2 * sqrt(5) * a + sqrt(3))
    assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
        sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)
Beispiel #8
0
def wtf():
    from sympy import collect_const
    a = VectorSymbol("a")
    b = VectorSymbol("b")
    expr = 3 * a + 3 * b
    print(collect_const(expr))