Exemple #1
0
def test_Factors():
    assert repr(Factors(x * y**2)) == 'Factors({x: 1, y: 2})'
Exemple #2
0
def test_Term():
    a = Term(4 * x * y**2 / z / t**3)
    b = Term(2 * x**3 * y**5 / t**3)

    assert a == Term(4, Factors({x: 1, y: 2}), Factors({z: 1, t: 3}))
    assert b == Term(2, Factors({x: 3, y: 5}), Factors({t: 3}))

    assert a.as_expr() == 4 * x * y**2 / z / t**3
    assert b.as_expr() == 2 * x**3 * y**5 / t**3

    assert a.inv() == \
        Term(Rational(1, 4), Factors({z: 1, t: 3}), Factors({x: 1, y: 2}))
    assert b.inv() == Term(Rational(1, 2), Factors({t: 3}),
                           Factors({
                               x: 3,
                               y: 5
                           }))

    assert a.mul(b) == a*b == \
        Term(8, Factors({x: 4, y: 7}), Factors({z: 1, t: 6}))
    assert a.quo(b) == a / b == Term(2, Factors({}), Factors({
        x: 2,
        y: 3,
        z: 1
    }))

    assert a.pow(3) == a**3 == \
        Term(64, Factors({x: 3, y: 6}), Factors({z: 3, t: 9}))
    assert b.pow(3) == b**3 == Term(8, Factors({x: 9, y: 15}), Factors({t: 9}))

    assert a.pow(-3) == a**(-3) == \
        Term(Rational(1, 64), Factors({z: 3, t: 9}), Factors({x: 3, y: 6}))
    assert b.pow(-3) == b**(-3) == \
        Term(Rational(1, 8), Factors({t: 9}), Factors({x: 9, y: 15}))

    assert a.gcd(b) == Term(2, Factors({x: 1, y: 2}), Factors({t: 3}))
    assert a.lcm(b) == Term(4, Factors({x: 3, y: 5}), Factors({z: 1, t: 3}))

    a = Term(4 * x * y**2 / z / t**3)
    b = Term(2 * x**3 * y**5 * t**7)

    assert a.mul(b) == Term(8, Factors({x: 4, y: 7, t: 4}), Factors({z: 1}))

    assert Term((2 * x + 2)**3) == Term(8, Factors({x + 1: 3}), Factors({}))
    assert Term((2*x + 2)*(3*x + 6)**2) == \
        Term(18, Factors({x + 1: 1, x + 2: 2}), Factors({}))

    A = Symbol('A', commutative=False)
    pytest.raises(NonCommutativeExpression, lambda: Term(A))

    f1, f2 = Factors({x: 2}), Factors()
    assert Term(2, numer=f1) == Term(2, f1, f2)
    assert Term(2, denom=f1) == Term(2, f2, f1)

    pytest.raises(TypeError, lambda: a * 2)
    pytest.raises(TypeError, lambda: a / 3)
    pytest.raises(TypeError, lambda: a**3.1)
Exemple #3
0
def test_Factors():
    assert Factors() == Factors({}) == Factors(Integer(1))
    assert Factors(Integer(1)) == Factors(Factors(Integer(1)))
    assert Factors().as_expr() == 1
    assert Factors({
        x: 2,
        y: 3,
        sin(x): 4
    }).as_expr() == x**2 * y**3 * sin(x)**4
    assert Factors(+oo) == Factors({oo: 1})
    assert Factors(-oo) == Factors({oo: 1, -1: 1})

    f1 = Factors({oo: 1})
    f2 = Factors({oo: 1})
    assert hash(f1) == hash(f2)

    a = Factors({x: 5, y: 3, z: 7})
    b = Factors({y: 4, z: 3, t: 10})

    assert a.mul(b) == a * b == Factors({x: 5, y: 7, z: 10, t: 10})

    assert a.div(b) == divmod(a, b) == \
        (Factors({x: 5, z: 4}), Factors({y: 1, t: 10}))
    assert a.quo(b) == a / b == Factors({x: 5, z: 4})
    assert a.rem(b) == a % b == Factors({y: 1, t: 10})

    assert a.pow(3) == a**3 == Factors({x: 15, y: 9, z: 21})
    assert b.pow(3) == b**3 == Factors({y: 12, z: 9, t: 30})

    pytest.raises(ValueError, lambda: a.pow(3.1))
    pytest.raises(ValueError, lambda: a.pow(Factors(3.1)))

    assert a.pow(0) == Factors()

    assert a.gcd(b) == Factors({y: 3, z: 3})
    assert a.lcm(b) == a.lcm(b.as_expr()) == Factors({x: 5, y: 4, z: 7, t: 10})

    a = Factors({x: 4, y: 7, t: 7})
    b = Factors({z: 1, t: 3})

    assert a.normal(b) == (Factors({x: 4, y: 7, t: 4}), Factors({z: 1}))

    assert Factors(sqrt(2) * x).as_expr() == sqrt(2) * x

    assert Factors(-I) * I == Factors()
    assert Factors({Integer(-1): Integer(3)})*Factors({Integer(-1): Integer(1), I: Integer(5)}) == \
        Factors(I)

    assert Factors(Integer(2)**x).div(Integer(3)**x) == \
        (Factors({Integer(2): x}), Factors({Integer(3): x}))
    assert Factors(2**(2*x + 2)).div(Integer(8)) == \
        (Factors({Integer(2): 2*x + 2}), Factors({Integer(8): Integer(1)}))

    # coverage
    # /!\ things break if this is not True
    assert Factors({Integer(-1): Rational(3, 2)}) == Factors({I: 1, -1: 1})
    assert Factors({
        I: Integer(1),
        Integer(-1): Rational(1, 3)
    }).as_expr() == I * cbrt(-1)

    assert Factors(-1.) == Factors({Integer(-1): Integer(1), Float(1.): 1})
    assert Factors(-2.) == Factors({Integer(-1): Integer(1), Float(2.): 1})
    assert Factors((-2.)**x) == Factors({Float(-2.): x})
    assert Factors(Integer(-2)) == Factors({
        Integer(-1): Integer(1),
        Integer(2): 1
    })
    assert Factors(Rational(1, 2)) == Factors({Integer(2): -1})
    assert Factors(Rational(3, 2)) == Factors({
        Integer(3): 1,
        Integer(2): Integer(-1)
    })
    assert Factors({I: Integer(1)}) == Factors(I)
    assert Factors({-1.0: 2, I: 1}) == Factors({Float(1.0): 1, I: 1})
    assert Factors({-1: -Rational(3, 2)}).as_expr() == I
    A = symbols('A', commutative=False)
    assert Factors(2 * A**2) == Factors({Integer(2): 1, A**2: 1})
    assert Factors(I) == Factors({I: 1})
    assert Factors(x).normal(Integer(2)) == (Factors(x), Factors(Integer(2)))
    assert Factors(x).normal(Integer(0)) == (Factors(), Factors(Integer(0)))
    pytest.raises(ZeroDivisionError, lambda: Factors(x).div(Integer(0)))
    assert Factors(x).mul(Integer(2)) == Factors(2 * x)
    assert Factors(x).mul(Integer(0)).is_zero
    assert Factors(x).mul(1 / x).is_one
    assert Factors(x**sqrt(8)).as_expr() == x**(2 * sqrt(2))
    assert Factors(x)**Factors(Integer(2)) == Factors(x**2)
    assert Factors(x).gcd(Integer(0)) == Factors(x)
    assert Factors(x).lcm(Integer(0)).is_zero
    assert Factors(Integer(0)).div(x) == (Factors(Integer(0)), Factors())
    assert Factors(x).div(x) == (Factors(), Factors())
    assert Factors({x: .2}) / Factors({x: .2}) == Factors()
    assert Factors(x) != Factors()
    assert Factors(x) == x
    assert Factors(Integer(0)).normal(x) == (Factors(Integer(0)), Factors())
    n, d = x**(2 + y), x**2
    f = Factors(n)
    assert f.div(d) == f.normal(d) == (Factors(x**y), Factors())
    assert f.gcd(d) == Factors()
    d = x**y
    assert f.div(d) == f.normal(d) == (Factors(x**2), Factors())
    assert f.gcd(d) == Factors(d)
    n = d = 2**x
    f = Factors(n)
    assert f.div(d) == f.normal(d) == (Factors(), Factors())
    assert f.gcd(d) == Factors(d)
    n, d = 2**x, 2**y
    f = Factors(n)
    assert f.div(d) == f.normal(d) == (Factors({Integer(2): x
                                                }), Factors({Integer(2): y}))
    assert f.gcd(d) == Factors()

    assert f.div(f) == (Factors(), Factors())

    # extraction of constant only
    n = x**(x + 3)
    assert Factors(n).normal(x**-3) == (Factors({x: x + 6}), Factors({}))
    assert Factors(n).normal(x**3) == (Factors({x: x}), Factors({}))
    assert Factors(n).normal(x**4) == (Factors({x: x}), Factors({x: 1}))
    assert Factors(n).normal(x**(y - 3)) == \
        (Factors({x: x + 6}), Factors({x: y}))
    assert Factors(n).normal(x**(y + 3)) == (Factors({x: x}), Factors({x: y}))
    assert Factors(n).normal(x**(y + 4)) == \
        (Factors({x: x}), Factors({x: y + 1}))

    assert Factors(n).div(x**-3) == (Factors({x: x + 6}), Factors({}))
    assert Factors(n).div(x**3) == (Factors({x: x}), Factors({}))
    assert Factors(n).div(x**4) == (Factors({x: x}), Factors({x: 1}))
    assert Factors(n).div(x**(y - 3)) == \
        (Factors({x: x + 6}), Factors({x: y}))
    assert Factors(n).div(x**(y + 3)) == (Factors({x: x}), Factors({x: y}))
    assert Factors(n).div(x**(y + 4)) == \
        (Factors({x: x}), Factors({x: y + 1}))

    assert Factors({I: I}).as_expr() == (-1)**(I / 2)
    assert Factors({-1: Rational(4, 3)}).as_expr() == -cbrt(-1)
Exemple #4
0
def test_Factors():
    assert Factors() == Factors({}) == Factors(Integer(1))
    assert Factors(Integer(1)) == Factors(Factors(Integer(1)))
    assert Factors().as_expr() == 1
    assert Factors({x: 2, y: 3, sin(x): 4}).as_expr() == x**2*y**3*sin(x)**4
    assert Factors(+oo) == Factors({oo: 1})
    assert Factors(-oo) == Factors({oo: 1, -1: 1})

    f1 = Factors({oo: 1})
    f2 = Factors({oo: 1})
    assert hash(f1) == hash(f2)

    a = Factors({x: 5, y: 3, z: 7})
    b = Factors({      y: 4, z: 3, t: 10})

    assert a.mul(b) == a*b == Factors({x: 5, y: 7, z: 10, t: 10})

    assert a.div(b) == divmod(a, b) == \
        (Factors({x: 5, z: 4}), Factors({y: 1, t: 10}))
    assert a.quo(b) == a/b == Factors({x: 5, z: 4})
    assert a.rem(b) == a % b == Factors({y: 1, t: 10})

    assert a.pow(3) == a**3 == Factors({x: 15, y: 9, z: 21})
    assert b.pow(3) == b**3 == Factors({y: 12, z: 9, t: 30})

    pytest.raises(ValueError, lambda: a.pow(3.1))
    pytest.raises(ValueError, lambda: a.pow(Factors(3.1)))

    assert a.pow(0) == Factors()

    assert a.gcd(b) == Factors({y: 3, z: 3})
    assert a.lcm(b) == a.lcm(b.as_expr()) == Factors({x: 5, y: 4, z: 7, t: 10})

    a = Factors({x: 4, y: 7, t: 7})
    b = Factors({z: 1, t: 3})

    assert a.normal(b) == (Factors({x: 4, y: 7, t: 4}), Factors({z: 1}))

    assert Factors(sqrt(2)*x).as_expr() == sqrt(2)*x

    assert Factors(-I)*I == Factors()
    assert Factors({Integer(-1): Integer(3)})*Factors({Integer(-1): Integer(1), I: Integer(5)}) == \
        Factors(I)

    assert Factors(Integer(2)**x).div(Integer(3)**x) == \
        (Factors({Integer(2): x}), Factors({Integer(3): x}))
    assert Factors(2**(2*x + 2)).div(Integer(8)) == \
        (Factors({Integer(2): 2*x + 2}), Factors({Integer(8): Integer(1)}))

    # coverage
    # /!\ things break if this is not True
    assert Factors({Integer(-1): Rational(3, 2)}) == Factors({I: 1, -1: 1})
    assert Factors({I: Integer(1), Integer(-1): Rational(1, 3)}).as_expr() == I*cbrt(-1)

    assert Factors(-1.) == Factors({Integer(-1): Integer(1), Float(1.): 1})
    assert Factors(-2.) == Factors({Integer(-1): Integer(1), Float(2.): 1})
    assert Factors((-2.)**x) == Factors({Float(-2.): x})
    assert Factors(Integer(-2)) == Factors({Integer(-1): Integer(1), Integer(2): 1})
    assert Factors(Rational(1, 2)) == Factors({Integer(2): -1})
    assert Factors(Rational(3, 2)) == Factors({Integer(3): 1, Integer(2): Integer(-1)})
    assert Factors({I: Integer(1)}) == Factors(I)
    assert Factors({-1.0: 2, I: 1}) == Factors({Float(1.0): 1, I: 1})
    assert Factors({-1: -Rational(3, 2)}).as_expr() == I
    A = symbols('A', commutative=False)
    assert Factors(2*A**2) == Factors({Integer(2): 1, A**2: 1})
    assert Factors(I) == Factors({I: 1})
    assert Factors(x).normal(Integer(2)) == (Factors(x), Factors(Integer(2)))
    assert Factors(x).normal(Integer(0)) == (Factors(), Factors(Integer(0)))
    pytest.raises(ZeroDivisionError, lambda: Factors(x).div(Integer(0)))
    assert Factors(x).mul(Integer(2)) == Factors(2*x)
    assert Factors(x).mul(Integer(0)).is_zero
    assert Factors(x).mul(1/x).is_one
    assert Factors(x**sqrt(8)).as_expr() == x**(2*sqrt(2))
    assert Factors(x)**Factors(Integer(2)) == Factors(x**2)
    assert Factors(x).gcd(Integer(0)) == Factors(x)
    assert Factors(x).lcm(Integer(0)).is_zero
    assert Factors(Integer(0)).div(x) == (Factors(Integer(0)), Factors())
    assert Factors(x).div(x) == (Factors(), Factors())
    assert Factors({x: .2})/Factors({x: .2}) == Factors()
    assert Factors(x) != Factors()
    assert Factors(x) == x
    assert Factors(Integer(0)).normal(x) == (Factors(Integer(0)), Factors())
    n, d = x**(2 + y), x**2
    f = Factors(n)
    assert f.div(d) == f.normal(d) == (Factors(x**y), Factors())
    assert f.gcd(d) == Factors()
    d = x**y
    assert f.div(d) == f.normal(d) == (Factors(x**2), Factors())
    assert f.gcd(d) == Factors(d)
    n = d = 2**x
    f = Factors(n)
    assert f.div(d) == f.normal(d) == (Factors(), Factors())
    assert f.gcd(d) == Factors(d)
    n, d = 2**x, 2**y
    f = Factors(n)
    assert f.div(d) == f.normal(d) == (Factors({Integer(2): x}), Factors({Integer(2): y}))
    assert f.gcd(d) == Factors()

    assert f.div(f) == (Factors(), Factors())

    # extraction of constant only
    n = x**(x + 3)
    assert Factors(n).normal(x**-3) == (Factors({x: x + 6}), Factors({}))
    assert Factors(n).normal(x**3) == (Factors({x: x}), Factors({}))
    assert Factors(n).normal(x**4) == (Factors({x: x}), Factors({x: 1}))
    assert Factors(n).normal(x**(y - 3)) == \
        (Factors({x: x + 6}), Factors({x: y}))
    assert Factors(n).normal(x**(y + 3)) == (Factors({x: x}), Factors({x: y}))
    assert Factors(n).normal(x**(y + 4)) == \
        (Factors({x: x}), Factors({x: y + 1}))

    assert Factors(n).div(x**-3) == (Factors({x: x + 6}), Factors({}))
    assert Factors(n).div(x**3) == (Factors({x: x}), Factors({}))
    assert Factors(n).div(x**4) == (Factors({x: x}), Factors({x: 1}))
    assert Factors(n).div(x**(y - 3)) == \
        (Factors({x: x + 6}), Factors({x: y}))
    assert Factors(n).div(x**(y + 3)) == (Factors({x: x}), Factors({x: y}))
    assert Factors(n).div(x**(y + 4)) == \
        (Factors({x: x}), Factors({x: y + 1}))

    assert Factors({I: I}).as_expr() == (-1)**(I/2)
    assert Factors({-1: Rational(4, 3)}).as_expr() == -cbrt(-1)
Exemple #5
0
def test_Term():
    a = Term(4 * x * y**2 / z / t**3)
    b = Term(2 * x**3 * y**5 / t**3)

    assert a == Term(4, Factors({x: 1, y: 2}), Factors({z: 1, t: 3}))
    assert b == Term(2, Factors({x: 3, y: 5}), Factors({t: 3}))

    assert a.as_expr() == 4 * x * y**2 / z / t**3
    assert b.as_expr() == 2 * x**3 * y**5 / t**3

    assert a.inv() == \
        Term(Rational(1, 4), Factors({z: 1, t: 3}), Factors({x: 1, y: 2}))
    assert b.inv() == Term(Rational(1, 2), Factors({t: 3}),
                           Factors({
                               x: 3,
                               y: 5
                           }))

    assert a.mul(b) == a*b == \
        Term(8, Factors({x: 4, y: 7}), Factors({z: 1, t: 6}))
    assert a.quo(b) == a / b == Term(2, Factors({}), Factors({
        x: 2,
        y: 3,
        z: 1
    }))

    assert a.pow(3) == a**3 == \
        Term(64, Factors({x: 3, y: 6}), Factors({z: 3, t: 9}))
    assert b.pow(3) == b**3 == Term(8, Factors({x: 9, y: 15}), Factors({t: 9}))

    assert a.pow(-3) == a**(-3) == \
        Term(Rational(1, 64), Factors({z: 3, t: 9}), Factors({x: 3, y: 6}))
    assert b.pow(-3) == b**(-3) == \
        Term(Rational(1, 8), Factors({t: 9}), Factors({x: 9, y: 15}))

    assert a.gcd(b) == Term(2, Factors({x: 1, y: 2}), Factors({t: 3}))
    assert a.lcm(b) == Term(4, Factors({x: 3, y: 5}), Factors({z: 1, t: 3}))

    a = Term(4 * x * y**2 / z / t**3)
    b = Term(2 * x**3 * y**5 * t**7)

    assert a.mul(b) == Term(8, Factors({x: 4, y: 7, t: 4}), Factors({z: 1}))

    assert Term((2 * x + 2)**3) == Term(8, Factors({x + 1: 3}), Factors({}))
    assert Term((2*x + 2)*(3*x + 6)**2) == \
        Term(18, Factors({x + 1: 1, x + 2: 2}), Factors({}))
Exemple #6
0
def _minpoly_compose(ex, x, dom):
    """
    Computes the minimal polynomial of an algebraic element
    using operations on minimal polynomials

    Examples
    ========

    >>> from diofant import minimal_polynomial, sqrt, Rational
    >>> from diofant.abc import x, y
    >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=True)
    x**2 - 2*x - 1
    >>> minimal_polynomial(sqrt(y) + 1/y, x, compose=True)
    x**2*y**2 - 2*x*y - y**3 + 1

    """
    if ex.is_Rational:
        return ex.q * x - ex.p
    if ex is I:
        return x**2 + 1
    if ex is GoldenRatio:
        return x**2 - x - 1
    if hasattr(dom, 'symbols') and ex in dom.symbols:
        return x - ex

    if dom.is_QQ and _is_sum_surds(ex):
        # eliminate the square roots
        ex -= x
        while 1:
            ex1 = _separate_sq(ex)
            if ex1 is ex:
                return ex
            else:
                ex = ex1

    if ex.is_Add:
        res = _minpoly_add(x, dom, *ex.args)
    elif ex.is_Mul:
        f = Factors(ex).factors
        r = sift(f.items(),
                 lambda itx: itx[0].is_Rational and itx[1].is_Rational)
        if r[True] and dom == QQ:
            ex1 = Mul(*[bx**ex for bx, ex in r[False] + r[None]])
            r1 = r[True]
            dens = [y.q for _, y in r1]
            lcmdens = reduce(lcm, dens, 1)
            nums = [base**(y.p * lcmdens // y.q) for base, y in r1]
            ex2 = Mul(*nums)
            mp1 = minimal_polynomial(ex1, x)
            # use the fact that in Diofant canonicalization products of integers
            # raised to rational powers are organized in relatively prime
            # bases, and that in ``base**(n/d)`` a perfect power is
            # simplified with the root
            mp2 = ex2.q * x**lcmdens - ex2.p
            ex2 = ex2**Rational(1, lcmdens)
            res = _minpoly_op_algebraic_element(Mul,
                                                ex1,
                                                ex2,
                                                x,
                                                dom,
                                                mp1=mp1,
                                                mp2=mp2)
        else:
            res = _minpoly_mul(x, dom, *ex.args)
    elif ex.is_Pow:
        if ex.base is S.Exp1:
            res = _minpoly_exp(ex, x)
        else:
            res = _minpoly_pow(ex.base, ex.exp, x, dom)
    elif ex.__class__ is sin:
        res = _minpoly_sin(ex, x)
    elif ex.__class__ is cos:
        res = _minpoly_cos(ex, x)
    elif ex.__class__ is RootOf:
        res = _minpoly_rootof(ex, x)
    else:
        raise NotAlgebraic("%s doesn't seem to be an algebraic element" % ex)
    return res
Exemple #7
0
def radsimp(expr, symbolic=True, max_terms=4):
    """
    Rationalize the denominator by removing square roots.

    Note: the expression returned from radsimp must be used with caution
    since if the denominator contains symbols, it will be possible to make
    substitutions that violate the assumptions of the simplification process:
    that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If
    there are no symbols, this assumptions is made valid by collecting terms
    of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If
    you do not want the simplification to occur for symbolic denominators, set
    ``symbolic`` to False.

    If there are more than ``max_terms`` radical terms then the expression is
    returned unchanged.

    Examples
    ========

    >>> from diofant import radsimp, sqrt, Symbol, denom, pprint, I
    >>> from diofant import factor_terms, fraction, signsimp
    >>> from diofant.simplify.radsimp import collect_sqrt
    >>> from diofant.abc import a, b, c

    >>> radsimp(1/(I + 1))
    (1 - I)/2
    >>> radsimp(1/(2 + sqrt(2)))
    (-sqrt(2) + 2)/2
    >>> x,y = map(Symbol, 'xy')
    >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2))
    >>> radsimp(e)
    sqrt(2)*(x + y)

    No simplification beyond removal of the gcd is done. One might
    want to polish the result a little, however, by collecting
    square root terms:

    >>> r2 = sqrt(2)
    >>> r5 = sqrt(5)
    >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5))
    >>> pprint(ans, use_unicode=False)
        ___       ___       ___       ___
      \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y
    ------------------------------------------
       2               2      2              2
    5*a  + 10*a*b + 5*b  - 2*x  - 4*x*y - 2*y

    >>> n, d = fraction(ans)
    >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True), use_unicode=False)
            ___             ___
          \/ 5 *(a + b) - \/ 2 *(x + y)
    ------------------------------------------
       2               2      2              2
    5*a  + 10*a*b + 5*b  - 2*x  - 4*x*y - 2*y

    If radicals in the denominator cannot be removed or there is no denominator,
    the original expression will be returned.

    >>> radsimp(sqrt(2)*x + sqrt(2))
    sqrt(2)*x + sqrt(2)

    Results with symbols will not always be valid for all substitutions:

    >>> eq = 1/(a + b*sqrt(c))
    >>> eq.subs(a, b*sqrt(c))
    1/(2*b*sqrt(c))
    >>> radsimp(eq).subs(a, b*sqrt(c))
    nan

    If symbolic=False, symbolic denominators will not be transformed (but
    numeric denominators will still be processed):

    >>> radsimp(eq, symbolic=False)
    1/(a + b*sqrt(c))

    """
    from diofant.simplify.simplify import signsimp

    syms = symbols("a:d A:D")

    def _num(rterms):
        # return the multiplier that will simplify the expression described
        # by rterms [(sqrt arg, coeff), ... ]
        a, b, c, d, A, B, C, D = syms
        if len(rterms) == 2:
            reps = dict(zip([A, a, B, b], [j for i in rterms for j in i]))
            return (sqrt(A) * a - sqrt(B) * b).xreplace(reps)
        if len(rterms) == 3:
            reps = dict(zip([A, a, B, b, C, c],
                            [j for i in rterms for j in i]))
            return ((sqrt(A) * a + sqrt(B) * b - sqrt(C) * c) *
                    (2 * sqrt(A) * sqrt(B) * a * b - A * a**2 - B * b**2 +
                     C * c**2)).xreplace(reps)
        elif len(rterms) == 4:
            reps = dict(
                zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i]))
            return (
                (sqrt(A) * a + sqrt(B) * b - sqrt(C) * c - sqrt(D) * d) *
                (2 * sqrt(A) * sqrt(B) * a * b - A * a**2 - B * b**2 -
                 2 * sqrt(C) * sqrt(D) * c * d + C * c**2 + D * d**2) *
                (-8 * sqrt(A) * sqrt(B) * sqrt(C) * sqrt(D) * a * b * c * d +
                 A**2 * a**4 - 2 * A * B * a**2 * b**2 -
                 2 * A * C * a**2 * c**2 - 2 * A * D * a**2 * d**2 +
                 B**2 * b**4 - 2 * B * C * b**2 * c**2 -
                 2 * B * D * b**2 * d**2 + C**2 * c**4 -
                 2 * C * D * c**2 * d**2 + D**2 * d**4)).xreplace(reps)
        elif len(rterms) == 1:
            return sqrt(rterms[0][0])
        else:
            raise NotImplementedError

    def ispow2(d, log2=False):
        if not d.is_Pow:
            return False
        e = d.exp
        if e.is_Rational and e.q == 2 or symbolic and fraction(e)[1] == 2:
            return True
        if log2:
            q = 1
            if e.is_Rational:
                q = e.q
            elif symbolic:
                d = fraction(e)[1]
                if d.is_Integer:
                    q = d
            if q != 1 and log(q, 2).is_Integer:
                return True
        return False

    def handle(expr):
        # Handle first reduces to the case
        # expr = 1/d, where d is an add, or d is base**p/2.
        # We do this by recursively calling handle on each piece.
        from diofant.simplify.simplify import nsimplify

        n, d = fraction(expr)

        if expr.is_Atom or (d.is_Atom and n.is_Atom):
            return expr
        elif not n.is_Atom:
            n = n.func(*[handle(a) for a in n.args])
            return _unevaluated_Mul(n, handle(1 / d))
        elif n is not S.One:
            return _unevaluated_Mul(n, handle(1 / d))
        elif d.is_Mul:
            return _unevaluated_Mul(*[handle(1 / d) for d in d.args])

        # By this step, expr is 1/d, and d is not a mul.
        if not symbolic and d.free_symbols:
            return expr

        if ispow2(d):
            d2 = sqrtdenest(sqrt(d.base))**fraction(d.exp)[0]
            if d2 != d:
                return handle(1 / d2)
        elif d.is_Pow and (d.exp.is_integer or d.base.is_positive):
            # (1/d**i) = (1/d)**i
            return handle(1 / d.base)**d.exp

        if not (d.is_Add or ispow2(d)):
            return 1 / d.func(*[handle(a) for a in d.args])

        # handle 1/d treating d as an Add (though it may not be)

        keep = True  # keep changes that are made

        # flatten it and collect radicals after checking for special
        # conditions
        d = _mexpand(d)

        # did it change?
        if d.is_Atom:
            return 1 / d

        # is it a number that might be handled easily?
        if d.is_number:
            _d = nsimplify(d)
            if _d.is_Number and _d.equals(d):
                return 1 / _d

        while True:
            # collect similar terms
            collected = defaultdict(list)
            for m in Add.make_args(d):  # d might have become non-Add
                p2 = []
                other = []
                for i in Mul.make_args(m):
                    if ispow2(i, log2=True):
                        p2.append(i.base if i.exp is S.Half else i.base**(
                            2 * i.exp))
                    elif i is S.ImaginaryUnit:
                        p2.append(S.NegativeOne)
                    else:
                        other.append(i)
                collected[tuple(ordered(p2))].append(Mul(*other))
            rterms = list(ordered(list(collected.items())))
            rterms = [(Mul(*i), Add(*j)) for i, j in rterms]
            nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0)
            if nrad < 1:
                break
            elif nrad > max_terms:
                # there may have been invalid operations leading to this point
                # so don't keep changes, e.g. this expression is troublesome
                # in collecting terms so as not to raise the issue of 2834:
                # r = sqrt(sqrt(5) + 5)
                # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r)
                keep = False
                break
            if len(rterms) > 4:
                # in general, only 4 terms can be removed with repeated squaring
                # but other considerations can guide selection of radical terms
                # so that radicals are removed
                if all(
                    [x.is_Integer and (y**2).is_Rational for x, y in rterms]):
                    nd, d = rad_rationalize(
                        S.One,
                        Add._from_args([sqrt(x) * y for x, y in rterms]))
                    n *= nd
                else:
                    # is there anything else that might be attempted?
                    keep = False
                break

            from diofant.simplify.powsimp import powsimp, powdenest

            num = powsimp(_num(rterms))
            n *= num
            d *= num
            d = powdenest(_mexpand(d), force=symbolic)
            if d.is_Atom:
                break

        if not keep:
            return expr
        return _unevaluated_Mul(n, 1 / d)

    coeff, expr = expr.as_coeff_Add()
    expr = expr.normal()
    old = fraction(expr)
    n, d = fraction(handle(expr))
    if old != (n, d):
        if not d.is_Atom:
            was = (n, d)
            n = signsimp(n, evaluate=False)
            d = signsimp(d, evaluate=False)
            u = Factors(_unevaluated_Mul(n, 1 / d))
            u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()])
            n, d = fraction(u)
            if old == (n, d):
                n, d = was
        n = expand_mul(n)
        if d.is_Number or d.is_Add:
            n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1 / d)))
            if d2.is_Number or (d2.count_ops() <= d.count_ops()):
                n, d = [signsimp(i) for i in (n2, d2)]
                if n.is_Mul and n.args[0].is_Number:
                    n = n.func(*n.args)

    return coeff + _unevaluated_Mul(n, 1 / d)
Exemple #8
0
def collect_const(expr, *vars, **kwargs):
    """A non-greedy collection of terms with similar number coefficients in
    an Add expr. If ``vars`` is given then only those constants will be
    targeted. Although any Number can also be targeted, if this is not
    desired set ``Numbers=False`` and no Float or Rational will be collected.

    Examples
    ========

    >>> from diofant import sqrt
    >>> from diofant.abc import a, s, x, y, z
    >>> from diofant.simplify.radsimp import collect_const
    >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2)))
    sqrt(3)*(sqrt(2) + 2)
    >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7))
    (sqrt(3) + sqrt(7))*(s + 1)
    >>> s = sqrt(2) + 2
    >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7))
    (sqrt(2) + 3)*(sqrt(3) + sqrt(7))
    >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3))
    sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2)

    The collection is sign-sensitive, giving higher precedence to the
    unsigned values:

    >>> collect_const(x - y - z)
    x - (y + z)
    >>> collect_const(-y - z)
    -(y + z)
    >>> collect_const(2*x - 2*y - 2*z, 2)
    2*(x - y - z)
    >>> collect_const(2*x - 2*y - 2*z, -2)
    2*x - 2*(y + z)

    See Also
    ========
    collect, collect_sqrt, rcollect
    """
    if not expr.is_Add:
        return expr

    recurse = False
    Numbers = kwargs.get('Numbers', True)

    if not vars:
        recurse = True
        vars = set()
        for a in expr.args:
            for m in Mul.make_args(a):
                if m.is_number:
                    vars.add(m)
    else:
        vars = sympify(vars)
    if not Numbers:
        vars = [v for v in vars if not v.is_Number]

    vars = list(ordered(vars))
    for v in vars:
        terms = defaultdict(list)
        Fv = Factors(v)
        for m in Add.make_args(expr):
            f = Factors(m)
            q, r = f.div(Fv)
            if r.is_one:
                # only accept this as a true factor if
                # it didn't change an exponent from an Integer
                # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2)
                # -- we aren't looking for this sort of change
                fwas = f.factors.copy()
                fnow = q.factors
                if not any(k in fwas and fwas[k].is_Integer
                           and not fnow[k].is_Integer for k in fnow):
                    terms[v].append(q.as_expr())
                    continue
            terms[S.One].append(m)

        args = []
        hit = False
        uneval = False
        for k in ordered(terms):
            v = terms[k]
            if k is S.One:
                args.extend(v)
                continue

            if len(v) > 1:
                v = Add(*v)
                hit = True
                if recurse and v != expr:
                    vars.append(v)
            else:
                v = v[0]

            # be careful not to let uneval become True unless
            # it must be because it's going to be more expensive
            # to rebuild the expression as an unevaluated one
            if Numbers and k.is_Number and v.is_Add:
                args.append(_keep_coeff(k, v, sign=True))
                uneval = True
            else:
                args.append(k * v)

        if hit:
            if uneval:
                expr = _unevaluated_Add(*args)
            else:
                expr = Add(*args)
            if not expr.is_Add:
                break

    return expr