def test_Factors(): assert repr(Factors(x * y**2)) == 'Factors({x: 1, y: 2})'
def test_Term(): a = Term(4 * x * y**2 / z / t**3) b = Term(2 * x**3 * y**5 / t**3) assert a == Term(4, Factors({x: 1, y: 2}), Factors({z: 1, t: 3})) assert b == Term(2, Factors({x: 3, y: 5}), Factors({t: 3})) assert a.as_expr() == 4 * x * y**2 / z / t**3 assert b.as_expr() == 2 * x**3 * y**5 / t**3 assert a.inv() == \ Term(Rational(1, 4), Factors({z: 1, t: 3}), Factors({x: 1, y: 2})) assert b.inv() == Term(Rational(1, 2), Factors({t: 3}), Factors({ x: 3, y: 5 })) assert a.mul(b) == a*b == \ Term(8, Factors({x: 4, y: 7}), Factors({z: 1, t: 6})) assert a.quo(b) == a / b == Term(2, Factors({}), Factors({ x: 2, y: 3, z: 1 })) assert a.pow(3) == a**3 == \ Term(64, Factors({x: 3, y: 6}), Factors({z: 3, t: 9})) assert b.pow(3) == b**3 == Term(8, Factors({x: 9, y: 15}), Factors({t: 9})) assert a.pow(-3) == a**(-3) == \ Term(Rational(1, 64), Factors({z: 3, t: 9}), Factors({x: 3, y: 6})) assert b.pow(-3) == b**(-3) == \ Term(Rational(1, 8), Factors({t: 9}), Factors({x: 9, y: 15})) assert a.gcd(b) == Term(2, Factors({x: 1, y: 2}), Factors({t: 3})) assert a.lcm(b) == Term(4, Factors({x: 3, y: 5}), Factors({z: 1, t: 3})) a = Term(4 * x * y**2 / z / t**3) b = Term(2 * x**3 * y**5 * t**7) assert a.mul(b) == Term(8, Factors({x: 4, y: 7, t: 4}), Factors({z: 1})) assert Term((2 * x + 2)**3) == Term(8, Factors({x + 1: 3}), Factors({})) assert Term((2*x + 2)*(3*x + 6)**2) == \ Term(18, Factors({x + 1: 1, x + 2: 2}), Factors({})) A = Symbol('A', commutative=False) pytest.raises(NonCommutativeExpression, lambda: Term(A)) f1, f2 = Factors({x: 2}), Factors() assert Term(2, numer=f1) == Term(2, f1, f2) assert Term(2, denom=f1) == Term(2, f2, f1) pytest.raises(TypeError, lambda: a * 2) pytest.raises(TypeError, lambda: a / 3) pytest.raises(TypeError, lambda: a**3.1)
def test_Factors(): assert Factors() == Factors({}) == Factors(Integer(1)) assert Factors(Integer(1)) == Factors(Factors(Integer(1))) assert Factors().as_expr() == 1 assert Factors({ x: 2, y: 3, sin(x): 4 }).as_expr() == x**2 * y**3 * sin(x)**4 assert Factors(+oo) == Factors({oo: 1}) assert Factors(-oo) == Factors({oo: 1, -1: 1}) f1 = Factors({oo: 1}) f2 = Factors({oo: 1}) assert hash(f1) == hash(f2) a = Factors({x: 5, y: 3, z: 7}) b = Factors({y: 4, z: 3, t: 10}) assert a.mul(b) == a * b == Factors({x: 5, y: 7, z: 10, t: 10}) assert a.div(b) == divmod(a, b) == \ (Factors({x: 5, z: 4}), Factors({y: 1, t: 10})) assert a.quo(b) == a / b == Factors({x: 5, z: 4}) assert a.rem(b) == a % b == Factors({y: 1, t: 10}) assert a.pow(3) == a**3 == Factors({x: 15, y: 9, z: 21}) assert b.pow(3) == b**3 == Factors({y: 12, z: 9, t: 30}) pytest.raises(ValueError, lambda: a.pow(3.1)) pytest.raises(ValueError, lambda: a.pow(Factors(3.1))) assert a.pow(0) == Factors() assert a.gcd(b) == Factors({y: 3, z: 3}) assert a.lcm(b) == a.lcm(b.as_expr()) == Factors({x: 5, y: 4, z: 7, t: 10}) a = Factors({x: 4, y: 7, t: 7}) b = Factors({z: 1, t: 3}) assert a.normal(b) == (Factors({x: 4, y: 7, t: 4}), Factors({z: 1})) assert Factors(sqrt(2) * x).as_expr() == sqrt(2) * x assert Factors(-I) * I == Factors() assert Factors({Integer(-1): Integer(3)})*Factors({Integer(-1): Integer(1), I: Integer(5)}) == \ Factors(I) assert Factors(Integer(2)**x).div(Integer(3)**x) == \ (Factors({Integer(2): x}), Factors({Integer(3): x})) assert Factors(2**(2*x + 2)).div(Integer(8)) == \ (Factors({Integer(2): 2*x + 2}), Factors({Integer(8): Integer(1)})) # coverage # /!\ things break if this is not True assert Factors({Integer(-1): Rational(3, 2)}) == Factors({I: 1, -1: 1}) assert Factors({ I: Integer(1), Integer(-1): Rational(1, 3) }).as_expr() == I * cbrt(-1) assert Factors(-1.) == Factors({Integer(-1): Integer(1), Float(1.): 1}) assert Factors(-2.) == Factors({Integer(-1): Integer(1), Float(2.): 1}) assert Factors((-2.)**x) == Factors({Float(-2.): x}) assert Factors(Integer(-2)) == Factors({ Integer(-1): Integer(1), Integer(2): 1 }) assert Factors(Rational(1, 2)) == Factors({Integer(2): -1}) assert Factors(Rational(3, 2)) == Factors({ Integer(3): 1, Integer(2): Integer(-1) }) assert Factors({I: Integer(1)}) == Factors(I) assert Factors({-1.0: 2, I: 1}) == Factors({Float(1.0): 1, I: 1}) assert Factors({-1: -Rational(3, 2)}).as_expr() == I A = symbols('A', commutative=False) assert Factors(2 * A**2) == Factors({Integer(2): 1, A**2: 1}) assert Factors(I) == Factors({I: 1}) assert Factors(x).normal(Integer(2)) == (Factors(x), Factors(Integer(2))) assert Factors(x).normal(Integer(0)) == (Factors(), Factors(Integer(0))) pytest.raises(ZeroDivisionError, lambda: Factors(x).div(Integer(0))) assert Factors(x).mul(Integer(2)) == Factors(2 * x) assert Factors(x).mul(Integer(0)).is_zero assert Factors(x).mul(1 / x).is_one assert Factors(x**sqrt(8)).as_expr() == x**(2 * sqrt(2)) assert Factors(x)**Factors(Integer(2)) == Factors(x**2) assert Factors(x).gcd(Integer(0)) == Factors(x) assert Factors(x).lcm(Integer(0)).is_zero assert Factors(Integer(0)).div(x) == (Factors(Integer(0)), Factors()) assert Factors(x).div(x) == (Factors(), Factors()) assert Factors({x: .2}) / Factors({x: .2}) == Factors() assert Factors(x) != Factors() assert Factors(x) == x assert Factors(Integer(0)).normal(x) == (Factors(Integer(0)), Factors()) n, d = x**(2 + y), x**2 f = Factors(n) assert f.div(d) == f.normal(d) == (Factors(x**y), Factors()) assert f.gcd(d) == Factors() d = x**y assert f.div(d) == f.normal(d) == (Factors(x**2), Factors()) assert f.gcd(d) == Factors(d) n = d = 2**x f = Factors(n) assert f.div(d) == f.normal(d) == (Factors(), Factors()) assert f.gcd(d) == Factors(d) n, d = 2**x, 2**y f = Factors(n) assert f.div(d) == f.normal(d) == (Factors({Integer(2): x }), Factors({Integer(2): y})) assert f.gcd(d) == Factors() assert f.div(f) == (Factors(), Factors()) # extraction of constant only n = x**(x + 3) assert Factors(n).normal(x**-3) == (Factors({x: x + 6}), Factors({})) assert Factors(n).normal(x**3) == (Factors({x: x}), Factors({})) assert Factors(n).normal(x**4) == (Factors({x: x}), Factors({x: 1})) assert Factors(n).normal(x**(y - 3)) == \ (Factors({x: x + 6}), Factors({x: y})) assert Factors(n).normal(x**(y + 3)) == (Factors({x: x}), Factors({x: y})) assert Factors(n).normal(x**(y + 4)) == \ (Factors({x: x}), Factors({x: y + 1})) assert Factors(n).div(x**-3) == (Factors({x: x + 6}), Factors({})) assert Factors(n).div(x**3) == (Factors({x: x}), Factors({})) assert Factors(n).div(x**4) == (Factors({x: x}), Factors({x: 1})) assert Factors(n).div(x**(y - 3)) == \ (Factors({x: x + 6}), Factors({x: y})) assert Factors(n).div(x**(y + 3)) == (Factors({x: x}), Factors({x: y})) assert Factors(n).div(x**(y + 4)) == \ (Factors({x: x}), Factors({x: y + 1})) assert Factors({I: I}).as_expr() == (-1)**(I / 2) assert Factors({-1: Rational(4, 3)}).as_expr() == -cbrt(-1)
def test_Factors(): assert Factors() == Factors({}) == Factors(Integer(1)) assert Factors(Integer(1)) == Factors(Factors(Integer(1))) assert Factors().as_expr() == 1 assert Factors({x: 2, y: 3, sin(x): 4}).as_expr() == x**2*y**3*sin(x)**4 assert Factors(+oo) == Factors({oo: 1}) assert Factors(-oo) == Factors({oo: 1, -1: 1}) f1 = Factors({oo: 1}) f2 = Factors({oo: 1}) assert hash(f1) == hash(f2) a = Factors({x: 5, y: 3, z: 7}) b = Factors({ y: 4, z: 3, t: 10}) assert a.mul(b) == a*b == Factors({x: 5, y: 7, z: 10, t: 10}) assert a.div(b) == divmod(a, b) == \ (Factors({x: 5, z: 4}), Factors({y: 1, t: 10})) assert a.quo(b) == a/b == Factors({x: 5, z: 4}) assert a.rem(b) == a % b == Factors({y: 1, t: 10}) assert a.pow(3) == a**3 == Factors({x: 15, y: 9, z: 21}) assert b.pow(3) == b**3 == Factors({y: 12, z: 9, t: 30}) pytest.raises(ValueError, lambda: a.pow(3.1)) pytest.raises(ValueError, lambda: a.pow(Factors(3.1))) assert a.pow(0) == Factors() assert a.gcd(b) == Factors({y: 3, z: 3}) assert a.lcm(b) == a.lcm(b.as_expr()) == Factors({x: 5, y: 4, z: 7, t: 10}) a = Factors({x: 4, y: 7, t: 7}) b = Factors({z: 1, t: 3}) assert a.normal(b) == (Factors({x: 4, y: 7, t: 4}), Factors({z: 1})) assert Factors(sqrt(2)*x).as_expr() == sqrt(2)*x assert Factors(-I)*I == Factors() assert Factors({Integer(-1): Integer(3)})*Factors({Integer(-1): Integer(1), I: Integer(5)}) == \ Factors(I) assert Factors(Integer(2)**x).div(Integer(3)**x) == \ (Factors({Integer(2): x}), Factors({Integer(3): x})) assert Factors(2**(2*x + 2)).div(Integer(8)) == \ (Factors({Integer(2): 2*x + 2}), Factors({Integer(8): Integer(1)})) # coverage # /!\ things break if this is not True assert Factors({Integer(-1): Rational(3, 2)}) == Factors({I: 1, -1: 1}) assert Factors({I: Integer(1), Integer(-1): Rational(1, 3)}).as_expr() == I*cbrt(-1) assert Factors(-1.) == Factors({Integer(-1): Integer(1), Float(1.): 1}) assert Factors(-2.) == Factors({Integer(-1): Integer(1), Float(2.): 1}) assert Factors((-2.)**x) == Factors({Float(-2.): x}) assert Factors(Integer(-2)) == Factors({Integer(-1): Integer(1), Integer(2): 1}) assert Factors(Rational(1, 2)) == Factors({Integer(2): -1}) assert Factors(Rational(3, 2)) == Factors({Integer(3): 1, Integer(2): Integer(-1)}) assert Factors({I: Integer(1)}) == Factors(I) assert Factors({-1.0: 2, I: 1}) == Factors({Float(1.0): 1, I: 1}) assert Factors({-1: -Rational(3, 2)}).as_expr() == I A = symbols('A', commutative=False) assert Factors(2*A**2) == Factors({Integer(2): 1, A**2: 1}) assert Factors(I) == Factors({I: 1}) assert Factors(x).normal(Integer(2)) == (Factors(x), Factors(Integer(2))) assert Factors(x).normal(Integer(0)) == (Factors(), Factors(Integer(0))) pytest.raises(ZeroDivisionError, lambda: Factors(x).div(Integer(0))) assert Factors(x).mul(Integer(2)) == Factors(2*x) assert Factors(x).mul(Integer(0)).is_zero assert Factors(x).mul(1/x).is_one assert Factors(x**sqrt(8)).as_expr() == x**(2*sqrt(2)) assert Factors(x)**Factors(Integer(2)) == Factors(x**2) assert Factors(x).gcd(Integer(0)) == Factors(x) assert Factors(x).lcm(Integer(0)).is_zero assert Factors(Integer(0)).div(x) == (Factors(Integer(0)), Factors()) assert Factors(x).div(x) == (Factors(), Factors()) assert Factors({x: .2})/Factors({x: .2}) == Factors() assert Factors(x) != Factors() assert Factors(x) == x assert Factors(Integer(0)).normal(x) == (Factors(Integer(0)), Factors()) n, d = x**(2 + y), x**2 f = Factors(n) assert f.div(d) == f.normal(d) == (Factors(x**y), Factors()) assert f.gcd(d) == Factors() d = x**y assert f.div(d) == f.normal(d) == (Factors(x**2), Factors()) assert f.gcd(d) == Factors(d) n = d = 2**x f = Factors(n) assert f.div(d) == f.normal(d) == (Factors(), Factors()) assert f.gcd(d) == Factors(d) n, d = 2**x, 2**y f = Factors(n) assert f.div(d) == f.normal(d) == (Factors({Integer(2): x}), Factors({Integer(2): y})) assert f.gcd(d) == Factors() assert f.div(f) == (Factors(), Factors()) # extraction of constant only n = x**(x + 3) assert Factors(n).normal(x**-3) == (Factors({x: x + 6}), Factors({})) assert Factors(n).normal(x**3) == (Factors({x: x}), Factors({})) assert Factors(n).normal(x**4) == (Factors({x: x}), Factors({x: 1})) assert Factors(n).normal(x**(y - 3)) == \ (Factors({x: x + 6}), Factors({x: y})) assert Factors(n).normal(x**(y + 3)) == (Factors({x: x}), Factors({x: y})) assert Factors(n).normal(x**(y + 4)) == \ (Factors({x: x}), Factors({x: y + 1})) assert Factors(n).div(x**-3) == (Factors({x: x + 6}), Factors({})) assert Factors(n).div(x**3) == (Factors({x: x}), Factors({})) assert Factors(n).div(x**4) == (Factors({x: x}), Factors({x: 1})) assert Factors(n).div(x**(y - 3)) == \ (Factors({x: x + 6}), Factors({x: y})) assert Factors(n).div(x**(y + 3)) == (Factors({x: x}), Factors({x: y})) assert Factors(n).div(x**(y + 4)) == \ (Factors({x: x}), Factors({x: y + 1})) assert Factors({I: I}).as_expr() == (-1)**(I/2) assert Factors({-1: Rational(4, 3)}).as_expr() == -cbrt(-1)
def test_Term(): a = Term(4 * x * y**2 / z / t**3) b = Term(2 * x**3 * y**5 / t**3) assert a == Term(4, Factors({x: 1, y: 2}), Factors({z: 1, t: 3})) assert b == Term(2, Factors({x: 3, y: 5}), Factors({t: 3})) assert a.as_expr() == 4 * x * y**2 / z / t**3 assert b.as_expr() == 2 * x**3 * y**5 / t**3 assert a.inv() == \ Term(Rational(1, 4), Factors({z: 1, t: 3}), Factors({x: 1, y: 2})) assert b.inv() == Term(Rational(1, 2), Factors({t: 3}), Factors({ x: 3, y: 5 })) assert a.mul(b) == a*b == \ Term(8, Factors({x: 4, y: 7}), Factors({z: 1, t: 6})) assert a.quo(b) == a / b == Term(2, Factors({}), Factors({ x: 2, y: 3, z: 1 })) assert a.pow(3) == a**3 == \ Term(64, Factors({x: 3, y: 6}), Factors({z: 3, t: 9})) assert b.pow(3) == b**3 == Term(8, Factors({x: 9, y: 15}), Factors({t: 9})) assert a.pow(-3) == a**(-3) == \ Term(Rational(1, 64), Factors({z: 3, t: 9}), Factors({x: 3, y: 6})) assert b.pow(-3) == b**(-3) == \ Term(Rational(1, 8), Factors({t: 9}), Factors({x: 9, y: 15})) assert a.gcd(b) == Term(2, Factors({x: 1, y: 2}), Factors({t: 3})) assert a.lcm(b) == Term(4, Factors({x: 3, y: 5}), Factors({z: 1, t: 3})) a = Term(4 * x * y**2 / z / t**3) b = Term(2 * x**3 * y**5 * t**7) assert a.mul(b) == Term(8, Factors({x: 4, y: 7, t: 4}), Factors({z: 1})) assert Term((2 * x + 2)**3) == Term(8, Factors({x + 1: 3}), Factors({})) assert Term((2*x + 2)*(3*x + 6)**2) == \ Term(18, Factors({x + 1: 1, x + 2: 2}), Factors({}))
def _minpoly_compose(ex, x, dom): """ Computes the minimal polynomial of an algebraic element using operations on minimal polynomials Examples ======== >>> from diofant import minimal_polynomial, sqrt, Rational >>> from diofant.abc import x, y >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=True) x**2 - 2*x - 1 >>> minimal_polynomial(sqrt(y) + 1/y, x, compose=True) x**2*y**2 - 2*x*y - y**3 + 1 """ if ex.is_Rational: return ex.q * x - ex.p if ex is I: return x**2 + 1 if ex is GoldenRatio: return x**2 - x - 1 if hasattr(dom, 'symbols') and ex in dom.symbols: return x - ex if dom.is_QQ and _is_sum_surds(ex): # eliminate the square roots ex -= x while 1: ex1 = _separate_sq(ex) if ex1 is ex: return ex else: ex = ex1 if ex.is_Add: res = _minpoly_add(x, dom, *ex.args) elif ex.is_Mul: f = Factors(ex).factors r = sift(f.items(), lambda itx: itx[0].is_Rational and itx[1].is_Rational) if r[True] and dom == QQ: ex1 = Mul(*[bx**ex for bx, ex in r[False] + r[None]]) r1 = r[True] dens = [y.q for _, y in r1] lcmdens = reduce(lcm, dens, 1) nums = [base**(y.p * lcmdens // y.q) for base, y in r1] ex2 = Mul(*nums) mp1 = minimal_polynomial(ex1, x) # use the fact that in Diofant canonicalization products of integers # raised to rational powers are organized in relatively prime # bases, and that in ``base**(n/d)`` a perfect power is # simplified with the root mp2 = ex2.q * x**lcmdens - ex2.p ex2 = ex2**Rational(1, lcmdens) res = _minpoly_op_algebraic_element(Mul, ex1, ex2, x, dom, mp1=mp1, mp2=mp2) else: res = _minpoly_mul(x, dom, *ex.args) elif ex.is_Pow: if ex.base is S.Exp1: res = _minpoly_exp(ex, x) else: res = _minpoly_pow(ex.base, ex.exp, x, dom) elif ex.__class__ is sin: res = _minpoly_sin(ex, x) elif ex.__class__ is cos: res = _minpoly_cos(ex, x) elif ex.__class__ is RootOf: res = _minpoly_rootof(ex, x) else: raise NotAlgebraic("%s doesn't seem to be an algebraic element" % ex) return res
def radsimp(expr, symbolic=True, max_terms=4): """ Rationalize the denominator by removing square roots. Note: the expression returned from radsimp must be used with caution since if the denominator contains symbols, it will be possible to make substitutions that violate the assumptions of the simplification process: that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If there are no symbols, this assumptions is made valid by collecting terms of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If you do not want the simplification to occur for symbolic denominators, set ``symbolic`` to False. If there are more than ``max_terms`` radical terms then the expression is returned unchanged. Examples ======== >>> from diofant import radsimp, sqrt, Symbol, denom, pprint, I >>> from diofant import factor_terms, fraction, signsimp >>> from diofant.simplify.radsimp import collect_sqrt >>> from diofant.abc import a, b, c >>> radsimp(1/(I + 1)) (1 - I)/2 >>> radsimp(1/(2 + sqrt(2))) (-sqrt(2) + 2)/2 >>> x,y = map(Symbol, 'xy') >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2)) >>> radsimp(e) sqrt(2)*(x + y) No simplification beyond removal of the gcd is done. One might want to polish the result a little, however, by collecting square root terms: >>> r2 = sqrt(2) >>> r5 = sqrt(5) >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)) >>> pprint(ans, use_unicode=False) ___ ___ ___ ___ \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y ------------------------------------------ 2 2 2 2 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y >>> n, d = fraction(ans) >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True), use_unicode=False) ___ ___ \/ 5 *(a + b) - \/ 2 *(x + y) ------------------------------------------ 2 2 2 2 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y If radicals in the denominator cannot be removed or there is no denominator, the original expression will be returned. >>> radsimp(sqrt(2)*x + sqrt(2)) sqrt(2)*x + sqrt(2) Results with symbols will not always be valid for all substitutions: >>> eq = 1/(a + b*sqrt(c)) >>> eq.subs(a, b*sqrt(c)) 1/(2*b*sqrt(c)) >>> radsimp(eq).subs(a, b*sqrt(c)) nan If symbolic=False, symbolic denominators will not be transformed (but numeric denominators will still be processed): >>> radsimp(eq, symbolic=False) 1/(a + b*sqrt(c)) """ from diofant.simplify.simplify import signsimp syms = symbols("a:d A:D") def _num(rterms): # return the multiplier that will simplify the expression described # by rterms [(sqrt arg, coeff), ... ] a, b, c, d, A, B, C, D = syms if len(rterms) == 2: reps = dict(zip([A, a, B, b], [j for i in rterms for j in i])) return (sqrt(A) * a - sqrt(B) * b).xreplace(reps) if len(rterms) == 3: reps = dict(zip([A, a, B, b, C, c], [j for i in rterms for j in i])) return ((sqrt(A) * a + sqrt(B) * b - sqrt(C) * c) * (2 * sqrt(A) * sqrt(B) * a * b - A * a**2 - B * b**2 + C * c**2)).xreplace(reps) elif len(rterms) == 4: reps = dict( zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i])) return ( (sqrt(A) * a + sqrt(B) * b - sqrt(C) * c - sqrt(D) * d) * (2 * sqrt(A) * sqrt(B) * a * b - A * a**2 - B * b**2 - 2 * sqrt(C) * sqrt(D) * c * d + C * c**2 + D * d**2) * (-8 * sqrt(A) * sqrt(B) * sqrt(C) * sqrt(D) * a * b * c * d + A**2 * a**4 - 2 * A * B * a**2 * b**2 - 2 * A * C * a**2 * c**2 - 2 * A * D * a**2 * d**2 + B**2 * b**4 - 2 * B * C * b**2 * c**2 - 2 * B * D * b**2 * d**2 + C**2 * c**4 - 2 * C * D * c**2 * d**2 + D**2 * d**4)).xreplace(reps) elif len(rterms) == 1: return sqrt(rterms[0][0]) else: raise NotImplementedError def ispow2(d, log2=False): if not d.is_Pow: return False e = d.exp if e.is_Rational and e.q == 2 or symbolic and fraction(e)[1] == 2: return True if log2: q = 1 if e.is_Rational: q = e.q elif symbolic: d = fraction(e)[1] if d.is_Integer: q = d if q != 1 and log(q, 2).is_Integer: return True return False def handle(expr): # Handle first reduces to the case # expr = 1/d, where d is an add, or d is base**p/2. # We do this by recursively calling handle on each piece. from diofant.simplify.simplify import nsimplify n, d = fraction(expr) if expr.is_Atom or (d.is_Atom and n.is_Atom): return expr elif not n.is_Atom: n = n.func(*[handle(a) for a in n.args]) return _unevaluated_Mul(n, handle(1 / d)) elif n is not S.One: return _unevaluated_Mul(n, handle(1 / d)) elif d.is_Mul: return _unevaluated_Mul(*[handle(1 / d) for d in d.args]) # By this step, expr is 1/d, and d is not a mul. if not symbolic and d.free_symbols: return expr if ispow2(d): d2 = sqrtdenest(sqrt(d.base))**fraction(d.exp)[0] if d2 != d: return handle(1 / d2) elif d.is_Pow and (d.exp.is_integer or d.base.is_positive): # (1/d**i) = (1/d)**i return handle(1 / d.base)**d.exp if not (d.is_Add or ispow2(d)): return 1 / d.func(*[handle(a) for a in d.args]) # handle 1/d treating d as an Add (though it may not be) keep = True # keep changes that are made # flatten it and collect radicals after checking for special # conditions d = _mexpand(d) # did it change? if d.is_Atom: return 1 / d # is it a number that might be handled easily? if d.is_number: _d = nsimplify(d) if _d.is_Number and _d.equals(d): return 1 / _d while True: # collect similar terms collected = defaultdict(list) for m in Add.make_args(d): # d might have become non-Add p2 = [] other = [] for i in Mul.make_args(m): if ispow2(i, log2=True): p2.append(i.base if i.exp is S.Half else i.base**( 2 * i.exp)) elif i is S.ImaginaryUnit: p2.append(S.NegativeOne) else: other.append(i) collected[tuple(ordered(p2))].append(Mul(*other)) rterms = list(ordered(list(collected.items()))) rterms = [(Mul(*i), Add(*j)) for i, j in rterms] nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0) if nrad < 1: break elif nrad > max_terms: # there may have been invalid operations leading to this point # so don't keep changes, e.g. this expression is troublesome # in collecting terms so as not to raise the issue of 2834: # r = sqrt(sqrt(5) + 5) # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r) keep = False break if len(rterms) > 4: # in general, only 4 terms can be removed with repeated squaring # but other considerations can guide selection of radical terms # so that radicals are removed if all( [x.is_Integer and (y**2).is_Rational for x, y in rterms]): nd, d = rad_rationalize( S.One, Add._from_args([sqrt(x) * y for x, y in rterms])) n *= nd else: # is there anything else that might be attempted? keep = False break from diofant.simplify.powsimp import powsimp, powdenest num = powsimp(_num(rterms)) n *= num d *= num d = powdenest(_mexpand(d), force=symbolic) if d.is_Atom: break if not keep: return expr return _unevaluated_Mul(n, 1 / d) coeff, expr = expr.as_coeff_Add() expr = expr.normal() old = fraction(expr) n, d = fraction(handle(expr)) if old != (n, d): if not d.is_Atom: was = (n, d) n = signsimp(n, evaluate=False) d = signsimp(d, evaluate=False) u = Factors(_unevaluated_Mul(n, 1 / d)) u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()]) n, d = fraction(u) if old == (n, d): n, d = was n = expand_mul(n) if d.is_Number or d.is_Add: n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1 / d))) if d2.is_Number or (d2.count_ops() <= d.count_ops()): n, d = [signsimp(i) for i in (n2, d2)] if n.is_Mul and n.args[0].is_Number: n = n.func(*n.args) return coeff + _unevaluated_Mul(n, 1 / d)
def collect_const(expr, *vars, **kwargs): """A non-greedy collection of terms with similar number coefficients in an Add expr. If ``vars`` is given then only those constants will be targeted. Although any Number can also be targeted, if this is not desired set ``Numbers=False`` and no Float or Rational will be collected. Examples ======== >>> from diofant import sqrt >>> from diofant.abc import a, s, x, y, z >>> from diofant.simplify.radsimp import collect_const >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2))) sqrt(3)*(sqrt(2) + 2) >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7)) (sqrt(3) + sqrt(7))*(s + 1) >>> s = sqrt(2) + 2 >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7)) (sqrt(2) + 3)*(sqrt(3) + sqrt(7)) >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3)) sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2) The collection is sign-sensitive, giving higher precedence to the unsigned values: >>> collect_const(x - y - z) x - (y + z) >>> collect_const(-y - z) -(y + z) >>> collect_const(2*x - 2*y - 2*z, 2) 2*(x - y - z) >>> collect_const(2*x - 2*y - 2*z, -2) 2*x - 2*(y + z) See Also ======== collect, collect_sqrt, rcollect """ if not expr.is_Add: return expr recurse = False Numbers = kwargs.get('Numbers', True) if not vars: recurse = True vars = set() for a in expr.args: for m in Mul.make_args(a): if m.is_number: vars.add(m) else: vars = sympify(vars) if not Numbers: vars = [v for v in vars if not v.is_Number] vars = list(ordered(vars)) for v in vars: terms = defaultdict(list) Fv = Factors(v) for m in Add.make_args(expr): f = Factors(m) q, r = f.div(Fv) if r.is_one: # only accept this as a true factor if # it didn't change an exponent from an Integer # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2) # -- we aren't looking for this sort of change fwas = f.factors.copy() fnow = q.factors if not any(k in fwas and fwas[k].is_Integer and not fnow[k].is_Integer for k in fnow): terms[v].append(q.as_expr()) continue terms[S.One].append(m) args = [] hit = False uneval = False for k in ordered(terms): v = terms[k] if k is S.One: args.extend(v) continue if len(v) > 1: v = Add(*v) hit = True if recurse and v != expr: vars.append(v) else: v = v[0] # be careful not to let uneval become True unless # it must be because it's going to be more expensive # to rebuild the expression as an unevaluated one if Numbers and k.is_Number and v.is_Add: args.append(_keep_coeff(k, v, sign=True)) uneval = True else: args.append(k * v) if hit: if uneval: expr = _unevaluated_Add(*args) else: expr = Add(*args) if not expr.is_Add: break return expr