예제 #1
0
파일: rootoftools.py 프로젝트: Lenqth/sympy
 def _refine_imaginary(cls, complexes):
     sifted = sift(complexes, lambda c: c[1])
     complexes = []
     for f in ordered(sifted):
         nimag = _imag_count_of_factor(f)
         if nimag == 0:
             # refine until xbounds are neg or pos
             for u, f, k in sifted[f]:
                 while u.ax*u.bx <= 0:
                     u = u._inner_refine()
                 complexes.append((u, f, k))
         else:
             # refine until all but nimag xbounds are neg or pos
             potential_imag = list(range(len(sifted[f])))
             while True:
                 assert len(potential_imag) > 1
                 for i in list(potential_imag):
                     u, f, k = sifted[f][i]
                     if u.ax*u.bx > 0:
                         potential_imag.remove(i)
                     elif u.ax != u.bx:
                         u = u._inner_refine()
                         sifted[f][i] = u, f, k
                 if len(potential_imag) == nimag:
                     break
             complexes.extend(sifted[f])
     return complexes
예제 #2
0
 def _refine_imaginary(cls, complexes):
     sifted = sift(complexes, lambda c: c[1])
     complexes = []
     for f in ordered(sifted):
         nimag = _imag_count_of_factor(f)
         if nimag == 0:
             # refine until xbounds are neg or pos
             for u, f, k in sifted[f]:
                 while u.ax * u.bx <= 0:
                     u = u._inner_refine()
                 complexes.append((u, f, k))
         else:
             # refine until all but nimag xbounds are neg or pos
             potential_imag = list(range(len(sifted[f])))
             while True:
                 assert len(potential_imag) > 1
                 for i in list(potential_imag):
                     u, f, k = sifted[f][i]
                     if u.ax * u.bx > 0:
                         potential_imag.remove(i)
                     elif u.ax != u.bx:
                         u = u._inner_refine()
                         sifted[f][i] = u, f, k
                 if len(potential_imag) == nimag:
                     break
             complexes.extend(sifted[f])
     return complexes
예제 #3
0
파일: matadd.py 프로젝트: asmeurer/sympy
def merge_explicit(matadd):
    """ Merge explicit MatrixBase arguments

    Examples
    ========

    >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint
    >>> from sympy.matrices.expressions.matadd import merge_explicit
    >>> A = MatrixSymbol('A', 2, 2)
    >>> B = eye(2)
    >>> C = Matrix([[1, 2], [3, 4]])
    >>> X = MatAdd(A, B, C)
    >>> pprint(X)
        [1  0]   [1  2]
    A + [    ] + [    ]
        [0  1]   [3  4]
    >>> pprint(merge_explicit(X))
        [2  2]
    A + [    ]
        [3  5]
    """
    groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase))
    if len(groups[True]) > 1:
        return MatAdd(*(groups[False] + [reduce(add, groups[True])]))
    else:
        return matadd
예제 #4
0
def merge_explicit(matadd):
    """ Merge explicit MatrixBase arguments

    Examples
    ========

    >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint
    >>> from sympy.matrices.expressions.matadd import merge_explicit
    >>> A = MatrixSymbol('A', 2, 2)
    >>> B = eye(2)
    >>> C = Matrix([[1, 2], [3, 4]])
    >>> X = MatAdd(A, B, C)
    >>> pprint(X)
        [1  0]   [1  2]
    A + [    ] + [    ]
        [0  1]   [3  4]
    >>> pprint(merge_explicit(X))
        [2  2]
    A + [    ]
        [3  5]
    """
    groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase))
    if len(groups[True]) > 1:
        return MatAdd(*(groups[False] + [reduce(add, groups[True])]))
    else:
        return matadd
예제 #5
0
    def _separate_imaginary_from_complex(cls, complexes):
        def is_imag(c):
            '''
            return True if all roots are imaginary (ax**2 + b)
            return False if no roots are imaginary
            return None if 2 roots are imaginary (ax**N'''
            u, f, k = c
            deg = f.degree()
            if f.length() == 2:
                if deg == 2:
                    return True  # both imag
                elif _ispow2(deg):
                    if f.LC() * f.TC() < 0:
                        return None  # 2 are imag
            return False  # none are imag

        # separate according to the function
        sifted = sift(complexes, lambda c: c[1])
        del complexes
        imag = []
        complexes = []
        for f in sifted:
            isift = sift(sifted[f], lambda c: is_imag(c))
            imag.extend(isift.pop(True, []))
            complexes.extend(isift.pop(False, []))
            mixed = isift.pop(None, [])
            assert not isift
            if not mixed:
                continue
            while True:
                # the non-imaginary ones will be on one side or the other
                # of the y-axis
                i = 0
                while i < len(mixed):
                    u, f, k = mixed[i]
                    if u.ax * u.bx > 0:
                        complexes.append(mixed.pop(i))
                    else:
                        i += 1
                if len(mixed) == 2:
                    imag.extend(mixed)
                    break
                # refine
                for i, (u, f, k) in enumerate(mixed):
                    u = u._inner_refine()
                    mixed[i] = u, f, k
        return imag, complexes
예제 #6
0
파일: kronecker.py 프로젝트: cklb/sympy
def kronecker_mat_add(expr):
    from functools import reduce
    args = sift(expr.args, _kronecker_dims_key)
    nonkrons = args.pop((0,), None)
    if not args:
        return expr

    krons = [reduce(lambda x, y: x._kronecker_add(y), group)
             for group in args.values()]

    if not nonkrons:
        return MatAdd(*krons)
    else:
        return MatAdd(*krons) + nonkrons
예제 #7
0
def kronecker_mat_add(expr):
    from functools import reduce
    args = sift(expr.args, _kronecker_dims_key)
    nonkrons = args.pop((0,), None)
    if not args:
        return expr

    krons = [reduce(lambda x, y: x._kronecker_add(y), group)
             for group in args.values()]

    if not nonkrons:
        return MatAdd(*krons)
    else:
        return MatAdd(*krons) + nonkrons
예제 #8
0
def bc_matadd(expr):
    args = sift(expr.args, lambda M: isinstance(M, BlockMatrix))
    blocks = args[True]
    if not blocks:
        return expr

    nonblocks = args[False]
    block = blocks[0]
    for b in blocks[1:]:
        block = block._blockadd(b)
    if nonblocks:
        return MatAdd(*nonblocks) + block
    else:
        return block
예제 #9
0
def bc_matadd(expr):
    args = sift(expr.args, lambda M: isinstance(M, BlockMatrix))
    blocks = args[True]
    if not blocks:
        return expr

    nonblocks = args[False]
    block = blocks[0]
    for b in blocks[1:]:
        block = block._blockadd(b)
    if nonblocks:
        return MatAdd(*nonblocks) + block
    else:
        return block
예제 #10
0
def _unique_and_repeated(inds):
    """
    Returns the unique and repeated indices. Also note, from the examples given below
    that the order of indices is maintained as given in the input.

    Examples
    ========

    >>> from sympy.tensor.index_methods import _unique_and_repeated
    >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0])
    ([2, 1, 4], [3, 0])
    """
    uniq = OrderedDict()
    for i in inds:
        if i in uniq:
            uniq[i] = 0
        else:
            uniq[i] = 1
    return sift(uniq, lambda x: uniq[x], binary=True)
def _unique_and_repeated(inds):
    """
    Returns the unique and repeated indices. Also note, from the examples given below
    that the order of indices is maintained as given in the input.

    Examples
    ========

    >>> from sympy.tensor.index_methods import _unique_and_repeated
    >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0])
    ([2, 1, 4], [3, 0])
    """
    uniq = OrderedDict()
    for i in inds:
        if i in uniq:
            uniq[i] = 0
        else:
            uniq[i] = 1
    return sift(uniq, lambda x: uniq[x], binary=True)
예제 #12
0
파일: str.py 프로젝트: thorek1/sympy
    def _print_Mul(self, expr):

        prec = precedence(expr)

        # Check for unevaluated Mul. In this case we need to make sure the
        # identities are visible, multiple Rational factors are not combined
        # etc so we display in a straight-forward form that fully preserves all
        # args and their order.
        args = expr.args
        if args[0] is S.One or any(
                isinstance(a, Number) or a.is_Pow and all(ai.is_Integer
                                                          for ai in a.args)
                for a in args[1:]):
            d, n = sift(args,
                        lambda x: isinstance(x, Pow) and bool(
                            x.exp.as_coeff_Mul()[0] < 0),
                        binary=True)
            for i, di in enumerate(d):
                if di.exp.is_Number:
                    e = -di.exp
                else:
                    dargs = list(di.exp.args)
                    dargs[0] = -dargs[0]
                    e = Mul._from_args(dargs)
                d[i] = Pow(di.base, e, evaluate=False) if e - 1 else di.base

            # don't parenthesize first factor if negative
            if _coeff_isneg(n[0]):
                pre = [str(n.pop(0))]
            else:
                pre = []
            nfactors = pre + [
                self.parenthesize(a, prec, strict=False) for a in n
            ]

            # don't parenthesize first of denominator unless singleton
            if len(d) > 1 and _coeff_isneg(d[0]):
                pre = [str(d.pop(0))]
            else:
                pre = []
            dfactors = pre + [
                self.parenthesize(a, prec, strict=False) for a in d
            ]

            n = '*'.join(nfactors)
            d = '*'.join(dfactors)
            if len(dfactors) > 1:
                return '%s/(%s)' % (n, d)
            elif dfactors:
                return '%s/%s' % (n, d)
            return n

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = "-"
        else:
            sign = ""

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        pow_paren = [
        ]  # Will collect all pow with more than one base element and exp = -1

        if self.order not in ('old', 'none'):
            args = expr.as_ordered_factors()
        else:
            # use make_args in case expr was something like -x -> x
            args = Mul.make_args(expr)

        # Gather args for numerator/denominator
        def apow(i):
            b, e = i.as_base_exp()
            eargs = list(Mul.make_args(e))
            if eargs[0] is S.NegativeOne:
                eargs = eargs[1:]
            else:
                eargs[0] = -eargs[0]
            e = Mul._from_args(eargs)
            if isinstance(i, Pow):
                return i.func(b, e, evaluate=False)
            return i.func(e, evaluate=False)

        for item in args:
            if (item.is_commutative and isinstance(item, Pow)
                    and bool(item.exp.as_coeff_Mul()[0] < 0)):
                if item.exp is not S.NegativeOne:
                    b.append(apow(item))
                else:
                    if (len(item.args[0].args) != 1
                            and isinstance(item.base, Mul)):
                        # To avoid situations like #14160
                        pow_paren.append(item)
                    b.append(item.base)
            elif item.is_Rational and item is not S.Infinity:
                if item.p != 1:
                    a.append(Rational(item.p))
                if item.q != 1:
                    b.append(Rational(item.q))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self.parenthesize(x, prec, strict=False) for x in a]
        b_str = [self.parenthesize(x, prec, strict=False) for x in b]

        # To parenthesize Pow with exp = -1 and having more than one Symbol
        for item in pow_paren:
            if item.base in b:
                b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]

        if not b:
            return sign + '*'.join(a_str)
        elif len(b) == 1:
            return sign + '*'.join(a_str) + "/" + b_str[0]
        else:
            return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str)
예제 #13
0
def _minpoly_compose(ex, x, dom):
    """
    Computes the minimal polynomial of an algebraic element
    using operations on minimal polynomials

    Examples
    ========

    >>> from sympy import minimal_polynomial, sqrt, Rational
    >>> from sympy.abc import x, y
    >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=True)
    x**2 - 2*x - 1
    >>> minimal_polynomial(sqrt(y) + 1/y, x, compose=True)
    x**2*y**2 - 2*x*y - y**3 + 1

    """
    if ex.is_Rational:
        return ex.q * x - ex.p
    if ex is I:
        return x**2 + 1
    if hasattr(dom, 'symbols') and ex in dom.symbols:
        return x - ex

    if dom.is_QQ and _is_sum_surds(ex):
        # eliminate the square roots
        ex -= x
        while 1:
            ex1 = _separate_sq(ex)
            if ex1 is ex:
                return ex
            else:
                ex = ex1

    if ex.is_Add:
        res = _minpoly_add(x, dom, *ex.args)
    elif ex.is_Mul:
        f = Factors(ex).factors
        r = sift(f.items(),
                 lambda itx: itx[0].is_rational and itx[1].is_rational)
        if r[True] and dom == QQ:
            ex1 = Mul(*[bx**ex for bx, ex in r[False] + r[None]])
            r1 = r[True]
            dens = [y.q for _, y in r1]
            lcmdens = reduce(lcm, dens, 1)
            nums = [base**(y.p * lcmdens // y.q) for base, y in r1]
            ex2 = Mul(*nums)
            mp1 = minimal_polynomial(ex1, x)
            # use the fact that in SymPy canonicalization products of integers
            # raised to rational powers are organized in relatively prime
            # bases, and that in ``base**(n/d)`` a perfect power is
            # simplified with the root
            mp2 = ex2.q * x**lcmdens - ex2.p
            ex2 = ex2**Rational(1, lcmdens)
            res = _minpoly_op_algebraic_element(Mul,
                                                ex1,
                                                ex2,
                                                x,
                                                dom,
                                                mp1=mp1,
                                                mp2=mp2)
        else:
            res = _minpoly_mul(x, dom, *ex.args)
    elif ex.is_Pow:
        res = _minpoly_pow(ex.base, ex.exp, x, dom)
    elif ex.__class__ is C.sin:
        res = _minpoly_sin(ex, x)
    elif ex.__class__ is C.cos:
        res = _minpoly_cos(ex, x)
    elif ex.__class__ is C.exp:
        res = _minpoly_exp(ex, x)
    elif ex.__class__ is RootOf:
        res = _minpoly_rootof(ex, x)
    else:
        raise NotAlgebraic("%s doesn't seem to be an algebraic element" % ex)
    return res
예제 #14
0
def _separate_sq(p):
    """
    helper function for ``_minimal_polynomial_sq``

    It selects a rational ``g`` such that the polynomial ``p``
    consists of a sum of terms whose surds squared have gcd equal to ``g``
    and a sum of terms with surds squared prime with ``g``;
    then it takes the field norm to eliminate ``sqrt(g)``

    See simplify.simplify.split_surds and polytools.sqf_norm.

    Examples
    ========

    >>> from sympy import sqrt
    >>> from sympy.abc import x
    >>> from sympy.polys.numberfields.minpoly import _separate_sq
    >>> p= -x + sqrt(2) + sqrt(3) + sqrt(7)
    >>> p = _separate_sq(p); p
    -x**2 + 2*sqrt(3)*x + 2*sqrt(7)*x - 2*sqrt(21) - 8
    >>> p = _separate_sq(p); p
    -x**4 + 4*sqrt(7)*x**3 - 32*x**2 + 8*sqrt(7)*x + 20
    >>> p = _separate_sq(p); p
    -x**8 + 48*x**6 - 536*x**4 + 1728*x**2 - 400

    """
    def is_sqrt(expr):
        return expr.is_Pow and expr.exp is S.Half

    # p = c1*sqrt(q1) + ... + cn*sqrt(qn) -> a = [(c1, q1), .., (cn, qn)]
    a = []
    for y in p.args:
        if not y.is_Mul:
            if is_sqrt(y):
                a.append((S.One, y**2))
            elif y.is_Atom:
                a.append((y, S.One))
            elif y.is_Pow and y.exp.is_integer:
                a.append((y, S.One))
            else:
                raise NotImplementedError
        else:
            T, F = sift(y.args, is_sqrt, binary=True)
            a.append((Mul(*F), Mul(*T)**2))
    a.sort(key=lambda z: z[1])
    if a[-1][1] is S.One:
        # there are no surds
        return p
    surds = [z for y, z in a]
    for i in range(len(surds)):
        if surds[i] != 1:
            break
    from sympy.simplify.radsimp import _split_gcd
    g, b1, b2 = _split_gcd(*surds[i:])
    a1 = []
    a2 = []
    for y, z in a:
        if z in b1:
            a1.append(y * z**S.Half)
        else:
            a2.append(y * z**S.Half)
    p1 = Add(*a1)
    p2 = Add(*a2)
    p = _mexpand(p1**2) - _mexpand(p2**2)
    return p
예제 #15
0
def _minpoly_compose(ex, x, dom):
    """
    Computes the minimal polynomial of an algebraic element
    using operations on minimal polynomials

    Examples
    ========

    >>> from sympy import minimal_polynomial, sqrt, Rational
    >>> from sympy.abc import x, y
    >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=True)
    x**2 - 2*x - 1
    >>> minimal_polynomial(sqrt(y) + 1/y, x, compose=True)
    x**2*y**2 - 2*x*y - y**3 + 1

    """
    if ex.is_Rational:
        return ex.q * x - ex.p
    if ex is I:
        _, factors = factor_list(x**2 + 1, x, domain=dom)
        return x**2 + 1 if len(factors) == 1 else x - I

    if ex is S.GoldenRatio:
        _, factors = factor_list(x**2 - x - 1, x, domain=dom)
        if len(factors) == 1:
            return x**2 - x - 1
        else:
            return _choose_factor(factors, x, (1 + sqrt(5)) / 2, dom=dom)

    if ex is S.TribonacciConstant:
        _, factors = factor_list(x**3 - x**2 - x - 1, x, domain=dom)
        if len(factors) == 1:
            return x**3 - x**2 - x - 1
        else:
            fac = (1 + cbrt(19 - 3 * sqrt(33)) + cbrt(19 + 3 * sqrt(33))) / 3
            return _choose_factor(factors, x, fac, dom=dom)

    if hasattr(dom, 'symbols') and ex in dom.symbols:
        return x - ex

    if dom.is_QQ and _is_sum_surds(ex):
        # eliminate the square roots
        ex -= x
        while 1:
            ex1 = _separate_sq(ex)
            if ex1 is ex:
                return ex
            else:
                ex = ex1

    if ex.is_Add:
        res = _minpoly_add(x, dom, *ex.args)
    elif ex.is_Mul:
        f = Factors(ex).factors
        r = sift(f.items(),
                 lambda itx: itx[0].is_Rational and itx[1].is_Rational)
        if r[True] and dom == QQ:
            ex1 = Mul(*[bx**ex for bx, ex in r[False] + r[None]])
            r1 = dict(r[True])
            dens = [y.q for y in r1.values()]
            lcmdens = reduce(lcm, dens, 1)
            neg1 = S.NegativeOne
            expn1 = r1.pop(neg1, S.Zero)
            nums = [base**(y.p * lcmdens // y.q) for base, y in r1.items()]
            ex2 = Mul(*nums)
            mp1 = minimal_polynomial(ex1, x)
            # use the fact that in SymPy canonicalization products of integers
            # raised to rational powers are organized in relatively prime
            # bases, and that in ``base**(n/d)`` a perfect power is
            # simplified with the root
            # Powers of -1 have to be treated separately to preserve sign.
            mp2 = ex2.q * x**lcmdens - ex2.p * neg1**(expn1 * lcmdens)
            ex2 = neg1**expn1 * ex2**Rational(1, lcmdens)
            res = _minpoly_op_algebraic_element(Mul,
                                                ex1,
                                                ex2,
                                                x,
                                                dom,
                                                mp1=mp1,
                                                mp2=mp2)
        else:
            res = _minpoly_mul(x, dom, *ex.args)
    elif ex.is_Pow:
        res = _minpoly_pow(ex.base, ex.exp, x, dom)
    elif ex.__class__ is sin:
        res = _minpoly_sin(ex, x)
    elif ex.__class__ is cos:
        res = _minpoly_cos(ex, x)
    elif ex.__class__ is tan:
        res = _minpoly_tan(ex, x)
    elif ex.__class__ is exp:
        res = _minpoly_exp(ex, x)
    elif ex.__class__ is CRootOf:
        res = _minpoly_rootof(ex, x)
    else:
        raise NotAlgebraic("%s does not seem to be an algebraic element" % ex)
    return res
예제 #16
0
def _minpoly_compose(ex, x, dom):
    """
    Computes the minimal polynomial of an algebraic element
    using operations on minimal polynomials

    Examples
    ========

    >>> from sympy import minimal_polynomial, sqrt, Rational
    >>> from sympy.abc import x, y
    >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=True)
    x**2 - 2*x - 1
    >>> minimal_polynomial(sqrt(y) + 1/y, x, compose=True)
    x**2*y**2 - 2*x*y - y**3 + 1

    """
    if ex.is_Rational:
        return ex.q*x - ex.p
    if ex is I:
        return x**2 + 1
    if hasattr(dom, 'symbols') and ex in dom.symbols:
        return x - ex

    if dom.is_QQ and _is_sum_surds(ex):
        # eliminate the square roots
        ex -= x
        while 1:
            ex1 = _separate_sq(ex)
            if ex1 is ex:
                return ex
            else:
                ex = ex1

    if ex.is_Add:
        res = _minpoly_add(x, dom, *ex.args)
    elif ex.is_Mul:
        f = Factors(ex).factors
        r = sift(f.items(), lambda itx: itx[0].is_rational and itx[1].is_rational)
        if r[True] and dom == QQ:
            ex1 = Mul(*[bx**ex for bx, ex in r[False] + r[None]])
            r1 = r[True]
            dens = [y.q for _, y in r1]
            lcmdens = reduce(lcm, dens, 1)
            nums = [base**(y.p*lcmdens // y.q) for base, y in r1]
            ex2 = Mul(*nums)
            mp1 = minimal_polynomial(ex1, x)
            # use the fact that in SymPy canonicalization products of integers
            # raised to rational powers are organized in relatively prime
            # bases, and that in ``base**(n/d)`` a perfect power is
            # simplified with the root
            mp2 = ex2.q*x**lcmdens - ex2.p
            ex2 = ex2**Rational(1, lcmdens)
            res = _minpoly_op_algebraic_element(Mul, ex1, ex2, x, dom, mp1=mp1, mp2=mp2)
        else:
            res = _minpoly_mul(x, dom, *ex.args)
    elif ex.is_Pow:
        res = _minpoly_pow(ex.base, ex.exp, x, dom)
    elif ex.__class__ is C.sin:
        res = _minpoly_sin(ex, x)
    elif ex.__class__ is C.cos:
        res = _minpoly_cos(ex, x)
    elif ex.__class__ is C.exp:
        res = _minpoly_exp(ex, x)
    elif ex.__class__ is RootOf:
        res = _minpoly_rootof(ex, x)
    else:
        raise NotAlgebraic("%s doesn't seem to be an algebraic element" % ex)
    return res
예제 #17
0
def _separate_sq(p):
    """
    helper function for ``_minimal_polynomial_sq``

    It selects a rational ``g`` such that the polynomial ``p``
    consists of a sum of terms whose surds squared have gcd equal to ``g``
    and a sum of terms with surds squared prime with ``g``;
    then it takes the field norm to eliminate ``sqrt(g)``

    See simplify.simplify.split_surds and polytools.sqf_norm.

    Examples
    ========

    >>> from sympy import sqrt
    >>> from sympy.abc import x
    >>> from sympy.polys.numberfields import _separate_sq
    >>> p= -x + sqrt(2) + sqrt(3) + sqrt(7)
    >>> p = _separate_sq(p); p
    -x**2 + 2*sqrt(3)*x + 2*sqrt(7)*x - 2*sqrt(21) - 8
    >>> p = _separate_sq(p); p
    -x**4 + 4*sqrt(7)*x**3 - 32*x**2 + 8*sqrt(7)*x + 20
    >>> p = _separate_sq(p); p
    -x**8 + 48*x**6 - 536*x**4 + 1728*x**2 - 400

    """
    from sympy.simplify.simplify import _split_gcd, _mexpand
    from sympy.utilities.iterables import sift
    def is_sqrt(expr):
        return expr.is_Pow and expr.exp is S.Half
    # p = c1*sqrt(q1) + ... + cn*sqrt(qn) -> a = [(c1, q1), .., (cn, qn)]
    a = []
    for y in p.args:
        if not y.is_Mul:
            if is_sqrt(y):
                a.append((S.One, y**2))
            elif y.is_Atom:
                a.append((y, S.One))
            elif y.is_Pow and y.exp.is_integer:
                a.append((y, S.One))
            else:
                raise NotImplementedError
            continue
        sifted = sift(y.args, is_sqrt)
        a.append((Mul(*sifted[False]), Mul(*sifted[True])**2))
    a.sort(key=lambda z: z[1])
    if a[-1][1] is S.One:
        # there are no surds
        return p
    surds = [z for y, z in a]
    for i in range(len(surds)):
        if surds[i] != 1:
            break
    g, b1, b2 = _split_gcd(*surds[i:])
    a1 = []
    a2 = []
    for y, z in a:
        if z in b1:
            a1.append(y*z**S.Half)
        else:
            a2.append(y*z**S.Half)
    p1 = Add(*a1)
    p2 = Add(*a2)
    p = _mexpand(p1**2) - _mexpand(p2**2)
    return p