def _sqrt_symbolic_denest(a, b, r): """Given an expression, sqrt(a + b*sqrt(b)), return the denested expression or None. Algorithm: If r = ra + rb*sqrt(rr), try replacing sqrt(rr) in ``a`` with (y**2 - ra)/rb, and if the result is a quadratic, ca*y**2 + cb*y + cc, and (cb + b)**2 - 4*ca*cc is 0, then sqrt(a + b*sqrt(r)) can be rewritten as sqrt(ca*(sqrt(r) + (cb + b)/(2*ca))**2). Examples ======== >>> from diofant.simplify.sqrtdenest import _sqrt_symbolic_denest, sqrtdenest >>> from diofant import sqrt, Symbol >>> from diofant.abc import x >>> a, b, r = 16 - 2*sqrt(29), 2, -10*sqrt(29) + 55 >>> _sqrt_symbolic_denest(a, b, r) sqrt(-2*sqrt(29) + 11) + sqrt(5) If the expression is numeric, it will be simplified: >>> w = sqrt(sqrt(sqrt(3) + 1) + 1) + 1 + sqrt(2) >>> sqrtdenest(sqrt((w**2).expand())) 1 + sqrt(2) + sqrt(1 + sqrt(1 + sqrt(3))) Otherwise, it will only be simplified if assumptions allow: >>> w = w.subs(sqrt(3), sqrt(x + 3)) >>> sqrtdenest(sqrt((w**2).expand())) sqrt((sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2))**2) Notice that the argument of the sqrt is a square. If x is made positive then the sqrt of the square is resolved: >>> _.subs(x, Symbol('x', positive=True)) sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2) """ a, b, r = map(sympify, (a, b, r)) rval = _sqrt_match(r) if not rval: return ra, rb, rr = rval if rb: y = Dummy('y', positive=True) try: newa = Poly(a.subs(sqrt(rr), (y**2 - ra) / rb), y) except PolynomialError: return if newa.degree() == 2: ca, cb, cc = newa.all_coeffs() cb += b if _mexpand(cb**2 - 4 * ca * cc).equals(0): z = sqrt(ca * (sqrt(r) + cb / (2 * ca))**2) if z.is_number: z = _mexpand(Mul._from_args(z.as_content_primitive())) return z
def sub_pre(e): """ Replace y - x with -(x - y) if -1 can be extracted from y - x. """ reps = [a for a in e.atoms(Add) if a.could_extract_minus_sign()] # make it canonical reps.sort(key=default_sort_key) e = e.xreplace({a: Mul._from_args([S.NegativeOne, -a]) for a in reps}) # repeat again for persisting Adds but mark these with a leading 1, -1 # e.g. y - x -> 1*-1*(x - y) if isinstance(e, Basic): negs = {} for a in sorted(e.atoms(Add), key=default_sort_key): if a in reps or a.could_extract_minus_sign(): negs[a] = Mul._from_args([S.One, S.NegativeOne, -a]) e = e.xreplace(negs) return e
def sub_post(e): """ Replace 1*-1*x with -x. """ replacements = [] for node in preorder_traversal(e): if isinstance(node, Mul) and \ node.args[0] is S.One and node.args[1] is S.NegativeOne: replacements.append((node, -Mul._from_args(node.args[2:]))) for node, replacement in replacements: e = e.xreplace({node: replacement}) return e
def _sqrt_match(p): """Return [a, b, r] for p.match(a + b*sqrt(r)) where, in addition to matching, sqrt(r) also has then maximal sqrt_depth among addends of p. Examples ======== >>> from diofant.functions.elementary.miscellaneous import sqrt >>> from diofant.simplify.sqrtdenest import _sqrt_match >>> _sqrt_match(1 + sqrt(2) + sqrt(2)*sqrt(3) + 2*sqrt(1+sqrt(5))) [1 + sqrt(2) + sqrt(6), 2, 1 + sqrt(5)] """ from diofant.simplify.radsimp import split_surds p = _mexpand(p) if p.is_Number: res = (p, S.Zero, S.Zero) elif p.is_Add: pargs = sorted(p.args, key=default_sort_key) if all((x**2).is_Rational for x in pargs): r, b, a = split_surds(p) res = a, b, r return list(res) # to make the process canonical, the argument is included in the tuple # so when the max is selected, it will be the largest arg having a # given depth v = [(sqrt_depth(x), x, i) for i, x in enumerate(pargs)] nmax = max(v, key=default_sort_key) if nmax[0] == 0: res = [] else: # select r depth, _, i = nmax r = pargs.pop(i) v.pop(i) b = S.One if r.is_Mul: bv = [] rv = [] for x in r.args: if sqrt_depth(x) < depth: bv.append(x) else: rv.append(x) b = Mul._from_args(bv) r = Mul._from_args(rv) # collect terms comtaining r a1 = [] b1 = [b] for x in v: if x[0] < depth: a1.append(x[1]) else: x1 = x[1] if x1 == r: b1.append(1) else: if x1.is_Mul: x1args = list(x1.args) if r in x1args: x1args.remove(r) b1.append(Mul(*x1args)) else: a1.append(x[1]) else: a1.append(x[1]) a = Add(*a1) b = Add(*b1) res = (a, b, r**2) else: b, r = p.as_coeff_Mul() if is_sqrt(r): res = (S.Zero, b, r**2) else: res = [] return list(res)
def __trigsimp(expr, deep=False): """recursive helper for trigsimp""" from diofant.simplify.fu import TR10i if _trigpat is None: _trigpats() a, b, c, d, matchers_division, matchers_add, \ matchers_identity, artifacts = _trigpat if expr.is_Mul: # do some simplifications like sin/cos -> tan: if not expr.is_commutative: com, nc = expr.args_cnc() expr = _trigsimp(Mul._from_args(com), deep) * Mul._from_args(nc) else: for i, (pattern, simp, ok1, ok2) in enumerate(matchers_division): if not _dotrig(expr, pattern): continue newexpr = _match_div_rewrite(expr, i) if newexpr is not None: if newexpr != expr: expr = newexpr break else: continue # use Diofant matching instead res = expr.match(pattern) if res and res.get(c, 0): if not res[c].is_integer: ok = ok1.subs(res) if not ok.is_positive: continue ok = ok2.subs(res) if not ok.is_positive: continue # if "a" contains any of trig or hyperbolic funcs with # argument "b" then skip the simplification if any(w.args[0] == res[b] for w in res[a].atoms( TrigonometricFunction, HyperbolicFunction)): continue # simplify and finish: expr = simp.subs(res) break # process below if expr.is_Add: args = [] for term in expr.args: if not term.is_commutative: com, nc = term.args_cnc() nc = Mul._from_args(nc) term = Mul._from_args(com) else: nc = S.One term = _trigsimp(term, deep) for pattern, result in matchers_identity: res = term.match(pattern) if res is not None: term = result.subs(res) break args.append(term * nc) if args != expr.args: expr = Add(*args) expr = min(expr, expand(expr), key=count_ops) if expr.is_Add: for pattern, result in matchers_add: if not _dotrig(expr, pattern): continue expr = TR10i(expr) if expr.has(HyperbolicFunction): res = expr.match(pattern) # if "d" contains any trig or hyperbolic funcs with # argument "a" or "b" then skip the simplification; # this isn't perfect -- see tests if res is None or not (a in res and b in res) or any( w.args[0] in (res[a], res[b]) for w in res[d].atoms(TrigonometricFunction, HyperbolicFunction)): continue expr = result.subs(res) break # Reduce any lingering artifacts, such as sin(x)**2 changing # to 1 - cos(x)**2 when sin(x)**2 was "simpler" for pattern, result, ex in artifacts: if not _dotrig(expr, pattern): continue # Substitute a new wild that excludes some function(s) # to help influence a better match. This is because # sometimes, for example, 'a' would match sec(x)**2 a_t = Wild('a', exclude=[ex]) pattern = pattern.subs(a, a_t) result = result.subs(a, a_t) m = expr.match(pattern) was = None while m and was != expr: was = expr if m[a_t] == 0 or \ -m[a_t] in m[c].args or m[a_t] + m[c] == 0: break if d in m and m[a_t] * m[d] + m[c] == 0: break expr = result.subs(m) m = expr.match(pattern) m.setdefault(c, S.Zero) elif expr.is_Mul or expr.is_Pow or deep and expr.args: expr = expr.func(*[_trigsimp(a, deep) for a in expr.args]) try: if not expr.has(*_trigs): raise TypeError e = {a for a in expr.atoms(Pow) if a.base is S.Exp1} new = expr.rewrite(exp, deep=deep) if new == e: raise TypeError fnew = factor(new) if fnew != new: new = sorted([new, factor(new)], key=count_ops)[0] # if all exp that were introduced disappeared then accept it ne = {a for a in new.atoms(Pow) if a.base is S.Exp1} if not (ne - e): expr = new except TypeError: pass return expr