def _sqrtdenest_rec(expr): """Helper that denests the square root of three or more surds. It returns the denested expression; if it cannot be denested it throws SqrtdenestStopIteration Algorithm: expr.base is in the extension Q_m = Q(sqrt(r_1),..,sqrt(r_k)); split expr.base = a + b*sqrt(r_k), where `a` and `b` are on Q_(m-1) = Q(sqrt(r_1),..,sqrt(r_(k-1))); then a**2 - b**2*r_k is on Q_(m-1); denest sqrt(a**2 - b**2*r_k) and so on. See [1], section 6. Examples ======== >>> from sympy import sqrt >>> from sympy.simplify.sqrtdenest import _sqrtdenest_rec >>> _sqrtdenest_rec(sqrt(-72*sqrt(2) + 158*sqrt(5) + 498)) -sqrt(10) + sqrt(2) + 9 + 9*sqrt(5) >>> w=-6*sqrt(55)-6*sqrt(35)-2*sqrt(22)-2*sqrt(14)+2*sqrt(77)+6*sqrt(10)+65 >>> _sqrtdenest_rec(sqrt(w)) -sqrt(11) - sqrt(7) + sqrt(2) + 3*sqrt(5) """ from sympy.simplify.simplify import radsimp, split_surds, rad_rationalize if not expr.is_Pow: return sqrtdenest(expr) if expr.base < 0: return sqrt(-1)*_sqrtdenest_rec(sqrt(-expr.base)) g, a, b = split_surds(expr.base) a = a*sqrt(g) if a < b: a, b = b, a c2 = _mexpand(a**2 - b**2) if len(c2.args) > 2: g, a1, b1 = split_surds(c2) a1 = a1*sqrt(g) if a1 < b1: a1, b1 = b1, a1 c2_1 = _mexpand(a1**2 - b1**2) c_1 = _sqrtdenest_rec(sqrt(c2_1)) d_1 = _sqrtdenest_rec(sqrt(a1 + c_1)) num, den = rad_rationalize(b1, d_1) c = _mexpand(d_1/sqrt(2) + num/(den*sqrt(2))) else: c = _sqrtdenest1(sqrt(c2)) if sqrt_depth(c) > 1: raise SqrtdenestStopIteration ac = a + c if len(ac.args) >= len(expr.args): if count_ops(ac) >= count_ops(expr.base): raise SqrtdenestStopIteration d = sqrtdenest(sqrt(ac)) if sqrt_depth(d) > 1: raise SqrtdenestStopIteration num, den = rad_rationalize(b, d) r = d/sqrt(2) + num/(den*sqrt(2)) r = radsimp(r) return _mexpand(r)
def _sqrt_match(p): """Return [a, b, r] for p.match(a + b*sqrt(r)) where, in addition to matching, sqrt(r) also has then maximal sqrt_depth among addends of p. Examples ======== >>> from sympy.functions.elementary.miscellaneous import sqrt >>> from sympy.simplify.sqrtdenest import _sqrt_match >>> _sqrt_match(1 + sqrt(2) + sqrt(2)*sqrt(3) + 2*sqrt(1+sqrt(5))) [1 + sqrt(2) + sqrt(6), 2, 1 + sqrt(5)] """ from sympy.simplify.simplify import split_surds p = _mexpand(p) if p.is_Number: res = (p, S.Zero, S.Zero) elif p.is_Add: pargs = sorted(p.args, key=default_sort_key) if all((x**2).is_Rational for x in pargs): r, b, a = split_surds(p) res = a, b, r return list(res) # to make the process canonical, the argument is included in the tuple # so when the max is selected, it will be the largest arg having a # given depth v = [(sqrt_depth(x), x, i) for i, x in enumerate(pargs)] nmax = max(v, key=default_sort_key) if nmax[0] == 0: res = [] else: # select r depth, _, i = nmax r = pargs.pop(i) v.pop(i) b = S.One if r.is_Mul: bv = [] rv = [] for x in r.args: if sqrt_depth(x) < depth: bv.append(x) else: rv.append(x) b = Mul._from_args(bv) r = Mul._from_args(rv) # collect terms comtaining r a1 = [] b1 = [b] for x in v: if x[0] < depth: a1.append(x[1]) else: x1 = x[1] if x1 == r: b1.append(1) else: if x1.is_Mul: x1args = list(x1.args) if r in x1args: x1args.remove(r) b1.append(Mul(*x1args)) else: a1.append(x[1]) else: a1.append(x[1]) a = Add(*a1) b = Add(*b1) #a = Add._from_args(pargs) res = (a, b, r**2) else: b, r = p.as_coeff_Mul() if is_sqrt(r): res = (S.Zero, b, r**2) else: res = [] return list(res)