예제 #1
0
def test_ordered():
    # Issue 7210 - this had been failing with python2/3 problems
    assert list(ordered([{1: 3, 2: 4, 9: 10}, {1: 3}])) == [{1: 3}, {1: 3, 2: 4, 9: 10}]
    # warnings should not be raised for identical items
    l = [1, 1]
    assert list(ordered(l, warn=True)) == l
    l = [[1], [2], [1]]
    assert list(ordered(l, warn=True)) == [[1], [1], [2]]
    raises(ValueError, lambda: list(ordered(["a", "ab"], keys=[lambda x: x[0]], default=False, warn=True)))
예제 #2
0
    def intersection(self, o):
        """The intersection of the parabola and another geometrical entity `o`.

        Parameters
        ==========

        o : GeometryEntity, LinearEntity

        Returns
        =======

        intersection : list of GeometryEntity objects

        Examples
        ========

        >>> from sympy import Parabola, Point, Ellipse, Line, Segment
        >>> p1 = Point(0,0)
        >>> l1 = Line(Point(1, -2), Point(-1,-2))
        >>> parabola1 = Parabola(p1, l1)
        >>> parabola1.intersection(Ellipse(Point(0, 0), 2, 5))
        [Point2D(-2, 0), Point2D(2, 0)]
        >>> parabola1.intersection(Line(Point(-7, 3), Point(12, 3)))
        [Point2D(-4, 3), Point2D(4, 3)]
        >>> parabola1.intersection(Segment((-12, -65), (14, -68)))
        []

        """
        x, y = symbols('x y', real=True)
        parabola_eq = self.equation()
        if isinstance(o, Parabola):
            if o in self:
                return [o]
            else:
                return list(ordered([Point(i) for i in solve([parabola_eq, o.equation()], [x, y])]))
        elif isinstance(o, Point2D):
            if simplify(parabola_eq.subs(([(x, o._args[0]), (y, o._args[1])]))) == 0:
                return [o]
            else:
                return []
        elif isinstance(o, (Segment2D, Ray2D)):
            result = solve([parabola_eq, Line2D(o.points[0], o.points[1]).equation()], [x, y])
            return list(ordered([Point2D(i) for i in result if i in o]))
        elif isinstance(o, (Line2D, Ellipse)):
            return list(ordered([Point2D(i) for i in solve([parabola_eq, o.equation()], [x, y])]))
        elif isinstance(o, LinearEntity3D):
            raise TypeError('Entity must be two dimensional, not three dimensional')
        else:
            raise TypeError('Wrong type of argument were put')
예제 #3
0
파일: polyroots.py 프로젝트: bjodah/sympy
def roots_cyclotomic(f, factor=False):
    """Compute roots of cyclotomic polynomials. """
    L, U = _inv_totient_estimate(f.degree())

    for n in range(L, U + 1):
        g = cyclotomic_poly(n, f.gen, polys=True)

        if f == g:
            break
    else:  # pragma: no cover
        raise RuntimeError("failed to find index of a cyclotomic polynomial")

    roots = []

    if not factor:
        # get the indices in the right order so the computed
        # roots will be sorted
        h = n//2
        ks = [i for i in range(1, n + 1) if igcd(i, n) == 1]
        ks.sort(key=lambda x: (x, -1) if x <= h else (abs(x - n), 1))
        d = 2*I*pi/n
        for k in reversed(ks):
            roots.append(exp(k*d).expand(complex=True))
    else:
        g = Poly(f, extension=root(-1, n))

        for h, _ in ordered(g.factor_list()[1]):
            roots.append(-h.TC())

    return roots
예제 #4
0
    def __new__(cls, *args, **kwargs):
        evaluate = kwargs.get("evaluate", global_evaluate[0])

        # flatten inputs
        args = list(args)

        # adapted from sympy.sets.sets.Union
        def _flatten(arg):
            if isinstance(arg, SeqBase):
                if isinstance(arg, SeqMul):
                    return sum(map(_flatten, arg.args), [])
                else:
                    return [arg]
            elif iterable(arg):
                return sum(map(_flatten, arg), [])
            raise TypeError("Input must be Sequences or " " iterables of Sequences")

        args = _flatten(args)

        # Multiplication of no sequences is EmptySequence
        if not args:
            return S.EmptySequence

        if Intersection(a.interval for a in args) is S.EmptySet:
            return S.EmptySequence

        # reduce using known rules
        if evaluate:
            return SeqMul.reduce(args)

        args = list(ordered(args, SeqBase._start_key))

        return Basic.__new__(cls, *args)
예제 #5
0
def roots_cyclotomic(f, factor=False):
    """Compute roots of cyclotomic polynomials. """
    L, U = _inv_totient_estimate(f.degree())

    for n in range(L, U + 1):
        g = cyclotomic_poly(n, f.gen, polys=True)

        if f == g:
            break
    else:  # pragma: no cover
        raise RuntimeError("failed to find index of a cyclotomic polynomial")

    roots = []

    if not factor:
        # get the indices in the right order so the computed
        # roots will be sorted
        h = n // 2
        ks = [i for i in range(1, n + 1) if igcd(i, n) == 1]
        ks.sort(key=lambda x: (x, -1) if x <= h else (abs(x - n), 1))
        d = 2 * I * pi / n
        for k in reversed(ks):
            roots.append(exp(k * d).expand(complex=True))
    else:
        g = Poly(f, extension=root(-1, n))

        for h, _ in ordered(g.factor_list()[1]):
            roots.append(-h.TC())

    return roots
예제 #6
0
파일: util.py 프로젝트: bjodah/sympy
def quantity_simplify(expr):
    """Return an equivalent expression in which prefixes are replaced
    with numerical values and all units of a given dimension are the
    unified in a canonical manner.

    Examples
    ========

    >>> from sympy.physics.units.util import quantity_simplify
    >>> from sympy.physics.units.prefixes import kilo
    >>> from sympy.physics.units import foot, inch
    >>> quantity_simplify(kilo*foot*inch)
    250*foot**2/3
    >>> quantity_simplify(foot - 6*inch)
    foot/2
    """

    if expr.is_Atom or not expr.has(Prefix, Quantity):
        return expr

    # replace all prefixes with numerical values
    p = expr.atoms(Prefix)
    expr = expr.xreplace({p: p.scale_factor for p in p})

    # replace all quantities of given dimension with a canonical
    # quantity, chosen from those in the expression
    d = sift(expr.atoms(Quantity), lambda i: i.dimension)
    for k in d:
        if len(d[k]) == 1:
            continue
        v = list(ordered(d[k]))
        ref = v[0]/v[0].scale_factor
        expr = expr.xreplace({vi: ref*vi.scale_factor for vi in v[1:]})

    return expr
예제 #7
0
파일: sets.py 프로젝트: alphaitis/sympy
    def __new__(cls, *args, **kwargs):
        evaluate = kwargs.get('evaluate', global_evaluate[0])

        # flatten inputs to merge intersections and iterables
        args = list(args)

        def flatten(arg):
            if isinstance(arg, Set):
                if arg.is_Intersection:
                    return sum(map(flatten, arg.args), [])
                else:
                    return [arg]
            if iterable(arg):  # and not isinstance(arg, Set) (implicit)
                return sum(map(flatten, arg), [])
            raise TypeError("Input must be Sets or iterables of Sets")
        args = flatten(args)

        if len(args) == 0:
            raise TypeError("Intersection expected at least one argument")

        # Reduce sets using known rules
        if evaluate:
            return Intersection.reduce(args)

        args = list(ordered(args, Set._infimum_key))

        return Basic.__new__(cls, *args)
예제 #8
0
 def _refine_imaginary(cls, complexes):
     sifted = sift(complexes, lambda c: c[1])
     complexes = []
     for f in ordered(sifted):
         nimag = _imag_count_of_factor(f)
         if nimag == 0:
             # refine until xbounds are neg or pos
             for u, f, k in sifted[f]:
                 while u.ax * u.bx <= 0:
                     u = u._inner_refine()
                 complexes.append((u, f, k))
         else:
             # refine until all but nimag xbounds are neg or pos
             potential_imag = list(range(len(sifted[f])))
             while True:
                 assert len(potential_imag) > 1
                 for i in list(potential_imag):
                     u, f, k = sifted[f][i]
                     if u.ax * u.bx > 0:
                         potential_imag.remove(i)
                     elif u.ax != u.bx:
                         u = u._inner_refine()
                         sifted[f][i] = u, f, k
                 if len(potential_imag) == nimag:
                     break
             complexes.extend(sifted[f])
     return complexes
예제 #9
0
    def __new__(cls, *args, **kwargs):
        evaluate = kwargs.get('evaluate', global_parameters.evaluate)

        # flatten inputs
        args = list(args)

        # adapted from sympy.sets.sets.Union
        def _flatten(arg):
            if isinstance(arg, SeqBase):
                if isinstance(arg, SeqMul):
                    return sum(map(_flatten, arg.args), [])
                else:
                    return [arg]
            elif iterable(arg):
                return sum(map(_flatten, arg), [])
            raise TypeError("Input must be Sequences or "
                            " iterables of Sequences")

        args = _flatten(args)

        # Multiplication of no sequences is EmptySequence
        if not args:
            return S.EmptySequence

        if Intersection(*(a.interval for a in args)) is S.EmptySet:
            return S.EmptySequence

        # reduce using known rules
        if evaluate:
            return SeqMul.reduce(args)

        args = list(ordered(args, SeqBase._start_key))

        return Basic.__new__(cls, *args)
예제 #10
0
    def __new__(cls, *args, **assumptions):
        evaluate = assumptions.pop('evaluate', True)
        args = (sympify(arg) for arg in args)

        # first standard filter, for cls.zero and cls.identity
        # also reshape Max(a, Max(b, c)) to Max(a, b, c)

        if evaluate:
            try:
                args = frozenset(cls._new_args_filter(args))
            except ShortCircuit:
                return cls.zero
        else:
            args = frozenset(args)

        if evaluate:
            # remove redundant args that are easily identified
            args = cls._collapse_arguments(args, **assumptions)
            # find local zeros
            args = cls._find_localzeros(args, **assumptions)

        if not args:
            return cls.identity

        if len(args) == 1:
            return list(args).pop()

        # base creation
        _args = frozenset(args)
        obj = Expr.__new__(cls, *ordered(_args), **assumptions)
        obj._argset = _args
        return obj
예제 #11
0
    def __new__(cls, *args, **kwargs):
        evaluate = kwargs.get('evaluate', global_evaluate[0])

        # flatten inputs to merge intersections and iterables
        args = list(args)

        def flatten(arg):
            if isinstance(arg, Set):
                if arg.is_Intersection:
                    return sum(map(flatten, arg.args), [])
                else:
                    return [arg]
            if iterable(arg):  # and not isinstance(arg, Set) (implicit)
                return sum(map(flatten, arg), [])
            raise TypeError("Input must be Sets or iterables of Sets")

        args = flatten(args)

        if len(args) == 0:
            raise TypeError("Intersection expected at least one argument")

        # Reduce sets using known rules
        if evaluate:
            return Intersection.reduce(args)

        args = list(ordered(args, Set._infimum_key))

        return Basic.__new__(cls, *args)
예제 #12
0
파일: rootoftools.py 프로젝트: Lenqth/sympy
 def _refine_imaginary(cls, complexes):
     sifted = sift(complexes, lambda c: c[1])
     complexes = []
     for f in ordered(sifted):
         nimag = _imag_count_of_factor(f)
         if nimag == 0:
             # refine until xbounds are neg or pos
             for u, f, k in sifted[f]:
                 while u.ax*u.bx <= 0:
                     u = u._inner_refine()
                 complexes.append((u, f, k))
         else:
             # refine until all but nimag xbounds are neg or pos
             potential_imag = list(range(len(sifted[f])))
             while True:
                 assert len(potential_imag) > 1
                 for i in list(potential_imag):
                     u, f, k = sifted[f][i]
                     if u.ax*u.bx > 0:
                         potential_imag.remove(i)
                     elif u.ax != u.bx:
                         u = u._inner_refine()
                         sifted[f][i] = u, f, k
                 if len(potential_imag) == nimag:
                     break
             complexes.extend(sifted[f])
     return complexes
예제 #13
0
파일: util.py 프로젝트: msgoff/sympy
def quantity_simplify(expr):
    """Return an equivalent expression in which prefixes are replaced
    with numerical values and all units of a given dimension are the
    unified in a canonical manner.

    Examples
    ========

    >>> from sympy.physics.units.util import quantity_simplify
    >>> from sympy.physics.units.prefixes import kilo
    >>> from sympy.physics.units import foot, inch
    >>> quantity_simplify(kilo*foot*inch)
    250*foot**2/3
    >>> quantity_simplify(foot - 6*inch)
    foot/2
    """

    if expr.is_Atom or not expr.has(Prefix, Quantity):
        return expr

    # replace all prefixes with numerical values
    p = expr.atoms(Prefix)
    expr = expr.xreplace({p: p.scale_factor for p in p})

    # replace all quantities of given dimension with a canonical
    # quantity, chosen from those in the expression
    d = sift(expr.atoms(Quantity), lambda i: i.dimension)
    for k in d:
        if len(d[k]) == 1:
            continue
        v = list(ordered(d[k]))
        ref = v[0] / v[0].scale_factor
        expr = expr.xreplace({vi: ref * vi.scale_factor for vi in v[1:]})

    return expr
예제 #14
0
def test_ternary_quadratic():
    # solution with 3 parameters
    s = diophantine(2*x**2 + y**2 - 2*z**2)
    p, q, r = ordered(S(s).free_symbols)
    assert s == {(
        p**2 - 2*q**2,
        -2*p**2 + 4*p*q - 4*p*r - 4*q**2,
        p**2 - 4*p*q + 2*q**2 - 4*q*r)}
    # solution with Mul in solution
    s = diophantine(x**2 + 2*y**2 - 2*z**2)
    assert s == {(4*p*q, p**2 - 2*q**2, p**2 + 2*q**2)}
    # solution with no Mul in solution
    s = diophantine(2*x**2 + 2*y**2 - z**2)
    assert s == {(2*p**2 - q**2, -2*p**2 + 4*p*q - q**2,
        4*p**2 - 4*p*q + 2*q**2)}
    # reduced form when parametrized
    s = diophantine(3*x**2 + 72*y**2 - 27*z**2)
    assert s == {(24*p**2 - 9*q**2, 6*p*q, 8*p**2 + 3*q**2)}
    assert parametrize_ternary_quadratic(
        3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) == (
        2*p**2 - 2*p*q - q**2, 2*p**2 + 2*p*q - q**2, 2*p**2 -
        2*p*q + 3*q**2)
    assert parametrize_ternary_quadratic(
        124*x**2 - 30*y**2 - 7729*z**2) == (
        -1410*p**2 - 363263*q**2, 2700*p**2 + 30916*p*q -
        695610*q**2, -60*p**2 + 5400*p*q + 15458*q**2)
 def f(rv):
     if not rv.is_Mul:
         return rv
     args = {tan: [], cot: [], None: []}
     for a in ordered(Mul.make_args(rv)):
         if a.func in (tan, cot):
             args[a.func].append(a.args[0])
         else:
             args[None].append(a)
     t = args[tan]
     c = args[cot]
     if len(t) < 2 and len(c) < 2:
         return rv
     args = args[None]
     while len(t) > 1:
         t1 = t.pop()
         t2 = t.pop()
         args.append(1 - (tan(t1) / tan(t1 + t2) + tan(t2) / tan(t1 + t2)))
     if t:
         args.append(tan(t.pop()))
     while len(c) > 1:
         t1 = c.pop()
         t2 = c.pop()
         args.append(1 + cot(t1) * cot(t1 + t2) + cot(t2) * cot(t1 + t2))
     if c:
         args.append(cot(c.pop()))
     return Mul(*args)
예제 #16
0
파일: boolalg.py 프로젝트: Bercio/sympy
 def __new__(cls, *args, **kwargs):
     argset = set([])
     obj = super(Xor, cls).__new__(cls, *args, **kwargs)
     for arg in obj._args:
         if isinstance(arg, Number) or arg in (True, False):
             if arg:
                 arg = true
             else:
                 continue
         if isinstance(arg, Xor):
             for a in arg.args:
                 argset.remove(a) if a in argset else argset.add(a)
         elif arg in argset:
             argset.remove(arg)
         else:
             argset.add(arg)
     if len(argset) == 0:
         return false
     elif len(argset) == 1:
         return argset.pop()
     elif True in argset:
         argset.remove(True)
         return Not(Xor(*argset))
     else:
         obj._args = tuple(ordered(argset))
         obj._argset = frozenset(argset)
         return obj
예제 #17
0
def _mostfunc(lhs, func, X=None):
    """Returns the term in lhs which contains the most of the
    func-type things e.g. log(log(x)) wins over log(x) if both terms appear.

    ``func`` can be a function (exp, log, etc...) or any other SymPy object,
    like Pow.

    If ``X`` is not ``None``, then the function returns the term composed with the
    most ``func`` having the specified variable.

    Examples
    ========

    >>> from sympy.solvers.bivariate import _mostfunc
    >>> from sympy.functions.elementary.exponential import exp
    >>> from sympy.abc import x, y
    >>> _mostfunc(exp(x) + exp(exp(x) + 2), exp)
    exp(exp(x) + 2)
    >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp)
    exp(exp(y) + 2)
    >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x)
    exp(x)
    >>> _mostfunc(x, exp, x) is None
    True
    >>> _mostfunc(exp(x) + exp(x*y), exp, x)
    exp(x)
    """
    fterms = [tmp for tmp in lhs.atoms(func) if (not X or
        X.is_Symbol and X in tmp.free_symbols or
        not X.is_Symbol and tmp.has(X))]
    if len(fterms) == 1:
        return fterms[0]
    elif fterms:
        return max(list(ordered(fterms)), key=lambda x: x.count(func))
    return None
예제 #18
0
파일: bivariate.py 프로젝트: AALEKH/sympy
def _mostfunc(lhs, func, X=None):
    """Returns the term in lhs which contains the most of the
    func-type things e.g. log(log(x)) wins over log(x) if both terms appear.

    ``func`` can be a function (exp, log, etc...) or any other SymPy object,
    like Pow.

    Examples
    ========

    >>> from sympy.solvers.bivariate import _mostfunc
    >>> from sympy.functions.elementary.exponential import exp
    >>> from sympy.utilities.pytest import raises
    >>> from sympy.abc import x, y
    >>> _mostfunc(exp(x) + exp(exp(x) + 2), exp)
    exp(exp(x) + 2)
    >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x)
    exp(x)
    >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x)
    exp(x)
    >>> _mostfunc(x, exp, x) is None
    True
    >>> _mostfunc(exp(x) + exp(x*y), exp, x)
    exp(x)
    """
    fterms = [tmp for tmp in lhs.atoms(func) if (not X or
        X.is_Symbol and X in tmp.free_symbols or
        not X.is_Symbol and tmp.has(X))]
    if len(fterms) == 1:
        return fterms[0]
    elif fterms:
        return max(list(ordered(fterms)), key=lambda x: x.count(func))
    return None
    def f(rv):
        if not rv.func in (cos, sin):
            return rv

        f = rv.func
        arg = rv.args[0]
        if arg.is_Add:
            if first:
                args = list(ordered(arg.args))
            else:
                args = list(arg.args)
            a = args.pop()
            b = Add._from_args(args)
            if b.is_Add:
                if f == sin:
                    return sin(a) * TR10(cos(b), first=False) + \
                           cos(a) * TR10(sin(b), first=False)
                else:
                    return cos(a) * TR10(cos(b), first=False) - \
                           sin(a) * TR10(sin(b), first=False)
            else:
                if f == sin:
                    return sin(a) * cos(b) + cos(a) * sin(b)
                else:
                    return cos(a) * cos(b) - sin(a) * sin(b)
        return rv
예제 #20
0
파일: rv.py 프로젝트: akshayah3/sympy
def random_symbols(expr):
    """
    Returns a sorted list of all RandomSymbols within a SymPy Expression.
    """
    try:
        return list(ordered(expr.atoms(RandomSymbol), warn=True))
    except AttributeError:
        return []
예제 #21
0
파일: rv.py 프로젝트: akshayah3/sympy
def random_symbols(expr):
    """
    Returns a sorted list of all RandomSymbols within a SymPy Expression.
    """
    try:
        return list(ordered(expr.atoms(RandomSymbol), warn=True))
    except AttributeError:
        return []
예제 #22
0
파일: basic.py 프로젝트: bladewang/sympy
 def _preorder_traversal(self, node, keys):
     yield node
     if self._skip_flag:
         self._skip_flag = False
         return
     if isinstance(node, Basic):
         args = node.args
         if keys:
             if keys != True:
                 args = ordered(args, keys, default=False)
             else:
                 args = ordered(args)
         for arg in args:
             for subtree in self._preorder_traversal(arg, keys):
                 yield subtree
     elif iterable(node):
         for item in node:
             for subtree in self._preorder_traversal(item, keys):
                 yield subtree
예제 #23
0
 def _preorder_traversal(self, node, keys):
     yield node
     if self._skip_flag:
         self._skip_flag = False
         return
     if isinstance(node, Basic):
         args = node.args
         if keys:
             if keys != True:
                 args = ordered(args, keys, default=False)
             else:
                 args = ordered(args)
         for arg in args:
             for subtree in self._preorder_traversal(arg, keys):
                 yield subtree
     elif iterable(node):
         for item in node:
             for subtree in self._preorder_traversal(item, keys):
                 yield subtree
예제 #24
0
def _construct_algebraic(coeffs, opt):
    """We know that coefficients are algebraic so construct the extension. """
    from sympy.polys.numberfields import primitive_element

    exts = set()

    def build_trees(args):
        trees = []
        for a in args:
            if a.is_Rational:
                tree = ('Q', QQ.from_sympy(a))
            elif a.is_Add:
                tree = ('+', build_trees(a.args))
            elif a.is_Mul:
                tree = ('*', build_trees(a.args))
            else:
                tree = ('e', a)
                exts.add(a)
            trees.append(tree)
        return trees

    trees = build_trees(coeffs)
    exts = list(ordered(exts))

    g, span, H = primitive_element(exts, ex=True, polys=True)
    root = sum([ s*ext for s, ext in zip(span, exts) ])

    domain, g = QQ.algebraic_field((g, root)), g.rep.rep

    exts_dom = [domain.dtype.from_list(h, g, QQ) for h in H]
    exts_map = dict(zip(exts, exts_dom))

    def convert_tree(tree):
        op, args = tree
        if op == 'Q':
            return domain.dtype.from_list([args], g, QQ)
        elif op == '+':
            return sum((convert_tree(a) for a in args), domain.zero)
        elif op == '*':
            # return prod(convert(a) for a in args)
            t = convert_tree(args[0])
            for a in args[1:]:
                t *= convert_tree(a)
            return t
        elif op == 'e':
            return exts_map[args]
        else:
            raise RuntimeError

    result = [convert_tree(tree) for tree in trees]

    return domain, result
예제 #25
0
    def canonical(self):
        """
        Return a canonical form of the relational by putting a
        number on the rhs, canonically removing a sign or else
        ordering the args canonically. No other simplification is
        attempted.
        """
        args = self.arguments
        r = self
        if r.rhs.is_number:
            if any(side is S.NaN for side in (r.lhs, r.rhs)):
                pass
            elif r.rhs.is_Number and r.lhs.is_Number and r.lhs > r.rhs:
                r = r.reversed
        elif r.lhs.is_number:
            r = r.reversed
        elif tuple(ordered(args)) != args:
            r = r.reversed

        LHS_CEMS = getattr(r.lhs, 'could_extract_minus_sign', None)
        RHS_CEMS = getattr(r.rhs, 'could_extract_minus_sign', None)

        if any(side.kind is BooleanKind for side in r.arguments):
            return r

        # Check if first value has negative sign
        if LHS_CEMS and LHS_CEMS():
            return r.reversedsign
        elif not r.rhs.is_number and RHS_CEMS and RHS_CEMS():
            # Right hand side has a minus, but not lhs.
            # How does the expression with reversed signs behave?
            # This is so that expressions of the type
            # Eq(x, -y) and Eq(-x, y)
            # have the same canonical representation
            expr1, _ = ordered([r.lhs, -r.rhs])
            if expr1 != r.lhs:
                return r.reversed.reversedsign

        return r
    def f(rv):
        if not (rv.is_Mul or rv.is_Pow and rv.base.func in (cos, sin) and
                (rv.exp.is_integer or rv.base.is_positive)):
            return rv

        if first:
            n, d = [expand_mul(i) for i in rv.as_numer_denom()]
            newn = TR8(n, first=False)
            newd = TR8(d, first=False)
            if newn != n or newd != d:
                rv = gcd_terms(newn / newd)
                if rv.is_Mul and rv.args[0].is_Rational and \
                        len(rv.args) == 2 and rv.args[1].is_Add:
                    rv = Mul(*rv.as_coeff_Mul())
            return rv

        args = {cos: [], sin: [], None: []}
        for a in ordered(Mul.make_args(rv)):
            if a.func in (cos, sin):
                args[a.func].append(a.args[0])
            elif (a.is_Pow and a.exp.is_Integer and a.exp > 0 and \
                  a.base.func in (cos, sin)):
                # XXX this is ok but pathological expression could be handled
                # more efficiently as in TRmorrie
                args[a.base.func].extend([a.base.args[0]] * a.exp)
            else:
                args[None].append(a)
        c = args[cos]
        s = args[sin]
        if not (c and s or len(c) > 1 or len(s) > 1):
            return rv

        args = args[None]
        n = min(len(c), len(s))
        for i in range(n):
            a1 = s.pop()
            a2 = c.pop()
            args.append((sin(a1 + a2) + sin(a1 - a2)) / 2)
        while len(c) > 1:
            a1 = c.pop()
            a2 = c.pop()
            args.append((cos(a1 + a2) + cos(a1 - a2)) / 2)
        if c:
            args.append(cos(c.pop()))
        while len(s) > 1:
            a1 = s.pop()
            a2 = s.pop()
            args.append((-cos(a1 + a2) + cos(a1 - a2)) / 2)
        if s:
            args.append(sin(s.pop()))
        return TR8(expand_mul(Mul(*args)))
예제 #27
0
    def __new__(cls, *args, **kwargs):
        evaluate = kwargs.get('evaluate', global_evaluate[0])
        if evaluate:
            args = list(map(sympify, args))

            if len(args) == 0:
                return EmptySet()
        else:
            args = list(map(sympify, args))

        args = list(ordered(frozenset(tuple(args))))
        obj = Basic.__new__(cls, *args)
        obj._elements = frozenset(args)
        return obj
예제 #28
0
파일: sets.py 프로젝트: alphaitis/sympy
    def __new__(cls, *args, **kwargs):
        evaluate = kwargs.get('evaluate', global_evaluate[0])
        if evaluate:
            args = list(map(sympify, args))

            if len(args) == 0:
                return EmptySet()
        else:
            args = list(map(sympify, args))

        args = list(ordered(frozenset(tuple(args)), Set._infimum_key))
        obj = Basic.__new__(cls, *args)
        obj._elements = frozenset(args)
        return obj
        def do(rv, first=True):
            if not rv.is_Add:
                return rv

            args = list(ordered(rv.args))
            if len(args) != 2:
                hit = False
                for i in range(len(args)):
                    ai = args[i]
                    if ai is None:
                        continue
                    for j in range(i + 1, len(args)):
                        aj = args[j]
                        if aj is None:
                            continue
                        was = ai + aj
                        new = do(was)
                        if new != was:
                            args[i] = new  # update in place
                            args[j] = None
                            hit = True
                            break  # go to next i
                if hit:
                    rv = Add(*[_f for _f in args if _f])
                    if rv.is_Add:
                        rv = do(rv)

                return rv

            # two-arg Add
            split = trig_split(*args)
            if not split:
                return rv
            gcd, n1, n2, a, b, iscos = split

            # application of rule if possible
            if iscos:
                if n1 == n2:
                    return gcd * n1 * 2 * cos((a + b) / 2) * cos((a - b) / 2)
                if n1 < 0:
                    a, b = b, a
                return -2 * gcd * sin((a + b) / 2) * sin((a - b) / 2)
            else:
                if n1 == n2:
                    return gcd * n1 * 2 * sin((a + b) / 2) * cos((a - b) / 2)
                if n1 < 0:
                    a, b = b, a
                return 2 * gcd * cos((a + b) / 2) * sin((a - b) / 2)
예제 #30
0
파일: boolalg.py 프로젝트: yang603/sympy
 def __new__(cls, *args, **kwargs):
     argset = set([])
     obj = super(Xor, cls).__new__(cls, *args, **kwargs)
     for arg in obj._args:
         if isinstance(arg, Number) or arg in (True, False):
             if arg:
                 arg = true
             else:
                 continue
         if isinstance(arg, Xor):
             for a in arg.args:
                 argset.remove(a) if a in argset else argset.add(a)
         elif arg in argset:
             argset.remove(arg)
         else:
             argset.add(arg)
     rel = [(r, r.canonical, (~r).canonical) for r in argset
            if r.is_Relational]
     odd = False  # is number of complimentary pairs odd? start 0 -> False
     remove = []
     for i, (r, c, nc) in enumerate(rel):
         for j in range(i + 1, len(rel)):
             rj, cj = rel[j][:2]
             if cj == nc:
                 odd = ~odd
                 break
             elif cj == c:
                 break
         else:
             continue
         remove.append((r, rj))
     if odd:
         argset.remove(true) if true in argset else argset.add(true)
     for a, b in remove:
         argset.remove(a)
         argset.remove(b)
     if len(argset) == 0:
         return false
     elif len(argset) == 1:
         return argset.pop()
     elif True in argset:
         argset.remove(True)
         return Not(Xor(*argset))
     else:
         obj._args = tuple(ordered(argset))
         obj._argset = frozenset(argset)
         return obj
        def do(rv, first=True):

            if not rv.is_Add:
                return rv

            args = list(ordered(rv.args))
            if len(args) != 2:
                hit = False
                for i in range(len(args)):
                    ai = args[i]
                    if ai is None:
                        continue
                    for j in range(i + 1, len(args)):
                        aj = args[j]
                        if aj is None:
                            continue
                        was = ai + aj
                        new = do(was)
                        if new != was:
                            args[i] = new  # update in place
                            args[j] = None
                            hit = True
                            break  # go to next i
                if hit:
                    rv = Add(*[_f for _f in args if _f])
                    if rv.is_Add:
                        rv = do(rv)

                return rv

            # two-arg Add
            split = trig_split(*args, two=True)
            if not split:
                return rv
            gcd, n1, n2, a, b, same = split

            # identify and get c1 to be cos then apply rule if possible
            if same:  # coscos, sinsin
                gcd = n1 * gcd
                if n1 == n2:
                    return gcd * cos(a - b)
                return gcd * cos(a + b)
            else:  # cossin, cossin
                gcd = n1 * gcd
                if n1 == n2:
                    return gcd * sin(a + b)
                return gcd * sin(b - a)
예제 #32
0
    def _get_complexes(cls, factors, use_cache=True):
        """Compute complex root isolating intervals for a list of factors. """
        complexes = []

        for currentfactor, k in ordered(factors):
            try:
                if not use_cache:
                    raise KeyError
                c = _complexes_cache[currentfactor]
                complexes.extend([(i, currentfactor, k) for i in c])
            except KeyError:
                complex_part = cls._get_complexes_sqf(currentfactor, use_cache)
                new = [(root, currentfactor, k) for root in complex_part]
                complexes.extend(new)

        complexes = cls._complexes_sorted(complexes)
        return complexes
예제 #33
0
 def __new__(cls, *args, **kwargs):
     argset = set([])
     obj = super(Xor, cls).__new__(cls, *args, **kwargs)
     for arg in obj._args:
         if isinstance(arg, Number) or arg in (True, False):
             if arg:
                 arg = true
             else:
                 continue
         if isinstance(arg, Xor):
             for a in arg.args:
                 argset.remove(a) if a in argset else argset.add(a)
         elif arg in argset:
             argset.remove(arg)
         else:
             argset.add(arg)
     rel = [(r, r.canonical, (~r).canonical) for r in argset if r.is_Relational]
     odd = False  # is number of complimentary pairs odd? start 0 -> False
     remove = []
     for i, (r, c, nc) in enumerate(rel):
         for j in range(i + 1, len(rel)):
             rj, cj = rel[j][:2]
             if cj == nc:
                 odd = ~odd
                 break
             elif cj == c:
                 break
         else:
             continue
         remove.append((r, rj))
     if odd:
         argset.remove(true) if true in argset else argset.add(true)
     for a, b in remove:
         argset.remove(a)
         argset.remove(b)
     if len(argset) == 0:
         return false
     elif len(argset) == 1:
         return argset.pop()
     elif True in argset:
         argset.remove(True)
         return Not(Xor(*argset))
     else:
         obj._args = tuple(ordered(argset))
         obj._argset = frozenset(argset)
         return obj
예제 #34
0
파일: rootoftools.py 프로젝트: Lenqth/sympy
    def _get_complexes(cls, factors, use_cache=True):
        """Compute complex root isolating intervals for a list of factors. """
        complexes = []

        for factor, k in ordered(factors):
            try:
                if not use_cache:
                    raise KeyError
                c = _complexes_cache[factor]
                complexes.extend([(i, factor, k) for i in c])
            except KeyError:
                complex_part = cls._get_complexes_sqf(factor, use_cache)
                new = [(root, factor, k) for root in complex_part]
                complexes.extend(new)

        complexes = cls._complexes_sorted(complexes)
        return complexes
예제 #35
0
파일: sets.py 프로젝트: axiom24/sympy
    def __new__(cls, *args, **kwargs):
        evaluate = kwargs.get('evaluate', global_evaluate[0])
        if evaluate:
            if len(args) == 1 and iterable(args[0]):
                args = args[0]

            args = list(map(sympify, args))

            if len(args) == 0:
                return EmptySet()
        else:
            args = list(map(sympify, args))

        args = list(ordered(frozenset(args)))
        obj = Basic.__new__(cls, *args)
        obj._elements = frozenset(args)
        return obj
예제 #36
0
def root_factors(f, *gens, **args):
    """
    Returns all factors of a univariate polynomial.

    Examples
    ========

    >>> from sympy.abc import x, y
    >>> from sympy.polys.polyroots import root_factors

    >>> root_factors(x**2 - y, x)
    [x - sqrt(y), x + sqrt(y)]

    """
    args = dict(args)
    filter = args.pop('filter', None)

    F = Poly(f, *gens, **args)

    if not F.is_Poly:
        return [f]

    if F.is_multivariate:
        raise ValueError('multivariate polynomials are not supported')

    x = F.gens[0]

    zeros = roots(F, filter=filter)

    if not zeros:
        factors = [F]
    else:
        factors, N = [], 0

        for r, n in ordered(zeros.items()):
            factors, N = factors + [Poly(x - r, x)]*n, N + n

        if N < F.degree():
            G = reduce(lambda p, q: p*q, factors)
            factors.append(F.quo(G))

    if not isinstance(f, Poly):
        factors = [ f.as_expr() for f in factors ]

    return factors
    def f(rv):
        if not rv.func == tan:
            return rv

        arg = rv.args[0]
        if arg.is_Add:
            if first:
                args = list(ordered(arg.args))
            else:
                args = list(arg.args)
            a = args.pop()
            b = Add._from_args(args)
            if b.is_Add:
                tb = TR12(tan(b), first=False)
            else:
                tb = tan(b)
            return (tan(a) + tb) / (1 - tan(a) * tb)
        return rv
예제 #38
0
    def _from_args(cls, args, is_commutative=None):
        """Create new instance with already-processed args.
        If the args are not in canonical order, then a non-canonical
        result will be returned, so use with caution. The order of
        args may change if the sign of the args is changed."""
        if len(args) == 0:
            return cls.identity
        elif len(args) == 1:
            return args[0]

        args = tuple(ordered(args, default_sort_key))
        #         args = tuple(sorted(args, key=lambda x: str(x)))

        obj = super(AssocOp, cls).__new__(cls, *args)
        #         if is_commutative is None:
        #             is_commutative = fuzzy_and(a.is_commutative for a in args)
        #         obj.is_commutative = is_commutative
        return obj
예제 #39
0
def quantity_simplify(expr):
    """Return an equivalent expression in which prefixes are replaced
    with numerical values and products of related quantities are
    unified in a canonical manner.

    Examples
    ========

    >>> from sympy.physics.units.util import quantity_simplify
    >>> from sympy.physics.units.prefixes import kilo
    >>> from sympy.physics.units import foot, inch
    >>> quantity_simplify(kilo*foot*inch)
    250*foot**2/3
    """

    if expr.is_Atom:
        return expr

    if isinstance(expr, Prefix):
        return expr.scale_factor

    expr = expr.func(*map(quantity_simplify, expr.args))

    if not expr.is_Mul or not expr.has(Quantity):
        return expr

    args_pow = [arg.as_base_exp() for arg in expr.args]
    quantity_pow, other_pow = sift(args_pow,
                                   lambda x: isinstance(x[0], Quantity),
                                   binary=True)
    coeff = Mul.fromiter([Pow(b, e, evaluate=False) for b, e in other_pow])
    quantity_pow_by_dim = sift(quantity_pow, lambda x: x[0].dimension)
    new_quantities = []
    for _, bp in quantity_pow_by_dim.items():
        if len(bp) == 1:
            new_quantities.append(bp[0][0]**bp[0][1])
        else:
            # just let reference quantity be the first quantity,
            # picked from an ordered list
            bp = list(ordered(bp))
            ref = bp[0][0] / bp[0][0].scale_factor
            new_quantities.append(
                Mul.fromiter(((ref * b.scale_factor)**p for b, p in bp)))
    return coeff * Mul.fromiter(new_quantities)
예제 #40
0
파일: boolalg.py 프로젝트: Bercio/sympy
def _finger(eq):
    """
    Assign a 5-item fingerprint to each symbol in the equation:
    [
    # of times it appeared as a Symbol,
    # of times it appeared as a Not(symbol),
    # of times it appeared as a Symbol in an And or Or,
    # of times it appeared as a Not(Symbol) in an And or Or,
    sum of the number of arguments with which it appeared,
    counting Symbol as 1 and Not(Symbol) as 2
    ]

    >>> from sympy.logic.boolalg import _finger as finger
    >>> from sympy import And, Or, Not
    >>> from sympy.abc import a, b, x, y
    >>> eq = Or(And(Not(y), a), And(Not(y), b), And(x, y))
    >>> dict(finger(eq))
    {(0, 0, 1, 0, 2): [x], (0, 0, 1, 0, 3): [a, b], (0, 0, 1, 2, 8): [y]}

    So y and x have unique fingerprints, but a and b do not.
    """
    f = eq.free_symbols
    d = dict(list(zip(f, [[0] * 5 for fi in f])))
    for a in eq.args:
        if a.is_Symbol:
            d[a][0] += 1
        elif a.is_Not:
            d[a.args[0]][1] += 1
        else:
            o = len(a.args) + sum(ai.func is Not for ai in a.args)
            for ai in a.args:
                if ai.is_Symbol:
                    d[ai][2] += 1
                    d[ai][-1] += o
                else:
                    d[ai.args[0]][3] += 1
                    d[ai.args[0]][-1] += o
    inv = defaultdict(list)
    for k, v in ordered(iter(d.items())):
        inv[tuple(v)].append(k)
    return inv
예제 #41
0
def _finger(eq):
    """
    Assign a 5-item fingerprint to each symbol in the equation:
    [
    # of times it appeared as a Symbol,
    # of times it appeared as a Not(symbol),
    # of times it appeared as a Symbol in an And or Or,
    # of times it appeared as a Not(Symbol) in an And or Or,
    sum of the number of arguments with which it appeared,
    counting Symbol as 1 and Not(Symbol) as 2
    ]

    >>> from sympy.logic.boolalg import _finger as finger
    >>> from sympy import And, Or, Not
    >>> from sympy.abc import a, b, x, y
    >>> eq = Or(And(Not(y), a), And(Not(y), b), And(x, y))
    >>> dict(finger(eq))
    {(0, 0, 1, 0, 2): [x], (0, 0, 1, 0, 3): [a, b], (0, 0, 1, 2, 8): [y]}

    So y and x have unique fingerprints, but a and b do not.
    """
    f = eq.free_symbols
    d = dict(list(zip(f, [[0] * 5 for fi in f])))
    for a in eq.args:
        if a.is_Symbol:
            d[a][0] += 1
        elif a.is_Not:
            d[a.args[0]][1] += 1
        else:
            o = len(a.args) + sum(ai.func is Not for ai in a.args)
            for ai in a.args:
                if ai.is_Symbol:
                    d[ai][2] += 1
                    d[ai][-1] += o
                else:
                    d[ai.args[0]][3] += 1
                    d[ai.args[0]][-1] += o
    inv = defaultdict(list)
    for k, v in ordered(iter(d.items())):
        inv[tuple(v)].append(k)
    return inv
예제 #42
0
    def __new__(cls, *args, **options):
        args = (_sympify_(arg) for arg in args)

        try:
            # /!\ args is a generator and _new_args_filter
            # must be careful to handle as such; this
            # is done so short-circuiting can be done
            # without having to sympify all values
            _args = frozenset(cls._new_args_filter(args))
        except ShortCircuit:
            return sympify(cls.zero)
        if not _args:
            return sympify(cls.identity)
        elif len(_args) == 1:
            return set(_args).pop()
        else:
            # XXX in almost every other case for __new__, *_args is
            # passed along, but the expectation here is for _args
            obj = super(AssocOp, cls).__new__(cls, *ordered(_args))
            obj._argset = _args
            return obj
예제 #43
0
파일: ellipse.py 프로젝트: aprasanna/sympy
    def intersection(self, o):
        """The intersection of this ellipse and another geometrical entity
        `o`.

        Parameters
        ==========

        o : GeometryEntity

        Returns
        =======

        intersection : list of GeometryEntity objects

        Notes
        -----
        Currently supports intersections with Point, Line, Segment, Ray,
        Circle and Ellipse types.

        See Also
        ========

        sympy.geometry.entity.GeometryEntity

        Examples
        ========

        >>> from sympy import Ellipse, Point, Line, sqrt
        >>> e = Ellipse(Point(0, 0), 5, 7)
        >>> e.intersection(Point(0, 0))
        []
        >>> e.intersection(Point(5, 0))
        [Point2D(5, 0)]
        >>> e.intersection(Line(Point(0,0), Point(0, 1)))
        [Point2D(0, -7), Point2D(0, 7)]
        >>> e.intersection(Line(Point(5,0), Point(5, 1)))
        [Point2D(5, 0)]
        >>> e.intersection(Line(Point(6,0), Point(6, 1)))
        []
        >>> e = Ellipse(Point(-1, 0), 4, 3)
        >>> e.intersection(Ellipse(Point(1, 0), 4, 3))
        [Point2D(0, -3*sqrt(15)/4), Point2D(0, 3*sqrt(15)/4)]
        >>> e.intersection(Ellipse(Point(5, 0), 4, 3))
        [Point2D(2, -3*sqrt(7)/4), Point2D(2, 3*sqrt(7)/4)]
        >>> e.intersection(Ellipse(Point(100500, 0), 4, 3))
        []
        >>> e.intersection(Ellipse(Point(0, 0), 3, 4))
        [Point2D(3, 0), Point2D(-363/175, -48*sqrt(111)/175), Point2D(-363/175, 48*sqrt(111)/175)]
        >>> e.intersection(Ellipse(Point(-1, 0), 3, 4))
        [Point2D(-17/5, -12/5), Point2D(-17/5, 12/5), Point2D(7/5, -12/5), Point2D(7/5, 12/5)]
        """
        # TODO: Replace solve with nonlinsolve, when nonlinsolve will be able to solve in real domain
        x = Dummy('x', real=True)
        y = Dummy('y', real=True)

        if isinstance(o, Point):
            if o in self:
                return [o]
            else:
                return []

        elif isinstance(o, (Segment2D, Ray2D)):
            ellipse_equation = self.equation(x, y)
            result = solve([ellipse_equation, Line(o.points[0], o.points[1]).equation(x, y)], [x, y])
            return list(ordered([Point(i) for i in result if i in o]))

        elif isinstance(o, Polygon):
            return o.intersection(self)

        elif isinstance(o, (Ellipse, Line2D)):
            if o == self:
                return self
            else:
                ellipse_equation = self.equation(x, y)
                return list(ordered([Point(i) for i in solve([ellipse_equation, o.equation(x, y)], [x, y])]))
        elif isinstance(o, LinearEntity3D):
                raise TypeError('Entity must be two dimensional, not three dimensional')
        else:
            raise TypeError('Intersection not handled for %s' % func_name(o))
예제 #44
0
파일: inference.py 프로젝트: Eskatrem/sympy
 def clauses(self):
     return list(ordered(self.clauses_))
예제 #45
0
def _solve_lambert(f, symbol, gens):
    """Return solution to ``f`` if it is a Lambert-type expression
    else raise NotImplementedError.

    The equality, ``f(x, a..f) = a*log(b*X + c) + d*X - f = 0`` has the
    solution,  `X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))`. There
    are a variety of forms for `f(X, a..f)` as enumerated below:

    1a1)
      if B**B = R for R not [0, 1] then
      log(B) + log(log(B)) = log(log(R))
      X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R))
    1a2)
      if B*(b*log(B) + c)**a = R then
      log(B) + a*log(b*log(B) + c) = log(R)
      X = log(B); d=1, f=log(R)
    1b)
      if a*log(b*B + c) + d*B = R then
      X = B, f = R
    2a)
      if (b*B + c)*exp(d*B + g) = R then
      log(b*B + c) + d*B + g = log(R)
      a = 1, f = log(R) - g, X = B
    2b)
      if -b*B + g*exp(d*B + h) = c then
      log(g) + d*B + h - log(b*B + c) = 0
      a = -1, f = -h - log(g), X = B
    3)
      if d*p**(a*B + g) - b*B = c then
      log(d) + (a*B + g)*log(p) - log(c + b*B) = 0
      a = -1, d = a*log(p), f = -log(d) - g*log(p)
    """

    nrhs, lhs = f.as_independent(symbol, as_Add=True)
    rhs = -nrhs

    lamcheck = [tmp for tmp in gens
                if (tmp.func in [exp, log] or
                (tmp.is_Pow and symbol in tmp.exp.free_symbols))]
    if not lamcheck:
        raise NotImplementedError()

    if lhs.is_Mul:
        lhs = expand_log(log(lhs))
        rhs = log(rhs)

    lhs = factor(lhs, deep=True)
    # make sure we are inverted as completely as possible
    r = Dummy()
    i, lhs = _invert(lhs - r, symbol)
    rhs = i.xreplace({r: rhs})

    # For the first ones:
    # 1a1) B**B = R != 0 (when 0, there is only a solution if the base is 0,
    #                     but if it is, the exp is 0 and 0**0=1
    #                     comes back as B*log(B) = log(R)
    # 1a2) B*(a + b*log(B))**p = R or with monomial expanded or with whole
    #                              thing expanded comes back unchanged
    #     log(B) + p*log(a + b*log(B)) = log(R)
    #     lhs is Mul:
    #         expand log of both sides to give:
    #         log(B) + log(log(B)) = log(log(R))
    # 1b) d*log(a*B + b) + c*B = R
    #     lhs is Add:
    #         isolate c*B and expand log of both sides:
    #         log(c) + log(B) = log(R - d*log(a*B + b))

    soln = []
    if not soln:
        mainlog = _mostfunc(lhs, log, symbol)
        if mainlog:
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(log(lhs) - log(rhs), symbol)
            elif lhs.is_Add:
                other = lhs.subs(mainlog, 0)
                if other and not other.is_Add and [
                        tmp for tmp in other.atoms(Pow)
                        if symbol in tmp.free_symbols]:
                    if not rhs:
                        diff = log(other) - log(other - lhs)
                    else:
                        diff = log(lhs - other) - log(rhs - other)
                    soln = _lambert(expand_log(diff), symbol)
                else:
                    #it's ready to go
                    soln = _lambert(lhs - rhs, symbol)

    # For the next two,
    #     collect on main exp
    #     2a) (b*B + c)*exp(d*B + g) = R
    #         lhs is mul:
    #             log to give
    #             log(b*B + c) + d*B = log(R) - g
    #     2b) -b*B + g*exp(d*B + h) = R
    #         lhs is add:
    #             add b*B
    #             log and rearrange
    #             log(R + b*B) - d*B = log(g) + h

    if not soln:
        mainexp = _mostfunc(lhs, exp, symbol)
        if mainexp:
            lhs = collect(lhs, mainexp)
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainexp-containing term to rhs
                other = lhs.subs(mainexp, 0)
                mainterm = lhs - other
                rhs=rhs - other
                if (mainterm.could_extract_minus_sign() and
                    rhs.could_extract_minus_sign()):
                    mainterm *= -1
                    rhs *= -1
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    # 3) d*p**(a*B + b) + c*B = R
    #     collect on main pow
    #     log(R - c*B) - a*B*log(p) = log(d) + b*log(p)

    if not soln:
        mainpow = _mostfunc(lhs, Pow, symbol)
        if mainpow and symbol in mainpow.exp.free_symbols:
            lhs = collect(lhs, mainpow)
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainpow-containing term to rhs
                other = lhs.subs(mainpow, 0)
                mainterm = lhs - other
                rhs = rhs - other
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    if not soln:
        raise NotImplementedError('%s does not appear to have a solution in '
            'terms of LambertW' % f)

    return list(ordered(soln))
예제 #46
0
파일: symbol.py 프로젝트: Daaofer/sympyTest
def disambiguate(*iter):
    """
    Return a Tuple containing the passed expressions with symbols
    that appear the same when printed replaced with numerically
    subscripted symbols, and all Dummy symbols replaced with Symbols.

    Parameters
    ==========

    iter: list of symbols or expressions.

    Examples
    ========

    >>> from sympy.core.symbol import disambiguate
    >>> from sympy import Dummy, Symbol, Tuple
    >>> from sympy.abc import y

    >>> tup = Symbol('_x'), Dummy('x'), Dummy('x')
    >>> disambiguate(*tup)
    (x_2, x, x_1)

    >>> eqs = Tuple(Symbol('x')/y, Dummy('x')/y)
    >>> disambiguate(*eqs)
    (x_1/y, x/y)

    >>> ix = Symbol('x', integer=True)
    >>> vx = Symbol('x')
    >>> disambiguate(vx + ix)
    (x + x_1,)

    To make your own mapping of symbols to use, pass only the free symbols
    of the expressions and create a dictionary:

    >>> free = eqs.free_symbols
    >>> mapping = dict(zip(free, disambiguate(*free)))
    >>> eqs.xreplace(mapping)
    (x_1/y, x/y)

    """
    new_iter = Tuple(*iter)
    key = lambda x: tuple(sorted(x.assumptions0.items()))
    syms = ordered(new_iter.free_symbols, keys=key)
    mapping = {}
    for s in syms:
        mapping.setdefault(str(s).lstrip('_'), []).append(s)
    reps = {}
    for k in mapping:
        # the first or only symbol doesn't get subscripted but make
        # sure that it's a Symbol, not a Dummy
        mapk0 = Symbol("%s" % (k), **mapping[k][0].assumptions0)
        if mapping[k][0] != mapk0:
            reps[mapping[k][0]] = mapk0
        # the others get subscripts (and are made into Symbols)
        skip = 0
        for i in range(1, len(mapping[k])):
            while True:
                name = "%s_%i" % (k, i + skip)
                if name not in mapping:
                    break
                skip += 1
            ki = mapping[k][i]
            reps[ki] = Symbol(name, **ki.assumptions0)
    return new_iter.xreplace(reps)
예제 #47
0
파일: sets.py 프로젝트: alphaitis/sympy
 def _sorted_args(self):
     from sympy.utilities import default_sort_key
     return tuple(ordered(self.args, Set._infimum_key))
예제 #48
0
파일: powsimp.py 프로젝트: bjodah/sympy
def powsimp(expr, deep=False, combine='all', force=False, measure=count_ops):
    """
    reduces expression by combining powers with similar bases and exponents.

    Notes
    =====

    If deep is True then powsimp() will also simplify arguments of
    functions. By default deep is set to False.

    If force is True then bases will be combined without checking for
    assumptions, e.g. sqrt(x)*sqrt(y) -> sqrt(x*y) which is not true
    if x and y are both negative.

    You can make powsimp() only combine bases or only combine exponents by
    changing combine='base' or combine='exp'.  By default, combine='all',
    which does both.  combine='base' will only combine::

         a   a          a                          2x      x
        x * y  =>  (x*y)   as well as things like 2   =>  4

    and combine='exp' will only combine
    ::

         a   b      (a + b)
        x * x  =>  x

    combine='exp' will strictly only combine exponents in the way that used
    to be automatic.  Also use deep=True if you need the old behavior.

    When combine='all', 'exp' is evaluated first.  Consider the first
    example below for when there could be an ambiguity relating to this.
    This is done so things like the second example can be completely
    combined.  If you want 'base' combined first, do something like
    powsimp(powsimp(expr, combine='base'), combine='exp').

    Examples
    ========

    >>> from sympy import powsimp, exp, log, symbols
    >>> from sympy.abc import x, y, z, n
    >>> powsimp(x**y*x**z*y**z, combine='all')
    x**(y + z)*y**z
    >>> powsimp(x**y*x**z*y**z, combine='exp')
    x**(y + z)*y**z
    >>> powsimp(x**y*x**z*y**z, combine='base', force=True)
    x**y*(x*y)**z

    >>> powsimp(x**z*x**y*n**z*n**y, combine='all', force=True)
    (n*x)**(y + z)
    >>> powsimp(x**z*x**y*n**z*n**y, combine='exp')
    n**(y + z)*x**(y + z)
    >>> powsimp(x**z*x**y*n**z*n**y, combine='base', force=True)
    (n*x)**y*(n*x)**z

    >>> x, y = symbols('x y', positive=True)
    >>> powsimp(log(exp(x)*exp(y)))
    log(exp(x)*exp(y))
    >>> powsimp(log(exp(x)*exp(y)), deep=True)
    x + y

    Radicals with Mul bases will be combined if combine='exp'

    >>> from sympy import sqrt, Mul
    >>> x, y = symbols('x y')

    Two radicals are automatically joined through Mul:

    >>> a=sqrt(x*sqrt(y))
    >>> a*a**3 == a**4
    True

    But if an integer power of that radical has been
    autoexpanded then Mul does not join the resulting factors:

    >>> a**4 # auto expands to a Mul, no longer a Pow
    x**2*y
    >>> _*a # so Mul doesn't combine them
    x**2*y*sqrt(x*sqrt(y))
    >>> powsimp(_) # but powsimp will
    (x*sqrt(y))**(5/2)
    >>> powsimp(x*y*a) # but won't when doing so would violate assumptions
    x*y*sqrt(x*sqrt(y))

    """
    from sympy.matrices.expressions.matexpr import MatrixSymbol

    def recurse(arg, **kwargs):
        _deep = kwargs.get('deep', deep)
        _combine = kwargs.get('combine', combine)
        _force = kwargs.get('force', force)
        _measure = kwargs.get('measure', measure)
        return powsimp(arg, _deep, _combine, _force, _measure)

    expr = sympify(expr)

    if (not isinstance(expr, Basic) or isinstance(expr, MatrixSymbol) or (
            expr.is_Atom or expr in (exp_polar(0), exp_polar(1)))):
        return expr

    if deep or expr.is_Add or expr.is_Mul and _y not in expr.args:
        expr = expr.func(*[recurse(w) for w in expr.args])

    if expr.is_Pow:
        return recurse(expr*_y, deep=False)/_y

    if not expr.is_Mul:
        return expr

    # handle the Mul
    if combine in ('exp', 'all'):
        # Collect base/exp data, while maintaining order in the
        # non-commutative parts of the product
        c_powers = defaultdict(list)
        nc_part = []
        newexpr = []
        coeff = S.One
        for term in expr.args:
            if term.is_Rational:
                coeff *= term
                continue
            if term.is_Pow:
                term = _denest_pow(term)
            if term.is_commutative:
                b, e = term.as_base_exp()
                if deep:
                    b, e = [recurse(i) for i in [b, e]]
                if b.is_Pow or isinstance(b, exp):
                    # don't let smthg like sqrt(x**a) split into x**a, 1/2
                    # or else it will be joined as x**(a/2) later
                    b, e = b**e, S.One
                c_powers[b].append(e)
            else:
                # This is the logic that combines exponents for equal,
                # but non-commutative bases: A**x*A**y == A**(x+y).
                if nc_part:
                    b1, e1 = nc_part[-1].as_base_exp()
                    b2, e2 = term.as_base_exp()
                    if (b1 == b2 and
                            e1.is_commutative and e2.is_commutative):
                        nc_part[-1] = Pow(b1, Add(e1, e2))
                        continue
                nc_part.append(term)

        # add up exponents of common bases
        for b, e in ordered(iter(c_powers.items())):
            # allow 2**x/4 -> 2**(x - 2); don't do this when b and e are
            # Numbers since autoevaluation will undo it, e.g.
            # 2**(1/3)/4 -> 2**(1/3 - 2) -> 2**(1/3)/4
            if (b and b.is_Rational and not all(ei.is_Number for ei in e) and \
                    coeff is not S.One and
                    b not in (S.One, S.NegativeOne)):
                m = multiplicity(abs(b), abs(coeff))
                if m:
                    e.append(m)
                    coeff /= b**m
            c_powers[b] = Add(*e)
        if coeff is not S.One:
            if coeff in c_powers:
                c_powers[coeff] += S.One
            else:
                c_powers[coeff] = S.One

        # convert to plain dictionary
        c_powers = dict(c_powers)

        # check for base and inverted base pairs
        be = list(c_powers.items())
        skip = set()  # skip if we already saw them
        for b, e in be:
            if b in skip:
                continue
            bpos = b.is_positive or b.is_polar
            if bpos:
                binv = 1/b
                if b != binv and binv in c_powers:
                    if b.as_numer_denom()[0] is S.One:
                        c_powers.pop(b)
                        c_powers[binv] -= e
                    else:
                        skip.add(binv)
                        e = c_powers.pop(binv)
                        c_powers[b] -= e

        # check for base and negated base pairs
        be = list(c_powers.items())
        _n = S.NegativeOne
        for i, (b, e) in enumerate(be):
            if ((-b).is_Symbol or b.is_Add) and -b in c_powers:
                if (b.is_positive in (0, 1) or e.is_integer):
                    c_powers[-b] += c_powers.pop(b)
                    if _n in c_powers:
                        c_powers[_n] += e
                    else:
                        c_powers[_n] = e

        # filter c_powers and convert to a list
        c_powers = [(b, e) for b, e in c_powers.items() if e]

        # ==============================================================
        # check for Mul bases of Rational powers that can be combined with
        # separated bases, e.g. x*sqrt(x*y)*sqrt(x*sqrt(x*y)) ->
        # (x*sqrt(x*y))**(3/2)
        # ---------------- helper functions

        def ratq(x):
            '''Return Rational part of x's exponent as it appears in the bkey.
            '''
            return bkey(x)[0][1]

        def bkey(b, e=None):
            '''Return (b**s, c.q), c.p where e -> c*s. If e is not given then
            it will be taken by using as_base_exp() on the input b.
            e.g.
                x**3/2 -> (x, 2), 3
                x**y -> (x**y, 1), 1
                x**(2*y/3) -> (x**y, 3), 2
                exp(x/2) -> (exp(a), 2), 1

            '''
            if e is not None:  # coming from c_powers or from below
                if e.is_Integer:
                    return (b, S.One), e
                elif e.is_Rational:
                    return (b, Integer(e.q)), Integer(e.p)
                else:
                    c, m = e.as_coeff_Mul(rational=True)
                    if c is not S.One:
                        if m.is_integer:
                            return (b, Integer(c.q)), m*Integer(c.p)
                        return (b**m, Integer(c.q)), Integer(c.p)
                    else:
                        return (b**e, S.One), S.One
            else:
                return bkey(*b.as_base_exp())

        def update(b):
            '''Decide what to do with base, b. If its exponent is now an
            integer multiple of the Rational denominator, then remove it
            and put the factors of its base in the common_b dictionary or
            update the existing bases if necessary. If it has been zeroed
            out, simply remove the base.
            '''
            newe, r = divmod(common_b[b], b[1])
            if not r:
                common_b.pop(b)
                if newe:
                    for m in Mul.make_args(b[0]**newe):
                        b, e = bkey(m)
                        if b not in common_b:
                            common_b[b] = 0
                        common_b[b] += e
                        if b[1] != 1:
                            bases.append(b)
        # ---------------- end of helper functions

        # assemble a dictionary of the factors having a Rational power
        common_b = {}
        done = []
        bases = []
        for b, e in c_powers:
            b, e = bkey(b, e)
            if b in common_b:
                common_b[b] = common_b[b] + e
            else:
                common_b[b] = e
            if b[1] != 1 and b[0].is_Mul:
                bases.append(b)
        bases.sort(key=default_sort_key)  # this makes tie-breaking canonical
        bases.sort(key=measure, reverse=True)  # handle longest first
        for base in bases:
            if base not in common_b:  # it may have been removed already
                continue
            b, exponent = base
            last = False  # True when no factor of base is a radical
            qlcm = 1  # the lcm of the radical denominators
            while True:
                bstart = b
                qstart = qlcm

                bb = []  # list of factors
                ee = []  # (factor's expo. and it's current value in common_b)
                for bi in Mul.make_args(b):
                    bib, bie = bkey(bi)
                    if bib not in common_b or common_b[bib] < bie:
                        ee = bb = []  # failed
                        break
                    ee.append([bie, common_b[bib]])
                    bb.append(bib)
                if ee:
                    # find the number of integral extractions possible
                    # e.g. [(1, 2), (2, 2)] -> min(2/1, 2/2) -> 1
                    min1 = ee[0][1]//ee[0][0]
                    for i in range(1, len(ee)):
                        rat = ee[i][1]//ee[i][0]
                        if rat < 1:
                            break
                        min1 = min(min1, rat)
                    else:
                        # update base factor counts
                        # e.g. if ee = [(2, 5), (3, 6)] then min1 = 2
                        # and the new base counts will be 5-2*2 and 6-2*3
                        for i in range(len(bb)):
                            common_b[bb[i]] -= min1*ee[i][0]
                            update(bb[i])
                        # update the count of the base
                        # e.g. x**2*y*sqrt(x*sqrt(y)) the count of x*sqrt(y)
                        # will increase by 4 to give bkey (x*sqrt(y), 2, 5)
                        common_b[base] += min1*qstart*exponent
                if (last  # no more radicals in base
                    or len(common_b) == 1  # nothing left to join with
                    or all(k[1] == 1 for k in common_b)  # no rad's in common_b
                        ):
                    break
                # see what we can exponentiate base by to remove any radicals
                # so we know what to search for
                # e.g. if base were x**(1/2)*y**(1/3) then we should
                # exponentiate by 6 and look for powers of x and y in the ratio
                # of 2 to 3
                qlcm = lcm([ratq(bi) for bi in Mul.make_args(bstart)])
                if qlcm == 1:
                    break  # we are done
                b = bstart**qlcm
                qlcm *= qstart
                if all(ratq(bi) == 1 for bi in Mul.make_args(b)):
                    last = True  # we are going to be done after this next pass
            # this base no longer can find anything to join with and
            # since it was longer than any other we are done with it
            b, q = base
            done.append((b, common_b.pop(base)*Rational(1, q)))

        # update c_powers and get ready to continue with powsimp
        c_powers = done
        # there may be terms still in common_b that were bases that were
        # identified as needing processing, so remove those, too
        for (b, q), e in common_b.items():
            if (b.is_Pow or isinstance(b, exp)) and \
                    q is not S.One and not b.exp.is_Rational:
                b, be = b.as_base_exp()
                b = b**(be/q)
            else:
                b = root(b, q)
            c_powers.append((b, e))
        check = len(c_powers)
        c_powers = dict(c_powers)
        assert len(c_powers) == check  # there should have been no duplicates
        # ==============================================================

        # rebuild the expression
        newexpr = expr.func(*(newexpr + [Pow(b, e) for b, e in c_powers.items()]))
        if combine == 'exp':
            return expr.func(newexpr, expr.func(*nc_part))
        else:
            return recurse(expr.func(*nc_part), combine='base') * \
                recurse(newexpr, combine='base')

    elif combine == 'base':

        # Build c_powers and nc_part.  These must both be lists not
        # dicts because exp's are not combined.
        c_powers = []
        nc_part = []
        for term in expr.args:
            if term.is_commutative:
                c_powers.append(list(term.as_base_exp()))
            else:
                nc_part.append(term)

        # Pull out numerical coefficients from exponent if assumptions allow
        # e.g., 2**(2*x) => 4**x
        for i in range(len(c_powers)):
            b, e = c_powers[i]
            if not (all(x.is_nonnegative for x in b.as_numer_denom()) or e.is_integer or force or b.is_polar):
                continue
            exp_c, exp_t = e.as_coeff_Mul(rational=True)
            if exp_c is not S.One and exp_t is not S.One:
                c_powers[i] = [Pow(b, exp_c), exp_t]

        # Combine bases whenever they have the same exponent and
        # assumptions allow
        # first gather the potential bases under the common exponent
        c_exp = defaultdict(list)
        for b, e in c_powers:
            if deep:
                e = recurse(e)
            c_exp[e].append(b)
        del c_powers

        # Merge back in the results of the above to form a new product
        c_powers = defaultdict(list)
        for e in c_exp:
            bases = c_exp[e]

            # calculate the new base for e

            if len(bases) == 1:
                new_base = bases[0]
            elif e.is_integer or force:
                new_base = expr.func(*bases)
            else:
                # see which ones can be joined
                unk = []
                nonneg = []
                neg = []
                for bi in bases:
                    if bi.is_negative:
                        neg.append(bi)
                    elif bi.is_nonnegative:
                        nonneg.append(bi)
                    elif bi.is_polar:
                        nonneg.append(
                            bi)  # polar can be treated like non-negative
                    else:
                        unk.append(bi)
                if len(unk) == 1 and not neg or len(neg) == 1 and not unk:
                    # a single neg or a single unk can join the rest
                    nonneg.extend(unk + neg)
                    unk = neg = []
                elif neg:
                    # their negative signs cancel in groups of 2*q if we know
                    # that e = p/q else we have to treat them as unknown
                    israt = False
                    if e.is_Rational:
                        israt = True
                    else:
                        p, d = e.as_numer_denom()
                        if p.is_integer and d.is_integer:
                            israt = True
                    if israt:
                        neg = [-w for w in neg]
                        unk.extend([S.NegativeOne]*len(neg))
                    else:
                        unk.extend(neg)
                        neg = []
                    del israt

                # these shouldn't be joined
                for b in unk:
                    c_powers[b].append(e)
                # here is a new joined base
                new_base = expr.func(*(nonneg + neg))
                # if there are positive parts they will just get separated
                # again unless some change is made

                def _terms(e):
                    # return the number of terms of this expression
                    # when multiplied out -- assuming no joining of terms
                    if e.is_Add:
                        return sum([_terms(ai) for ai in e.args])
                    if e.is_Mul:
                        return prod([_terms(mi) for mi in e.args])
                    return 1
                xnew_base = expand_mul(new_base, deep=False)
                if len(Add.make_args(xnew_base)) < _terms(new_base):
                    new_base = factor_terms(xnew_base)

            c_powers[new_base].append(e)

        # break out the powers from c_powers now
        c_part = [Pow(b, ei) for b, e in c_powers.items() for ei in e]

        # we're done
        return expr.func(*(c_part + nc_part))

    else:
        raise ValueError("combine must be one of ('all', 'exp', 'base').")
예제 #49
0
파일: symbol.py 프로젝트: Lenqth/sympy
def disambiguate(*iter):
    """
    Return a Tuple containing the passed expressions with symbols
    that appear the same when printed replaced with numerically
    subscripted symbols, and all Dummy symbols replaced with Symbols.

    Parameters
    ==========

    iter: list of symbols or expressions.

    Examples
    ========

    >>> from sympy.core.symbol import disambiguate
    >>> from sympy import Dummy, Symbol, Tuple
    >>> from sympy.abc import y

    >>> tup = Symbol('_x'), Dummy('x'), Dummy('x')
    >>> disambiguate(*tup)
    (x_2, x, x_1)

    >>> eqs = Tuple(Symbol('x')/y, Dummy('x')/y)
    >>> disambiguate(*eqs)
    (x_1/y, x/y)

    >>> ix = Symbol('x', integer=True)
    >>> vx = Symbol('x')
    >>> disambiguate(vx + ix)
    (x + x_1,)

    To make your own mapping of symbols to use, pass only the free symbols
    of the expressions and create a dictionary:

    >>> free = eqs.free_symbols
    >>> mapping = dict(zip(free, disambiguate(*free)))
    >>> eqs.xreplace(mapping)
    (x_1/y, x/y)

    """
    new_iter = Tuple(*iter)
    key = lambda x:tuple(sorted(x.assumptions0.items()))
    syms = ordered(new_iter.free_symbols, keys=key)
    mapping = {}
    for s in syms:
        mapping.setdefault(str(s).lstrip('_'), []).append(s)
    reps = {}
    for k in mapping:
        # the first or only symbol doesn't get subscripted but make
        # sure that it's a Symbol, not a Dummy
        k0 = Symbol("%s" % (k), **mapping[k][0].assumptions0)
        if k != k0:
            reps[mapping[k][0]] = k0
        # the others get subscripts (and are made into Symbols)
        skip = 0
        for i in range(1, len(mapping[k])):
            while True:
                name = "%s_%i" % (k, i + skip)
                if name not in mapping:
                    break
                skip += 1
            ki = mapping[k][i]
            reps[ki] = Symbol(name, **ki.assumptions0)
    return new_iter.xreplace(reps)
예제 #50
0
파일: heurisch.py 프로젝트: sympy/sympy
def heurisch(f, x, rewrite=False, hints=None, mappings=None, retries=3,
             degree_offset=0, unnecessary_permutations=None):
    """
    Compute indefinite integral using heuristic Risch algorithm.

    This is a heuristic approach to indefinite integration in finite
    terms using the extended heuristic (parallel) Risch algorithm, based
    on Manuel Bronstein's "Poor Man's Integrator".

    The algorithm supports various classes of functions including
    transcendental elementary or special functions like Airy,
    Bessel, Whittaker and Lambert.

    Note that this algorithm is not a decision procedure. If it isn't
    able to compute the antiderivative for a given function, then this is
    not a proof that such a functions does not exist.  One should use
    recursive Risch algorithm in such case.  It's an open question if
    this algorithm can be made a full decision procedure.

    This is an internal integrator procedure. You should use toplevel
    'integrate' function in most cases,  as this procedure needs some
    preprocessing steps and otherwise may fail.

    Specification
    =============

     heurisch(f, x, rewrite=False, hints=None)

       where
         f : expression
         x : symbol

         rewrite -> force rewrite 'f' in terms of 'tan' and 'tanh'
         hints   -> a list of functions that may appear in anti-derivate

          - hints = None          --> no suggestions at all
          - hints = [ ]           --> try to figure out
          - hints = [f1, ..., fn] --> we know better

    Examples
    ========

    >>> from sympy import tan
    >>> from sympy.integrals.heurisch import heurisch
    >>> from sympy.abc import x, y

    >>> heurisch(y*tan(x), x)
    y*log(tan(x)**2 + 1)/2

    See Manuel Bronstein's "Poor Man's Integrator":

    [1] http://www-sop.inria.fr/cafe/Manuel.Bronstein/pmint/index.html

    For more information on the implemented algorithm refer to:

    [2] K. Geddes, L. Stefanus, On the Risch-Norman Integration
       Method and its Implementation in Maple, Proceedings of
       ISSAC'89, ACM Press, 212-217.

    [3] J. H. Davenport, On the Parallel Risch Algorithm (I),
       Proceedings of EUROCAM'82, LNCS 144, Springer, 144-157.

    [4] J. H. Davenport, On the Parallel Risch Algorithm (III):
       Use of Tangents, SIGSAM Bulletin 16 (1982), 3-6.

    [5] J. H. Davenport, B. M. Trager, On the Parallel Risch
       Algorithm (II), ACM Transactions on Mathematical
       Software 11 (1985), 356-362.

    See Also
    ========

    sympy.integrals.integrals.Integral.doit
    sympy.integrals.integrals.Integral
    components
    """
    f = sympify(f)
    if x not in f.free_symbols:
        return f*x

    if not f.is_Add:
        indep, f = f.as_independent(x)
    else:
        indep = S.One

    rewritables = {
        (sin, cos, cot): tan,
        (sinh, cosh, coth): tanh,
    }

    if rewrite:
        for candidates, rule in rewritables.items():
            f = f.rewrite(candidates, rule)
    else:
        for candidates in rewritables.keys():
            if f.has(*candidates):
                break
        else:
            rewrite = True

    terms = components(f, x)

    if hints is not None:
        if not hints:
            a = Wild('a', exclude=[x])
            b = Wild('b', exclude=[x])
            c = Wild('c', exclude=[x])

            for g in set(terms):  # using copy of terms
                if g.is_Function:
                    if isinstance(g, li):
                        M = g.args[0].match(a*x**b)

                        if M is not None:
                            terms.add( x*(li(M[a]*x**M[b]) - (M[a]*x**M[b])**(-1/M[b])*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b])) )
                            #terms.add( x*(li(M[a]*x**M[b]) - (x**M[b])**(-1/M[b])*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b])) )
                            #terms.add( x*(li(M[a]*x**M[b]) - x*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b])) )
                            #terms.add( li(M[a]*x**M[b]) - Ei((M[b]+1)*log(M[a]*x**M[b])/M[b]) )

                    elif isinstance(g, exp):
                        M = g.args[0].match(a*x**2)

                        if M is not None:
                            if M[a].is_positive:
                                terms.add(erfi(sqrt(M[a])*x))
                            else: # M[a].is_negative or unknown
                                terms.add(erf(sqrt(-M[a])*x))

                        M = g.args[0].match(a*x**2 + b*x + c)

                        if M is not None:
                            if M[a].is_positive:
                                terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a]))*
                                          erfi(sqrt(M[a])*x + M[b]/(2*sqrt(M[a]))))
                            elif M[a].is_negative:
                                terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a]))*
                                          erf(sqrt(-M[a])*x - M[b]/(2*sqrt(-M[a]))))

                        M = g.args[0].match(a*log(x)**2)

                        if M is not None:
                            if M[a].is_positive:
                                terms.add(erfi(sqrt(M[a])*log(x) + 1/(2*sqrt(M[a]))))
                            if M[a].is_negative:
                                terms.add(erf(sqrt(-M[a])*log(x) - 1/(2*sqrt(-M[a]))))

                elif g.is_Pow:
                    if g.exp.is_Rational and g.exp.q == 2:
                        M = g.base.match(a*x**2 + b)

                        if M is not None and M[b].is_positive:
                            if M[a].is_positive:
                                terms.add(asinh(sqrt(M[a]/M[b])*x))
                            elif M[a].is_negative:
                                terms.add(asin(sqrt(-M[a]/M[b])*x))

                        M = g.base.match(a*x**2 - b)

                        if M is not None and M[b].is_positive:
                            if M[a].is_positive:
                                terms.add(acosh(sqrt(M[a]/M[b])*x))
                            elif M[a].is_negative:
                                terms.add((-M[b]/2*sqrt(-M[a])*
                                           atan(sqrt(-M[a])*x/sqrt(M[a]*x**2 - M[b]))))

        else:
            terms |= set(hints)

    dcache = DiffCache(x)

    for g in set(terms):  # using copy of terms
        terms |= components(dcache.get_diff(g), x)

    # TODO: caching is significant factor for why permutations work at all. Change this.
    V = _symbols('x', len(terms))


    # sort mapping expressions from largest to smallest (last is always x).
    mapping = list(reversed(list(zip(*ordered(                          #
        [(a[0].as_independent(x)[1], a) for a in zip(terms, V)])))[1])) #
    rev_mapping = {v: k for k, v in mapping}                            #
    if mappings is None:                                                #
        # optimizing the number of permutations of mapping              #
        assert mapping[-1][0] == x # if not, find it and correct this comment
        unnecessary_permutations = [mapping.pop(-1)]
        mappings = permutations(mapping)
    else:
        unnecessary_permutations = unnecessary_permutations or []

    def _substitute(expr):
        return expr.subs(mapping)

    for mapping in mappings:
        mapping = list(mapping)
        mapping = mapping + unnecessary_permutations
        diffs = [ _substitute(dcache.get_diff(g)) for g in terms ]
        denoms = [ g.as_numer_denom()[1] for g in diffs ]
        if all(h.is_polynomial(*V) for h in denoms) and _substitute(f).is_rational_function(*V):
            denom = reduce(lambda p, q: lcm(p, q, *V), denoms)
            break
    else:
        if not rewrite:
            result = heurisch(f, x, rewrite=True, hints=hints,
                unnecessary_permutations=unnecessary_permutations)

            if result is not None:
                return indep*result
        return None

    numers = [ cancel(denom*g) for g in diffs ]
    def _derivation(h):
        return Add(*[ d * h.diff(v) for d, v in zip(numers, V) ])

    def _deflation(p):
        for y in V:
            if not p.has(y):
                continue

            if _derivation(p) is not S.Zero:
                c, q = p.as_poly(y).primitive()
                return _deflation(c)*gcd(q, q.diff(y)).as_expr()

        return p

    def _splitter(p):
        for y in V:
            if not p.has(y):
                continue

            if _derivation(y) is not S.Zero:
                c, q = p.as_poly(y).primitive()

                q = q.as_expr()

                h = gcd(q, _derivation(q), y)
                s = quo(h, gcd(q, q.diff(y), y), y)

                c_split = _splitter(c)

                if s.as_poly(y).degree() == 0:
                    return (c_split[0], q * c_split[1])

                q_split = _splitter(cancel(q / s))

                return (c_split[0]*q_split[0]*s, c_split[1]*q_split[1])

        return (S.One, p)

    special = {}

    for term in terms:
        if term.is_Function:
            if isinstance(term, tan):
                special[1 + _substitute(term)**2] = False
            elif isinstance(term, tanh):
                special[1 + _substitute(term)] = False
                special[1 - _substitute(term)] = False
            elif isinstance(term, LambertW):
                special[_substitute(term)] = True

    F = _substitute(f)

    P, Q = F.as_numer_denom()

    u_split = _splitter(denom)
    v_split = _splitter(Q)

    polys = set(list(v_split) + [ u_split[0] ] + list(special.keys()))

    s = u_split[0] * Mul(*[ k for k, v in special.items() if v ])
    polified = [ p.as_poly(*V) for p in [s, P, Q] ]

    if None in polified:
        return None

    #--- definitions for _integrate
    a, b, c = [ p.total_degree() for p in polified ]

    poly_denom = (s * v_split[0] * _deflation(v_split[1])).as_expr()

    def _exponent(g):
        if g.is_Pow:
            if g.exp.is_Rational and g.exp.q != 1:
                if g.exp.p > 0:
                    return g.exp.p + g.exp.q - 1
                else:
                    return abs(g.exp.p + g.exp.q)
            else:
                return 1
        elif not g.is_Atom and g.args:
            return max([ _exponent(h) for h in g.args ])
        else:
            return 1

    A, B = _exponent(f), a + max(b, c)

    if A > 1 and B > 1:
        monoms = tuple(itermonomials(V, A + B - 1 + degree_offset))
    else:
        monoms = tuple(itermonomials(V, A + B + degree_offset))

    poly_coeffs = _symbols('A', len(monoms))

    poly_part = Add(*[ poly_coeffs[i]*monomial
        for i, monomial in enumerate(monoms) ])

    reducibles = set()

    for poly in polys:
        if poly.has(*V):
            try:
                factorization = factor(poly, greedy=True)
            except PolynomialError:
                factorization = poly

            if factorization.is_Mul:
                factors = factorization.args
            else:
                factors = (factorization, )

            for fact in factors:
                if fact.is_Pow:
                    reducibles.add(fact.base)
                else:
                    reducibles.add(fact)

    def _integrate(field=None):
        irreducibles = set()
        atans = set()
        pairs = set()

        for poly in reducibles:
            for z in poly.free_symbols:
                if z in V:
                    break  # should this be: `irreducibles |= \
            else:          # set(root_factors(poly, z, filter=field))`
                continue   # and the line below deleted?
                           #               |
                           #               V
            irreducibles |= set(root_factors(poly, z, filter=field))

        log_part, atan_part = [], []

        for poly in list(irreducibles):
            m = collect(poly, I, evaluate=False)
            y = m.get(I, S.Zero)
            if y:
                x = m.get(S.One, S.Zero)
                if x.has(I) or y.has(I):
                    continue  # nontrivial x + I*y
                pairs.add((x, y))
                irreducibles.remove(poly)

        while pairs:
            x, y = pairs.pop()
            if (x, -y) in pairs:
                pairs.remove((x, -y))
                # Choosing b with no minus sign
                if y.could_extract_minus_sign():
                    y = -y
                irreducibles.add(x*x + y*y)
                atans.add(atan(x/y))
            else:
                irreducibles.add(x + I*y)


        B = _symbols('B', len(irreducibles))
        C = _symbols('C', len(atans))

        # Note: the ordering matters here
        for poly, b in reversed(list(ordered(zip(irreducibles, B)))):
            if poly.has(*V):
                poly_coeffs.append(b)
                log_part.append(b * log(poly))

        for poly, c in reversed(list(ordered(zip(atans, C)))):
            if poly.has(*V):
                poly_coeffs.append(c)
                atan_part.append(c * poly)

        # TODO: Currently it's better to use symbolic expressions here instead
        # of rational functions, because it's simpler and FracElement doesn't
        # give big speed improvement yet. This is because cancellation is slow
        # due to slow polynomial GCD algorithms. If this gets improved then
        # revise this code.
        candidate = poly_part/poly_denom + Add(*log_part) + Add(*atan_part)
        h = F - _derivation(candidate) / denom
        raw_numer = h.as_numer_denom()[0]

        # Rewrite raw_numer as a polynomial in K[coeffs][V] where K is a field
        # that we have to determine. We can't use simply atoms() because log(3),
        # sqrt(y) and similar expressions can appear, leading to non-trivial
        # domains.
        syms = set(poly_coeffs) | set(V)
        non_syms = set([])

        def find_non_syms(expr):
            if expr.is_Integer or expr.is_Rational:
                pass # ignore trivial numbers
            elif expr in syms:
                pass # ignore variables
            elif not expr.has(*syms):
                non_syms.add(expr)
            elif expr.is_Add or expr.is_Mul or expr.is_Pow:
                list(map(find_non_syms, expr.args))
            else:
                # TODO: Non-polynomial expression. This should have been
                # filtered out at an earlier stage.
                raise PolynomialError

        try:
            find_non_syms(raw_numer)
        except PolynomialError:
            return None
        else:
            ground, _ = construct_domain(non_syms, field=True)

        coeff_ring = PolyRing(poly_coeffs, ground)
        ring = PolyRing(V, coeff_ring)
        try:
            numer = ring.from_expr(raw_numer)
        except ValueError:
            raise PolynomialError
        solution = solve_lin_sys(numer.coeffs(), coeff_ring, _raw=False)

        if solution is None:
            return None
        else:
            return candidate.subs(solution).subs(
                list(zip(poly_coeffs, [S.Zero]*len(poly_coeffs))))

    if not (F.free_symbols - set(V)):
        solution = _integrate('Q')

        if solution is None:
            solution = _integrate()
    else:
        solution = _integrate()

    if solution is not None:
        antideriv = solution.subs(rev_mapping)
        antideriv = cancel(antideriv).expand(force=True)

        if antideriv.is_Add:
            antideriv = antideriv.as_independent(x)[1]

        return indep*antideriv
    else:
        if retries >= 0:
            result = heurisch(f, x, mappings=mappings, rewrite=rewrite, hints=hints, retries=retries - 1, unnecessary_permutations=unnecessary_permutations)

            if result is not None:
                return indep*result

        return None
예제 #51
0
파일: operations.py 프로젝트: AStorus/sympy
 def args(self):
     return tuple(ordered(self._argset))
예제 #52
0
파일: operations.py 프로젝트: AStorus/sympy
    def _matches_commutative(self, expr, repl_dict={}, old=False):
        """
        Matches Add/Mul "pattern" to an expression "expr".

        repl_dict ... a dictionary of (wild: expression) pairs, that get
                      returned with the results

        This function is the main workhorse for Add/Mul.

        For instance:

        >>> from sympy import symbols, Wild, sin
        >>> a = Wild("a")
        >>> b = Wild("b")
        >>> c = Wild("c")
        >>> x, y, z = symbols("x y z")
        >>> (a+sin(b)*c)._matches_commutative(x+sin(y)*z)
        {a_: x, b_: y, c_: z}

        In the example above, "a+sin(b)*c" is the pattern, and "x+sin(y)*z" is
        the expression.

        The repl_dict contains parts that were already matched. For example
        here:

        >>> (x+sin(b)*c)._matches_commutative(x+sin(y)*z, repl_dict={a: x})
        {a_: x, b_: y, c_: z}

        the only function of the repl_dict is to return it in the
        result, e.g. if you omit it:

        >>> (x+sin(b)*c)._matches_commutative(x+sin(y)*z)
        {b_: y, c_: z}

        the "a: x" is not returned in the result, but otherwise it is
        equivalent.

        """
        # make sure expr is Expr if pattern is Expr
        from .expr import Add, Expr
        from sympy import Mul
        if isinstance(self, Expr) and not isinstance(expr, Expr):
            return None

        # handle simple patterns
        if self == expr:
            return repl_dict

        d = self._matches_simple(expr, repl_dict)
        if d is not None:
            return d

        # eliminate exact part from pattern: (2+a+w1+w2).matches(expr) -> (w1+w2).matches(expr-a-2)
        from .function import WildFunction
        from .symbol import Wild
        wild_part = []
        exact_part = []
        for p in ordered(self.args):
            if p.has(Wild, WildFunction) and (not expr.has(p)):
                # not all Wild should stay Wilds, for example:
                # (w2+w3).matches(w1) -> (w1+w3).matches(w1) -> w3.matches(0)
                wild_part.append(p)
            else:
                exact_part.append(p)

        if exact_part:
            exact = self.func(*exact_part)
            free = expr.free_symbols
            if free and (exact.free_symbols - free):
                # there are symbols in the exact part that are not
                # in the expr; but if there are no free symbols, let
                # the matching continue
                return None
            newpattern = self.func(*wild_part)
            newexpr = self._combine_inverse(expr, exact)
            if not old and (expr.is_Add or expr.is_Mul):
                if newexpr.count_ops() > expr.count_ops():
                    return None
            return newpattern.matches(newexpr, repl_dict)

        # now to real work ;)
        i = 0
        saw = set()
        while expr not in saw:
            saw.add(expr)
            expr_list = (self.identity,) + tuple(ordered(self.make_args(expr)))
            for last_op in reversed(expr_list):
                for w in reversed(wild_part):
                    d1 = w.matches(last_op, repl_dict)
                    if d1 is not None:
                        d2 = self.xreplace(d1).matches(expr, d1)
                        if d2 is not None:
                            return d2

            if i == 0:
                if self.is_Mul:
                    # make e**i look like Mul
                    if expr.is_Pow and expr.exp.is_Integer:
                        if expr.exp > 0:
                            expr = Mul(*[expr.base, expr.base**(expr.exp - 1)], evaluate=False)
                        else:
                            expr = Mul(*[1/expr.base, expr.base**(expr.exp + 1)], evaluate=False)
                        i += 1
                        continue

                elif self.is_Add:
                    # make i*e look like Add
                    c, e = expr.as_coeff_Mul()
                    if abs(c) > 1:
                        if c > 0:
                            expr = Add(*[e, (c - 1)*e], evaluate=False)
                        else:
                            expr = Add(*[-e, (c + 1)*e], evaluate=False)
                        i += 1
                        continue

                    # try collection on non-Wild symbols
                    from sympy.simplify.radsimp import collect
                    was = expr
                    did = set()
                    for w in reversed(wild_part):
                        c, w = w.as_coeff_mul(Wild)
                        free = c.free_symbols - did
                        if free:
                            did.update(free)
                            expr = collect(expr, free)
                    if expr != was:
                        i += 0
                        continue

                break  # if we didn't continue, there is nothing more to do

        return
    def f(rv):
        if not (rv.is_Add or rv.is_Mul):
            return rv

        def gooda(a):
            # bool to tell whether the leading ``a`` in ``a*log(x)``
            # could appear as log(x**a)
            return (a is not S.NegativeOne and  # -1 *could* go, but we disallow
                (a.is_real or force and a.is_real is not False))

        def goodlog(l):
            # bool to tell whether log ``l``'s argument can combine with others
            a = l.args[0]
            return a.is_positive or force and a.is_nonpositive is not False

        other = []
        logs = []
        log1 = defaultdict(list)
        for a in Add.make_args(rv):
            if a.func is log and goodlog(a):
                log1[()].append(([], a))
            elif not a.is_Mul:
                other.append(a)
            else:
                ot = []
                co = []
                lo = []
                for ai in a.args:
                    if ai.is_Rational and ai < 0:
                        ot.append(S.NegativeOne)
                        co.append(-ai)
                    elif ai.func is log and goodlog(ai):
                        lo.append(ai)
                    elif gooda(ai):
                        co.append(ai)
                    else:
                        ot.append(ai)
                if len(lo) > 1:
                    logs.append((ot, co, lo))
                elif lo:
                    log1[tuple(ot)].append((co, lo[0]))
                else:
                    other.append(a)

        # if there is only one log at each coefficient and none have
        # an exponent to place inside the log then there is nothing to do
        if not logs and all(len(log1[k]) == 1 and log1[k][0] == [] for k in log1):
            return rv

        # collapse multi-logs as far as possible in a canonical way
        # TODO: see if x*log(a)+x*log(a)*log(b) -> x*log(a)*(1+log(b))?
        # -- in this case, it's unambiguous, but if it were were a log(c) in
        # each term then it's arbitrary whether they are grouped by log(a) or
        # by log(c). So for now, just leave this alone; it's probably better to
        # let the user decide
        for o, e, l in logs:
            l = list(ordered(l))
            e = log(l.pop(0).args[0]**Mul(*e))
            while l:
                li = l.pop(0)
                e = log(li.args[0]**e)
            c, l = Mul(*o), e
            if l.func is log:  # it should be, but check to be sure
                log1[(c,)].append(([], l))
            else:
                other.append(c*l)

        # logs that have the same coefficient can multiply
        for k in list(log1.keys()):
            log1[Mul(*k)] = log(logcombine(Mul(*[
                l.args[0]**Mul(*c) for c, l in log1.pop(k)]),
                force=force))

        # logs that have oppositely signed coefficients can divide
        for k in ordered(list(log1.keys())):
            if not k in log1:  # already popped as -k
                continue
            if -k in log1:
                # figure out which has the minus sign; the one with
                # more op counts should be the one
                num, den = k, -k
                if num.count_ops() > den.count_ops():
                    num, den = den, num
                other.append(num*log(log1.pop(num).args[0]/log1.pop(den).args[0]))
            else:
                other.append(k*log1.pop(k))

        return Add(*other)
예제 #54
0
파일: util.py 프로젝트: asmeurer/sympy
def intersection(*entities, **kwargs):
    """The intersection of a collection of GeometryEntity instances.

    Parameters
    ==========
    entities : sequence of GeometryEntity
    pairwise (keyword argument) : Can be either True or False

    Returns
    =======
    intersection : list of GeometryEntity

    Raises
    ======
    NotImplementedError
        When unable to calculate intersection.

    Notes
    =====
    The intersection of any geometrical entity with itself should return
    a list with one item: the entity in question.
    An intersection requires two or more entities. If only a single
    entity is given then the function will return an empty list.
    It is possible for `intersection` to miss intersections that one
    knows exists because the required quantities were not fully
    simplified internally.
    Reals should be converted to Rationals, e.g. Rational(str(real_num))
    or else failures due to floating point issues may result.

    Case 1: When the keyword argument 'pairwise' is False (default value):
    In this case, the function returns a list of intersections common to
    all entities.

    Case 2: When the keyword argument 'pairwise' is True:
    In this case, the functions returns a list intersections that occur
    between any pair of entities.

    See Also
    ========

    sympy.geometry.entity.GeometryEntity.intersection

    Examples
    ========

    >>> from sympy.geometry import Ray, Circle, intersection
    >>> c = Circle((0, 1), 1)
    >>> intersection(c, c.center)
    []
    >>> right = Ray((0, 0), (1, 0))
    >>> up = Ray((0, 0), (0, 1))
    >>> intersection(c, right, up)
    [Point2D(0, 0)]
    >>> intersection(c, right, up, pairwise=True)
    [Point2D(0, 0), Point2D(0, 2)]
    >>> left = Ray((1, 0), (0, 0))
    >>> intersection(right, left)
    [Segment2D(Point2D(0, 0), Point2D(1, 0))]

    """

    from .entity import GeometryEntity
    from .point import Point

    pairwise = kwargs.pop('pairwise', False)

    if len(entities) <= 1:
        return []

    # entities may be an immutable tuple
    entities = list(entities)
    for i, e in enumerate(entities):
        if not isinstance(e, GeometryEntity):
            entities[i] = Point(e)

    if not pairwise:
        # find the intersection common to all objects
        res = entities[0].intersection(entities[1])
        for entity in entities[2:]:
            newres = []
            for x in res:
                newres.extend(x.intersection(entity))
            res = newres
        return res

    # find all pairwise intersections
    ans = []
    for j in range(0, len(entities)):
        for k in range(j + 1, len(entities)):
            ans.extend(intersection(entities[j], entities[k]))
    return list(ordered(set(ans)))
예제 #55
0
파일: heurisch.py 프로젝트: sympy/sympy
    def _integrate(field=None):
        irreducibles = set()
        atans = set()
        pairs = set()

        for poly in reducibles:
            for z in poly.free_symbols:
                if z in V:
                    break  # should this be: `irreducibles |= \
            else:          # set(root_factors(poly, z, filter=field))`
                continue   # and the line below deleted?
                           #               |
                           #               V
            irreducibles |= set(root_factors(poly, z, filter=field))

        log_part, atan_part = [], []

        for poly in list(irreducibles):
            m = collect(poly, I, evaluate=False)
            y = m.get(I, S.Zero)
            if y:
                x = m.get(S.One, S.Zero)
                if x.has(I) or y.has(I):
                    continue  # nontrivial x + I*y
                pairs.add((x, y))
                irreducibles.remove(poly)

        while pairs:
            x, y = pairs.pop()
            if (x, -y) in pairs:
                pairs.remove((x, -y))
                # Choosing b with no minus sign
                if y.could_extract_minus_sign():
                    y = -y
                irreducibles.add(x*x + y*y)
                atans.add(atan(x/y))
            else:
                irreducibles.add(x + I*y)


        B = _symbols('B', len(irreducibles))
        C = _symbols('C', len(atans))

        # Note: the ordering matters here
        for poly, b in reversed(list(ordered(zip(irreducibles, B)))):
            if poly.has(*V):
                poly_coeffs.append(b)
                log_part.append(b * log(poly))

        for poly, c in reversed(list(ordered(zip(atans, C)))):
            if poly.has(*V):
                poly_coeffs.append(c)
                atan_part.append(c * poly)

        # TODO: Currently it's better to use symbolic expressions here instead
        # of rational functions, because it's simpler and FracElement doesn't
        # give big speed improvement yet. This is because cancellation is slow
        # due to slow polynomial GCD algorithms. If this gets improved then
        # revise this code.
        candidate = poly_part/poly_denom + Add(*log_part) + Add(*atan_part)
        h = F - _derivation(candidate) / denom
        raw_numer = h.as_numer_denom()[0]

        # Rewrite raw_numer as a polynomial in K[coeffs][V] where K is a field
        # that we have to determine. We can't use simply atoms() because log(3),
        # sqrt(y) and similar expressions can appear, leading to non-trivial
        # domains.
        syms = set(poly_coeffs) | set(V)
        non_syms = set([])

        def find_non_syms(expr):
            if expr.is_Integer or expr.is_Rational:
                pass # ignore trivial numbers
            elif expr in syms:
                pass # ignore variables
            elif not expr.has(*syms):
                non_syms.add(expr)
            elif expr.is_Add or expr.is_Mul or expr.is_Pow:
                list(map(find_non_syms, expr.args))
            else:
                # TODO: Non-polynomial expression. This should have been
                # filtered out at an earlier stage.
                raise PolynomialError

        try:
            find_non_syms(raw_numer)
        except PolynomialError:
            return None
        else:
            ground, _ = construct_domain(non_syms, field=True)

        coeff_ring = PolyRing(poly_coeffs, ground)
        ring = PolyRing(V, coeff_ring)
        try:
            numer = ring.from_expr(raw_numer)
        except ValueError:
            raise PolynomialError
        solution = solve_lin_sys(numer.coeffs(), coeff_ring, _raw=False)

        if solution is None:
            return None
        else:
            return candidate.subs(solution).subs(
                list(zip(poly_coeffs, [S.Zero]*len(poly_coeffs))))
예제 #56
0
def roots(f, *gens, **flags):
    """
    Computes symbolic roots of a univariate polynomial.

    Given a univariate polynomial f with symbolic coefficients (or
    a list of the polynomial's coefficients), returns a dictionary
    with its roots and their multiplicities.

    Only roots expressible via radicals will be returned.  To get
    a complete set of roots use RootOf class or numerical methods
    instead. By default cubic and quartic formulas are used in
    the algorithm. To disable them because of unreadable output
    set ``cubics=False`` or ``quartics=False`` respectively. If cubic
    roots are real but are expressed in terms of complex numbers
    (casus irreducibilis [1]) the ``trig`` flag can be set to True to
    have the solutions returned in terms of cosine and inverse cosine
    functions.

    To get roots from a specific domain set the ``filter`` flag with
    one of the following specifiers: Z, Q, R, I, C. By default all
    roots are returned (this is equivalent to setting ``filter='C'``).

    By default a dictionary is returned giving a compact result in
    case of multiple roots.  However to get a list containing all
    those roots set the ``multiple`` flag to True; the list will
    have identical roots appearing next to each other in the result.
    (For a given Poly, the all_roots method will give the roots in
    sorted numerical order.)

    Examples
    ========

    >>> from sympy import Poly, roots
    >>> from sympy.abc import x, y

    >>> roots(x**2 - 1, x)
    {-1: 1, 1: 1}

    >>> p = Poly(x**2-1, x)
    >>> roots(p)
    {-1: 1, 1: 1}

    >>> p = Poly(x**2-y, x, y)

    >>> roots(Poly(p, x))
    {-sqrt(y): 1, sqrt(y): 1}

    >>> roots(x**2 - y, x)
    {-sqrt(y): 1, sqrt(y): 1}

    >>> roots([1, 0, -1])
    {-1: 1, 1: 1}


    References
    ==========

    .. [1] https://en.wikipedia.org/wiki/Cubic_function#Trigonometric_.28and_hyperbolic.29_method

    """
    from sympy.polys.polytools import to_rational_coeffs
    flags = dict(flags)

    auto = flags.pop('auto', True)
    cubics = flags.pop('cubics', True)
    trig = flags.pop('trig', False)
    quartics = flags.pop('quartics', True)
    quintics = flags.pop('quintics', False)
    multiple = flags.pop('multiple', False)
    filter = flags.pop('filter', None)
    predicate = flags.pop('predicate', None)

    if isinstance(f, list):
        if gens:
            raise ValueError('redundant generators given')

        x = Dummy('x')

        poly, i = {}, len(f) - 1

        for coeff in f:
            poly[i], i = sympify(coeff), i - 1

        f = Poly(poly, x, field=True)
    else:
        try:
            f = Poly(f, *gens, **flags)
            if f.length == 2 and f.degree() != 1:
                # check for foo**n factors in the constant
                n = f.degree()
                npow_bases = []
                others = []
                expr = f.as_expr()
                con = expr.as_independent(*gens)[0]
                for p in Mul.make_args(con):
                    if p.is_Pow and not p.exp % n:
                        npow_bases.append(p.base**(p.exp/n))
                    else:
                        others.append(p)
                    if npow_bases:
                        b = Mul(*npow_bases)
                        B = Dummy()
                        d = roots(Poly(expr - con + B**n*Mul(*others), *gens,
                            **flags), *gens, **flags)
                        rv = {}
                        for k, v in d.items():
                            rv[k.subs(B, b)] = v
                        return rv

        except GeneratorsNeeded:
            if multiple:
                return []
            else:
                return {}

        if f.is_multivariate:
            raise PolynomialError('multivariate polynomials are not supported')

    def _update_dict(result, currentroot, k):
        if currentroot in result:
            result[currentroot] += k
        else:
            result[currentroot] = k

    def _try_decompose(f):
        """Find roots using functional decomposition. """
        factors, roots = f.decompose(), []

        for currentroot in _try_heuristics(factors[0]):
            roots.append(currentroot)

        for currentfactor in factors[1:]:
            previous, roots = list(roots), []

            for currentroot in previous:
                g = currentfactor - Poly(currentroot, f.gen)

                for currentroot in _try_heuristics(g):
                    roots.append(currentroot)

        return roots

    def _try_heuristics(f):
        """Find roots using formulas and some tricks. """
        if f.is_ground:
            return []
        if f.is_monomial:
            return [S(0)]*f.degree()

        if f.length() == 2:
            if f.degree() == 1:
                return list(map(cancel, roots_linear(f)))
            else:
                return roots_binomial(f)

        result = []

        for i in [-1, 1]:
            if not f.eval(i):
                f = f.quo(Poly(f.gen - i, f.gen))
                result.append(i)
                break

        n = f.degree()

        if n == 1:
            result += list(map(cancel, roots_linear(f)))
        elif n == 2:
            result += list(map(cancel, roots_quadratic(f)))
        elif f.is_cyclotomic:
            result += roots_cyclotomic(f)
        elif n == 3 and cubics:
            result += roots_cubic(f, trig=trig)
        elif n == 4 and quartics:
            result += roots_quartic(f)
        elif n == 5 and quintics:
            result += roots_quintic(f)

        return result

    (k,), f = f.terms_gcd()

    if not k:
        zeros = {}
    else:
        zeros = {S(0): k}

    coeff, f = preprocess_roots(f)

    if auto and f.get_domain().is_Ring:
        f = f.to_field()

    rescale_x = None
    translate_x = None

    result = {}

    if not f.is_ground:
        dom = f.get_domain()
        if not dom.is_Exact and dom.is_Numerical:
            for r in f.nroots():
                _update_dict(result, r, 1)
        elif f.degree() == 1:
            result[roots_linear(f)[0]] = 1
        elif f.length() == 2:
            roots_fun = roots_quadratic if f.degree() == 2 else roots_binomial
            for r in roots_fun(f):
                _update_dict(result, r, 1)
        else:
            _, factors = Poly(f.as_expr()).factor_list()
            if len(factors) == 1 and f.degree() == 2:
                for r in roots_quadratic(f):
                    _update_dict(result, r, 1)
            else:
                if len(factors) == 1 and factors[0][1] == 1:
                    if f.get_domain().is_EX:
                        res = to_rational_coeffs(f)
                        if res:
                            if res[0] is None:
                                translate_x, f = res[2:]
                            else:
                                rescale_x, f = res[1], res[-1]
                            result = roots(f)
                            if not result:
                                for currentroot in _try_decompose(f):
                                    _update_dict(result, currentroot, 1)
                        else:
                            for r in _try_heuristics(f):
                                _update_dict(result, r, 1)
                    else:
                        for currentroot in _try_decompose(f):
                            _update_dict(result, currentroot, 1)
                else:
                    for currentfactor, k in factors:
                        for r in _try_heuristics(Poly(currentfactor, f.gen, field=True)):
                            _update_dict(result, r, k)

    if coeff is not S.One:
        _result, result, = result, {}

        for currentroot, k in _result.items():
            result[coeff*currentroot] = k

    result.update(zeros)

    if filter not in [None, 'C']:
        handlers = {
            'Z': lambda r: r.is_Integer,
            'Q': lambda r: r.is_Rational,
            'R': lambda r: r.is_real,
            'I': lambda r: r.is_imaginary,
        }

        try:
            query = handlers[filter]
        except KeyError:
            raise ValueError("Invalid filter: %s" % filter)

        for zero in dict(result).keys():
            if not query(zero):
                del result[zero]

    if predicate is not None:
        for zero in dict(result).keys():
            if not predicate(zero):
                del result[zero]
    if rescale_x:
        result1 = {}
        for k, v in result.items():
            result1[k*rescale_x] = v
        result = result1
    if translate_x:
        result1 = {}
        for k, v in result.items():
            result1[k + translate_x] = v
        result = result1

    if not multiple:
        return result
    else:
        zeros = []

        for zero in ordered(result):
            zeros.extend([zero]*result[zero])

        return zeros
예제 #57
0
 def args(self):
     return tuple(ordered(self._argset))