Ejemplo n.º 1
0
    def __print_set(self, set_):
        items = list(set_)
        items.sort(key=cmp_to_key(Basic.compare_pretty))

        s = self._print_seq(items, '(', ')')
        s = prettyForm(*stringPict.next(type(set_).__name__, s))
        return s
Ejemplo n.º 2
0
    def __print_set(self, set_):
        items = list(set_)
        items.sort( key=cmp_to_key(Basic.compare_pretty) )

        s = self._print_seq(items, '(', ')')
        s = prettyForm(*stringPict.next(type(set_).__name__, s))
        return s
Ejemplo n.º 3
0
    def _as_ordered_terms(self, expr, order=None):
        """A compatibility function for ordering terms in Add. """
        order = order or self.order

        if order == 'old':
            return sorted(Add.make_args(expr), key=cmp_to_key(Basic._compare_pretty))
        else:
            return expr.as_ordered_terms(order=order)
Ejemplo n.º 4
0
Archivo: str.py Proyecto: Ingwar/sympy
    def _print_dict(self, expr):
        keys = sorted(expr.keys(), key=cmp_to_key(Basic.compare_pretty))
        items = []
        for key in keys:
            item = "%s: %s" % (self._print(key), self._print(expr[key]))
            items.append(item)

        return "{%s}"%", ".join(items)
Ejemplo n.º 5
0
    def _as_ordered_terms(self, expr, order=None):
        """A compatibility function for ordering terms in Add. """
        order = order or self.order

        if order == 'old':
            return sorted(Add.make_args(expr), key=cmp_to_key(Basic._compare_pretty))
        else:
            return expr.as_ordered_terms(order=order)
Ejemplo n.º 6
0
Archivo: str.py Proyecto: Jerryy/sympy
    def __print_set(self, expr):
        items = list(expr)
        items.sort( key=cmp_to_key(Basic.compare_pretty) )

        args = ', '.join(self._print(item) for item in items)
        if args:
            args = '[%s]' % args
        return '%s(%s)' % (type(expr).__name__, args)
Ejemplo n.º 7
0
    def __print_set(self, expr):
        items = list(expr)
        items.sort(key=cmp_to_key(Basic.compare_pretty))

        args = ', '.join(self._print(item) for item in items)
        if args:
            args = '[%s]' % args
        return '%s(%s)' % (type(expr).__name__, args)
Ejemplo n.º 8
0
    def _print_dict(self, expr):
        items = []

        keys = sorted(expr.keys(), key=cmp_to_key(Basic.compare_pretty))
        for key in keys:
            val = expr[key]
            items.append("%s : %s" % (self._print(key), self._print(val)))

        return r"\begin{Bmatrix}%s\end{Bmatrix}" % r", & ".join(items)
Ejemplo n.º 9
0
    def _print_dict(self, expr):
        keys = expr.keys()
        keys.sort(key=cmp_to_key(Basic.compare_pretty))

        items = []
        for key in keys:
            item = "%s: %s" % (self._print(key), self._print(expr[key]))
            items.append(item)

        return "{%s}" % ", ".join(items)
Ejemplo n.º 10
0
    def _print_dict(self, expr):
        items = []

        keys = expr.keys()
        keys.sort(key=cmp_to_key(Basic.compare_pretty))
        for key in keys:
            val = expr[key]
            items.append("%s : %s" % (self._print(key), self._print(val)))

        return r"\begin{Bmatrix}%s\end{Bmatrix}" % r", & ".join(items)
Ejemplo n.º 11
0
    def _print_dict(self, d):
        items = []

        keys = sorted(d.keys(), key=cmp_to_key(Basic.compare_pretty) )
        for k in keys:
            K = self._print(k)
            V = self._print(d[k])
            s = prettyForm(*stringPict.next(K, ': ', V))

            items.append(s)

        return self._print_seq(items, '{', '}')
Ejemplo n.º 12
0
    def __new__(cls, expr, *symbols, **assumptions):

        expr = sympify(expr).expand()
        if expr is S.NaN:
            return S.NaN

        if symbols:
            symbols = map(sympify, symbols)
            if not all(isinstance(s, Symbol) for s in symbols):
                raise NotImplementedError(
                    'Order at points other than 0 not supported.')
        else:
            symbols = list(expr.free_symbols)

        if expr.is_Order:

            new_symbols = list(expr.variables)
            for s in symbols:
                if s not in new_symbols:
                    new_symbols.append(s)
            if len(new_symbols) == len(expr.variables):
                return expr
            symbols = new_symbols

        elif symbols:

            if expr.is_Add:
                lst = expr.extract_leading_order(*symbols)
                expr = Add(*[f.expr for (e, f) in lst])
            elif expr:
                if len(symbols) > 1 or expr.is_commutative is False:
                    # TODO
                    # We cannot use compute_leading_term because that only
                    # works in one symbol.
                    expr = expr.as_leading_term(*symbols)
                else:
                    expr = expr.compute_leading_term(symbols[0])
                terms = expr.as_coeff_mul(*symbols)[1]
                s = set(symbols)
                expr = Mul(*[t for t in terms if s & t.free_symbols])

        if expr is S.Zero:
            return expr
        elif not expr.has(*symbols):
            expr = S.One

        # create Order instance:
        symbols.sort(key=cmp_to_key(Basic.compare))
        obj = Expr.__new__(cls, expr, *symbols, **assumptions)

        return obj
Ejemplo n.º 13
0
    def __new__(cls, expr, *symbols, **assumptions):

        expr = sympify(expr).expand()
        if expr is S.NaN:
            return S.NaN

        if symbols:
            symbols = map(sympify, symbols)
            if not all(isinstance(s, Symbol) for s in symbols):
                raise NotImplementedError(
                    'Order at points other than 0 not supported.')
        else:
            symbols = list(expr.free_symbols)

        if expr.is_Order:

            new_symbols = list(expr.variables)
            for s in symbols:
                if s not in new_symbols:
                    new_symbols.append(s)
            if len(new_symbols) == len(expr.variables):
                return expr
            symbols = new_symbols

        elif symbols:

            if expr.is_Add:
                lst = expr.extract_leading_order(*symbols)
                expr = Add(*[f.expr for (e, f) in lst])
            elif expr:
                if len(symbols) > 1 or expr.is_commutative is False:
                    # TODO
                    # We cannot use compute_leading_term because that only
                    # works in one symbol.
                    expr = expr.as_leading_term(*symbols)
                else:
                    expr = expr.compute_leading_term(symbols[0])
                terms = expr.as_coeff_mul(*symbols)[1]
                s = set(symbols)
                expr = Mul(*[t for t in terms if s & t.free_symbols])

        if expr is S.Zero:
            return expr
        elif not expr.has(*symbols):
            expr = S.One

        # create Order instance:
        symbols.sort(key=cmp_to_key(Basic.compare))
        obj = Expr.__new__(cls, expr, *symbols, **assumptions)

        return obj
Ejemplo n.º 14
0
    def _print_dict(self, d):
        items = []

        keys = d.keys()
        keys.sort(key=cmp_to_key(Basic.compare_pretty))

        for k in keys:
            K = self._print(k)
            V = self._print(d[k])
            s = prettyForm(*stringPict.next(K, ': ', V))

            items.append(s)

        return self._print_seq(items, '{', '}')
Ejemplo n.º 15
0
class Mul(Expr, AssocOp):

    __slots__ = []

    is_Mul = True

    #identity = S.One
    # cyclic import, so defined in numbers.py

    # Key for sorting commutative args in canonical order
    _args_sortkey = cmp_to_key(Basic.compare)

    @classmethod
    def flatten(cls, seq):
        """Return commutative, noncommutative and order arguments by
        combining related terms.

        Notes
        =====
            * In an expression like ``a*b*c``, python process this through sympy
              as ``Mul(Mul(a, b), c)``. This can have undesirable consequences.

              -  Sometimes terms are not combined as one would like:
                 {c.f. http://code.google.com/p/sympy/issues/detail?id=1497}

                >>> from sympy import Mul, sqrt
                >>> from sympy.abc import x, y, z
                >>> 2*(x + 1) # this is the 2-arg Mul behavior
                2*x + 2
                >>> y*(x + 1)*2
                2*y*(x + 1)
                >>> 2*(x + 1)*y # 2-arg result will be obtained first
                y*(2*x + 2)
                >>> Mul(2, x + 1, y) # all 3 args simultaneously processed
                2*y*(x + 1)
                >>> 2*((x + 1)*y) # parentheses can control this behavior
                2*y*(x + 1)

                Powers with compound bases may not find a single base to
                combine with unless all arguments are processed at once.
                Post-processing may be necessary in such cases.
                {c.f. http://code.google.com/p/sympy/issues/detail?id=2629}

                >>> a = sqrt(x*sqrt(y))
                >>> a**3
                (x*sqrt(y))**(3/2)
                >>> Mul(a,a,a)
                (x*sqrt(y))**(3/2)
                >>> a*a*a
                x*sqrt(y)*sqrt(x*sqrt(y))
                >>> _.subs(a.base, z).subs(z, a.base)
                (x*sqrt(y))**(3/2)

              -  If more than two terms are being multiplied then all the
                 previous terms will be re-processed for each new argument.
                 So if each of ``a``, ``b`` and ``c`` were :class:`Mul`
                 expression, then ``a*b*c`` (or building up the product
                 with ``*=``) will process all the arguments of ``a`` and
                 ``b`` twice: once when ``a*b`` is computed and again when
                 ``c`` is multiplied.

                 Using ``Mul(a, b, c)`` will process all arguments once.

            * The results of Mul are cached according to arguments, so flatten
              will only be called once for ``Mul(a, b, c)``. If you can
              structure a calculation so the arguments are most likely to be
              repeats then this can save time in computing the answer. For
              example, say you had a Mul, M, that you wished to divide by ``d[i]``
              and multiply by ``n[i]`` and you suspect there are many repeats
              in ``n``. It would be better to compute ``M*n[i]/d[i]`` rather
              than ``M/d[i]*n[i]`` since every time n[i] is a repeat, the
              product, ``M*n[i]`` will be returned without flattening -- the
              cached value will be returned. If you divide by the ``d[i]``
              first (and those are more unique than the ``n[i]``) then that will
              create a new Mul, ``M/d[i]`` the args of which will be traversed
              again when it is multiplied by ``n[i]``.

              {c.f. http://code.google.com/p/sympy/issues/detail?id=2607}

              This consideration is moot if the cache is turned off.

            NB
            --
              The validity of the above notes depends on the implementation
              details of Mul and flatten which may change at any time. Therefore,
              you should only consider them when your code is highly performance
              sensitive.

              Removal of 1 from the sequence is already handled by AssocOp.__new__.
        """
        rv = None
        if len(seq) == 2:
            a, b = seq
            if b.is_Rational:
                a, b = b, a
            assert not a is S.One
            if a and a.is_Rational:
                r, b = b.as_coeff_Mul()
                a *= r
                if b.is_Mul:
                    bargs, nc = b.args_cnc()
                    rv = bargs, nc, None
                    if a is not S.One:
                        bargs.insert(0, a)

                elif b.is_Add and b.is_commutative:
                    if a is S.One:
                        rv = [b], [], None
                    else:
                        r, b = b.as_coeff_Add()
                        bargs = [_keep_coeff(a, bi) for bi in Add.make_args(b)]
                        bargs.sort(key=hash)
                        ar = a*r
                        if ar:
                            bargs.insert(0, ar)
                        bargs = [Add._from_args(bargs)]
                        rv = bargs, [], None
            if rv:
                return rv

        # apply associativity, separate commutative part of seq
        c_part = []         # out: commutative factors
        nc_part = []        # out: non-commutative factors

        nc_seq = []

        coeff = S.One       # standalone term
                            # e.g. 3 * ...

        iu = []             # ImaginaryUnits, I

        c_powers = []       # (base,exp)      n
                            # e.g. (x,n) for x

        num_exp = []        # (num-base, exp)           y
                            # e.g.  (3, y)  for  ... * 3  * ...

        neg1e = 0           # exponent on -1 extracted from Number-based Pow

        pnum_rat = {}       # (num-base, Rat-exp)          1/2
                            # e.g.  (3, 1/2)  for  ... * 3     * ...

        order_symbols = None

        # --- PART 1 ---
        #
        # "collect powers and coeff":
        #
        # o coeff
        # o c_powers
        # o num_exp
        # o neg1e
        # o pnum_rat
        #
        # NOTE: this is optimized for all-objects-are-commutative case
        for o in seq:
            # O(x)
            if o.is_Order:
                o, order_symbols = o.as_expr_variables(order_symbols)

            # Mul([...])
            if o.is_Mul:
                if o.is_commutative:
                    seq.extend(o.args)    # XXX zerocopy?

                else:
                    # NCMul can have commutative parts as well
                    for q in o.args:
                        if q.is_commutative:
                            seq.append(q)
                        else:
                            nc_seq.append(q)

                    # append non-commutative marker, so we don't forget to
                    # process scheduled non-commutative objects
                    seq.append(NC_Marker)

                continue

            # 3
            elif o.is_Number:
                if o is S.NaN or coeff is S.ComplexInfinity and o is S.Zero:
                    # we know for sure the result will be nan
                    return [S.NaN], [], None
                elif coeff.is_Number:  # it could be zoo
                    coeff *= o
                    if coeff is S.NaN:
                        # we know for sure the result will be nan
                        return [S.NaN], [], None
                continue

            elif o is S.ComplexInfinity:
                if not coeff:
                    # 0 * zoo = NaN
                    return [S.NaN], [], None
                if coeff is S.ComplexInfinity:
                    # zoo * zoo = zoo
                    return [S.ComplexInfinity], [], None
                coeff = S.ComplexInfinity
                continue

            elif o is S.ImaginaryUnit:
                iu.append(o)
                continue

            elif o.is_commutative:
                #      e
                # o = b
                b, e = o.as_base_exp()

                #  y
                # 3
                if o.is_Pow and b.is_Number:

                    # get all the factors with numeric base so they can be
                    # combined below, but don't combine negatives unless
                    # the exponent is an integer
                    if e.is_Rational:
                        if e.is_Integer:
                            coeff *= Pow(b, e)  # it is an unevaluated power
                            continue
                        elif e.is_negative:    # also a sign of an unevaluated power
                            seq.append(Pow(b, e))
                            continue
                        elif b.is_negative:
                            neg1e += e
                            b = -b
                        if b is not S.One:
                            pnum_rat.setdefault(b, []).append(e)
                        continue
                    elif b.is_positive or e.is_integer:
                        num_exp.append((b, e))
                        continue
                c_powers.append((b, e))

            # NON-COMMUTATIVE
            # TODO: Make non-commutative exponents not combine automatically
            else:
                if o is not NC_Marker:
                    nc_seq.append(o)

                # process nc_seq (if any)
                while nc_seq:
                    o = nc_seq.pop(0)
                    if not nc_part:
                        nc_part.append(o)
                        continue

                    #                             b    c       b+c
                    # try to combine last terms: a  * a   ->  a
                    o1 = nc_part.pop()
                    b1, e1 = o1.as_base_exp()
                    b2, e2 = o.as_base_exp()
                    new_exp = e1 + e2
                    # Only allow powers to combine if the new exponent is
                    # not an Add. This allow things like a**2*b**3 == a**5
                    # if a.is_commutative == False, but prohibits
                    # a**x*a**y and x**a*x**b from combining (x,y commute).
                    if b1 == b2 and (not new_exp.is_Add):
                        o12 = b1 ** new_exp

                        # now o12 could be a commutative object
                        if o12.is_commutative:
                            seq.append(o12)
                            continue
                        else:
                            nc_seq.insert(0, o12)

                    else:
                        nc_part.append(o1)
                        nc_part.append(o)

        # handle the ImaginaryUnits
        if iu:
            if len(iu) == 1:
                c_powers.append((iu[0], S.One))
            else:
                # a product of I's has one of 4 values; select that value
                # based on the length of iu:
                # len(iu) % 4 of (0, 1, 2, 3) has a corresponding value of
                #                (1, I,-1,-I)
                niu = len(iu) % 4
                if niu % 2:
                    c_powers.append((S.ImaginaryUnit, S.One))
                if niu in (2, 3):
                    coeff = -coeff

        # We do want a combined exponent if it would not be an Add, such as
        #  y    2y     3y
        # x  * x   -> x
        # We determine if two exponents have the same term by using
        # as_coeff_Mul.
        #
        # Unfortunately, this isn't smart enough to consider combining into
        # exponents that might already be adds, so things like:
        #  z - y    y
        # x      * x  will be left alone.  This is because checking every possible
        # combination can slow things down.

        # gather exponents of common bases...
        def _gather(c_powers):
            new_c_powers = []
            common_b = {}  # b:e
            for b, e in c_powers:
                co = e.as_coeff_Mul()
                common_b.setdefault(b, {}).setdefault(co[1], []).append(co[0])
            for b, d in common_b.items():
                for di, li in d.items():
                    d[di] = Add(*li)
            for b, e in common_b.items():
                for t, c in e.items():
                    new_c_powers.append((b, c*t))
            return new_c_powers

        # in c_powers
        c_powers = _gather(c_powers)

        # and in num_exp
        num_exp = _gather(num_exp)

        # --- PART 2 ---
        #
        # o process collected powers  (x**0 -> 1; x**1 -> x; otherwise Pow)
        # o combine collected powers  (2**x * 3**x -> 6**x)
        #   with numeric base

        # ................................
        # now we have:
        # - coeff:
        # - c_powers:    (b, e)
        # - num_exp:     (2, e)
        # - pnum_rat:    {(1/3, [1/3, 2/3, 1/4])}

        #  0             1
        # x  -> 1       x  -> x
        for b, e in c_powers:
            if e is S.One:
                if b.is_Number:
                    coeff *= b
                else:
                    c_part.append(b)
            elif e is not S.Zero:
                c_part.append(Pow(b, e))

        #  x    x     x
        # 2  * 3  -> 6
        inv_exp_dict = {}   # exp:Mul(num-bases)     x    x
                            # e.g.  x:6  for  ... * 2  * 3  * ...
        for b, e in num_exp:
            inv_exp_dict.setdefault(e, []).append(b)
        for e, b in inv_exp_dict.items():
            inv_exp_dict[e] = Mul(*b)
        c_part.extend([Pow(b, e) for e, b in inv_exp_dict.iteritems() if e])

        # b, e -> e' = sum(e), b
        # {(1/5, [1/3]), (1/2, [1/12, 1/4]} -> {(1/3, [1/5, 1/2])}
        comb_e = {}
        for b, e in pnum_rat.iteritems():
            comb_e.setdefault(Add(*e), []).append(b)
        del pnum_rat
        # process them, reducing exponents to values less than 1
        # and updating coeff if necessary else adding them to
        # num_rat for further processing
        num_rat = []
        for e, b in comb_e.iteritems():
            b = Mul(*b)
            if e.q == 1:
                coeff *= Pow(b, e)
                continue
            if e.p > e.q:
                e_i, ep = divmod(e.p, e.q)
                coeff *= Pow(b, e_i)
                e = Rational(ep, e.q)
            num_rat.append((b, e))
        del comb_e

        # extract gcd of bases in num_rat
        # 2**(1/3)*6**(1/4) -> 2**(1/3+1/4)*3**(1/4)
        pnew = defaultdict(list)
        i = 0  # steps through num_rat which may grow
        while i < len(num_rat):
            bi, ei = num_rat[i]
            grow = []
            for j in range(i + 1, len(num_rat)):
                bj, ej = num_rat[j]
                g = _rgcd(bi, bj)
                if g is not S.One:
                    # 4**r1*6**r2 -> 2**(r1+r2)  *  2**r1 *  3**r2
                    # this might have a gcd with something else
                    e = ei + ej
                    if e.q == 1:
                        coeff *= Pow(g, e)
                    else:
                        if e.p > e.q:
                            e_i, ep = divmod(e.p, e.q)  # change e in place
                            coeff *= Pow(g, e_i)
                            e = Rational(ep, e.q)
                        grow.append((g, e))
                    # update the jth item
                    num_rat[j] = (bj/g, ej)
                    # update bi that we are checking with
                    bi = bi/g
                    if bi is S.One:
                        break
            if bi is not S.One:
                obj = Pow(bi, ei)
                if obj.is_Number:
                    coeff *= obj
                else:
                    # changes like sqrt(12) -> 2*sqrt(3)
                    for obj in Mul.make_args(obj):
                        if obj.is_Number:
                            coeff *= obj
                        else:
                            assert obj.is_Pow
                            bi, ei = obj.args
                            pnew[ei].append(bi)

            num_rat.extend(grow)
            i += 1

        # combine bases of the new powers
        for e, b in pnew.iteritems():
            pnew[e] = Mul(*b)

        # see if there is a base with matching coefficient
        # that the -1 can be joined with
        if neg1e:
            p = Pow(S.NegativeOne, neg1e)
            if p.is_Number:
                coeff *= p
            else:
                c, p = p.as_coeff_Mul()
                coeff *= c
                if p.is_Pow and p.base is S.NegativeOne:
                    neg1e = p.exp
                for e, b in pnew.iteritems():
                    if e == neg1e and b.is_positive:
                        pnew[e] = -b
                        break
                else:
                    c_part.append(p)

        # add all the pnew powers
        c_part.extend([Pow(b, e) for e, b in pnew.iteritems()])

        # oo, -oo
        if (coeff is S.Infinity) or (coeff is S.NegativeInfinity):
            def _handle_for_oo(c_part, coeff_sign):
                new_c_part = []
                for t in c_part:
                    if t.is_positive:
                        continue
                    if t.is_negative:
                        coeff_sign *= -1
                        continue
                    new_c_part.append(t)
                return new_c_part, coeff_sign
            c_part, coeff_sign = _handle_for_oo(c_part, 1)
            nc_part, coeff_sign = _handle_for_oo(nc_part, coeff_sign)
            coeff *= coeff_sign

        # zoo
        if coeff is S.ComplexInfinity:
            # zoo might be
            #   unbounded_real + bounded_im
            #   bounded_real + unbounded_im
            #   unbounded_real + unbounded_im
            # and non-zero real or imaginary will not change that status.
            c_part = [c for c in c_part if not (c.is_nonzero and
                                                c.is_real is not None)]
            nc_part = [c for c in nc_part if not (c.is_nonzero and
                                                  c.is_real is not None)]

        # 0
        elif coeff is S.Zero:
            # we know for sure the result will be 0
            return [coeff], [], order_symbols

        # order commutative part canonically
        c_part.sort(key=cls._args_sortkey)

        # current code expects coeff to be always in slot-0
        if coeff is not S.One:
            c_part.insert(0, coeff)

        # we are done
        if len(c_part) == 2 and c_part[0].is_Number and c_part[1].is_Add:
            # 2*(1+a) -> 2 + 2 * a
            coeff = c_part[0]
            c_part = [Add(*[coeff*f for f in c_part[1].args])]

        return c_part, nc_part, order_symbols

    def _eval_power(b, e):

        # don't break up NC terms: (A*B)**3 != A**3*B**3, it is A*B*A*B*A*B
        cargs, nc = b.args_cnc(split_1=False)

        if e.is_Integer:
            return Mul(*[Pow(b, e, evaluate=False) for b in cargs]) * \
                Pow(Mul._from_args(nc), e, evaluate=False)

        p = Pow(b, e, evaluate=False)

        if e.is_Rational or e.is_Float:
            return p._eval_expand_power_base()

        return p

    @classmethod
    def class_key(cls):
        return 3, 0, cls.__name__

    def _eval_evalf(self, prec):
        c, m = self.as_coeff_Mul()
        if c is S.NegativeOne:
            if m.is_Mul:
                rv = -AssocOp._eval_evalf(m, prec)
            else:
                mnew = m._eval_evalf(prec)
                if mnew is not None:
                    m = mnew
                rv = -m
        else:
            rv = AssocOp._eval_evalf(self, prec)
        if rv.is_number:
            return rv.expand()
        return rv

    @cacheit
    def as_two_terms(self):
        """Return head and tail of self.

        This is the most efficient way to get the head and tail of an
        expression.

        - if you want only the head, use self.args[0];
        - if you want to process the arguments of the tail then use
          self.as_coef_mul() which gives the head and a tuple containing
          the arguments of the tail when treated as a Mul.
        - if you want the coefficient when self is treated as an Add
          then use self.as_coeff_add()[0]

        >>> from sympy.abc import x, y
        >>> (3*x*y).as_two_terms()
        (3, x*y)
        """
        args = self.args

        if len(args) == 1:
            return S.One, self
        elif len(args) == 2:
            return args

        else:
            return args[0], self._new_rawargs(*args[1:])

    @cacheit
    def as_coeff_mul(self, *deps):
        if deps:
            l1 = []
            l2 = []
            for f in self.args:
                if f.has(*deps):
                    l2.append(f)
                else:
                    l1.append(f)
            return self._new_rawargs(*l1), tuple(l2)
        args = self.args
        if args[0].is_Rational:
            return args[0], args[1:]
        elif args[0] is S.NegativeInfinity:
            return S.NegativeOne, (-args[0],) + args[1:]
        return S.One, args

    def as_coeff_Mul(self, rational=False):
        """Efficiently extract the coefficient of a product. """
        coeff, args = self.args[0], self.args[1:]

        if coeff.is_Number and not (rational and not coeff.is_Rational):
            if len(args) == 1:
                return coeff, args[0]
            else:
                return coeff, self._new_rawargs(*args)
        else:
            return S.One, self

    def as_real_imag(self, deep=True, **hints):
        other = []
        coeff = S(1)
        for a in self.args:
            if a.is_real:
                coeff *= a
            else:
                other.append(a)
        m = Mul(*other)
        if hints.get('ignore') == m:
            return None
        else:
            return (coeff*C.re(m), coeff*C.im(m))

    @staticmethod
    def _expandsums(sums):
        """
        Helper function for _eval_expand_mul.

        sums must be a list of instances of Basic.
        """

        L = len(sums)
        if L == 1:
            return sums[0].args
        terms = []
        left = Mul._expandsums(sums[:L//2])
        right = Mul._expandsums(sums[L//2:])

        terms = [Mul(a, b) for a in left for b in right]
        added = Add(*terms)
        return Add.make_args(added)  # it may have collapsed down to one term

    def _eval_expand_mul(self, **hints):
        from sympy import fraction, expand_mul

        # Handle things like 1/(x*(x + 1)), which are automatically converted
        # to 1/x*1/(x + 1)
        expr = self
        n, d = fraction(expr)
        if d.is_Mul:
            expr = n/d._eval_expand_mul(**hints)
            if not expr.is_Mul:
                return expand_mul(expr, deep=False)

        plain, sums, rewrite = [], [], False
        for factor in expr.args:
            if factor.is_Add:
                sums.append(factor)
                rewrite = True
            else:
                if factor.is_commutative:
                    plain.append(factor)
                else:
                    sums.append(Basic(factor))  # Wrapper

        if not rewrite:
            return expr
        else:
            plain = Mul(*plain)
            if sums:
                terms = Mul._expandsums(sums)
                args = []
                for term in terms:
                    t = Mul(plain, term)
                    if t.is_Mul and any(a.is_Add for a in t.args):
                        t = t._eval_expand_mul()
                    args.append(t)
                return Add(*args)
            else:
                return plain

    def _eval_derivative(self, s):
        terms = list(self.args)
        factors = []
        for i in xrange(len(terms)):
            t = terms[i].diff(s)
            if t is S.Zero:
                continue
            factors.append(Mul(*(terms[:i] + [t] + terms[i + 1:])))
        return Add(*factors)

    def _matches_simple(self, expr, repl_dict):
        # handle (w*3).matches('x*5') -> {w: x*5/3}
        coeff, terms = self.as_coeff_Mul()
        terms = Mul.make_args(terms)
        if len(terms) == 1:
            newexpr = self.__class__._combine_inverse(expr, coeff)
            return terms[0].matches(newexpr, repl_dict)
        return

    def matches(self, expr, repl_dict={}):
        expr = sympify(expr)
        if self.is_commutative and expr.is_commutative:
            return AssocOp._matches_commutative(self, expr, repl_dict)
        elif self.is_commutative is not expr.is_commutative:
            return None
        c1, nc1 = self.args_cnc()
        c2, nc2 = expr.args_cnc()
        repl_dict = repl_dict.copy()
        if c1:
            if not c2:
                c2 = [1]
            a = Mul(*c1)
            if isinstance(a, AssocOp):
                repl_dict = a._matches_commutative(Mul(*c2), repl_dict)
            else:
                repl_dict = a.matches(Mul(*c2), repl_dict)
        if repl_dict:
            a = Mul(*nc1)
            if isinstance(a, Mul):
                repl_dict = a._matches(Mul(*nc2), repl_dict)
            else:
                repl_dict = a.matches(Mul(*nc2), repl_dict)
        return repl_dict or None

    def _matches(self, expr, repl_dict={}):
        # weed out negative one prefixes
        sign = 1
        a, b = self.as_two_terms()
        if a is S.NegativeOne:
            if b.is_Mul:
                sign = -sign
            else:
                # the remainder, b, is not a Mul anymore
                return b.matches(-expr, repl_dict)
        expr = sympify(expr)
        if expr.is_Mul and expr.args[0] is S.NegativeOne:
            expr = -expr
            sign = -sign

        if not expr.is_Mul:
            # expr can only match if it matches b and a matches +/- 1
            if len(self.args) == 2:
                # quickly test for equality
                if b == expr:
                    return a.matches(Rational(sign), repl_dict)
                # do more expensive match
                dd = b.matches(expr, repl_dict)
                if dd is None:
                    return None
                dd = a.matches(Rational(sign), dd)
                return dd
            return None

        d = repl_dict.copy()

        # weed out identical terms
        pp = list(self.args)
        ee = list(expr.args)
        for p in self.args:
            if p in expr.args:
                ee.remove(p)
                pp.remove(p)

        # only one symbol left in pattern -> match the remaining expression
        if len(pp) == 1 and isinstance(pp[0], C.Wild):
            if len(ee) == 1:
                d[pp[0]] = sign * ee[0]
            else:
                d[pp[0]] = sign * expr.func(*ee)
            return d

        if len(ee) != len(pp):
            return None

        for p, e in zip(pp, ee):
            d = p.xreplace(d).matches(e, d)
            if d is None:
                return None
        return d

    @staticmethod
    def _combine_inverse(lhs, rhs):
        """
        Returns lhs/rhs, but treats arguments like symbols, so things like
        oo/oo return 1, instead of a nan.
        """
        if lhs == rhs:
            return S.One

        def check(l, r):
            if l.is_Float and r.is_comparable:
                # if both objects are added to 0 they will share the same "normalization"
                # and are more likely to compare the same. Since Add(foo, 0) will not allow
                # the 0 to pass, we use __add__ directly.
                return l.__add__(0) == r.evalf().__add__(0)
            return False
        if check(lhs, rhs) or check(rhs, lhs):
            return S.One
        if lhs.is_Mul and rhs.is_Mul:
            a = list(lhs.args)
            b = [1]
            for x in rhs.args:
                if x in a:
                    a.remove(x)
                else:
                    b.append(x)
            return Mul(*a)/Mul(*b)
        return lhs/rhs

    def as_powers_dict(self):
        d = defaultdict(list)
        for term in self.args:
            b, e = term.as_base_exp()
            d[b].append(e)
        for b, e in d.iteritems():
            if len(e) == 1:
                e = e[0]
            else:
                e = Add(*e)
            d[b] = e
        return d

    def as_numer_denom(self):
        # don't use _from_args to rebuild the numerators and denominators
        # as the order is not guaranteed to be the same once they have
        # been separated from each other
        numers, denoms = zip(*[f.as_numer_denom() for f in self.args])
        return Mul(*numers), Mul(*denoms)

    def as_base_exp(self):
        e1 = None
        bases = []
        nc = 0
        for m in self.args:
            b, e = m.as_base_exp()
            if not b.is_commutative:
                nc += 1
            if e1 is None:
                e1 = e
            elif e != e1 or nc > 1:
                return self, S.One
            bases.append(b)
        return Mul(*bases), e1

    def _eval_is_polynomial(self, syms):
        return all(term._eval_is_polynomial(syms) for term in self.args)

    def _eval_is_rational_function(self, syms):
        return all(term._eval_is_rational_function(syms) for term in self.args)

    _eval_is_bounded = lambda self: self._eval_template_is_attr('is_bounded')
    _eval_is_integer = lambda self: self._eval_template_is_attr(
        'is_integer', when_multiple=None)
    _eval_is_commutative = lambda self: self._eval_template_is_attr(
        'is_commutative')

    def _eval_is_polar(self):
        has_polar = any(arg.is_polar for arg in self.args)
        return has_polar and \
            all(arg.is_polar or arg.is_positive for arg in self.args)

    # I*I -> R,  I*I*I -> -I
    def _eval_is_real(self):
        im_count = 0
        is_neither = False
        for t in self.args:
            if t.is_imaginary:
                im_count += 1
                continue
            t_real = t.is_real
            if t_real:
                continue
            elif t_real is False:
                if is_neither:
                    return None
                else:
                    is_neither = True
            else:
                return None
        if is_neither:
            return False

        return (im_count % 2 == 0)

    def _eval_is_imaginary(self):
        im_count = 0
        is_neither = False
        for t in self.args:
            if t.is_imaginary:
                im_count += 1
                continue
            t_real = t.is_real
            if t_real:
                continue
            elif t_real is False:
                if is_neither:
                    return None
                else:
                    is_neither = True
            else:
                return None
        if is_neither:
            return False

        return (im_count % 2 == 1)

    def _eval_is_hermitian(self):
        nc_count = 0
        im_count = 0
        is_neither = False
        for t in self.args:
            if not t.is_commutative:
                nc_count += 1
                if nc_count > 1:
                    return None
            if t.is_antihermitian:
                im_count += 1
                continue
            t_real = t.is_hermitian
            if t_real:
                continue
            elif t_real is False:
                if is_neither:
                    return None
                else:
                    is_neither = True
            else:
                return None
        if is_neither:
            return False

        return (im_count % 2 == 0)

    def _eval_is_antihermitian(self):
        nc_count = 0
        im_count = 0
        is_neither = False
        for t in self.args:
            if not t.is_commutative:
                nc_count += 1
                if nc_count > 1:
                    return None
            if t.is_antihermitian:
                im_count += 1
                continue
            t_real = t.is_hermitian
            if t_real:
                continue
            elif t_real is False:
                if is_neither:
                    return None
                else:
                    is_neither = True
            else:
                return None
        if is_neither:
            return False

        return (im_count % 2 == 1)

    def _eval_is_irrational(self):
        for t in self.args:
            a = t.is_irrational
            if a:
                others = list(self.args)
                others.remove(t)
                if all(x.is_rational is True for x in others):
                    return True
                return None
            if a is None:
                return
        return False

    def _eval_is_zero(self):
        zero = None
        for a in self.args:
            if a.is_zero:
                zero = True
                continue
            bound = a.is_bounded
            if not bound:
                return bound
        if zero:
            return True

    def _eval_is_positive(self):
        """Return True if self is positive, False if not, and None if it
        cannot be determined.

        This algorithm is non-recursive and works by keeping track of the
        sign which changes when a negative or nonpositive is encountered.
        Whether a nonpositive or nonnegative is seen is also tracked since
        the presence of these makes it impossible to return True, but
        possible to return False if the end result is nonpositive. e.g.

            pos * neg * nonpositive -> pos or zero -> None is returned
            pos * neg * nonnegative -> neg or zero -> False is returned
        """

        sign = 1
        saw_NON = False
        for t in self.args:
            if t.is_positive:
                continue
            elif t.is_negative:
                sign = -sign
            elif t.is_zero:
                return False
            elif t.is_nonpositive:
                sign = -sign
                saw_NON = True
            elif t.is_nonnegative:
                saw_NON = True
            else:
                return
        if sign == 1 and saw_NON is False:
            return True
        if sign < 0:
            return False

    def _eval_is_negative(self):
        """Return True if self is negative, False if not, and None if it
        cannot be determined.

        This algorithm is non-recursive and works by keeping track of the
        sign which changes when a negative or nonpositive is encountered.
        Whether a nonpositive or nonnegative is seen is also tracked since
        the presence of these makes it impossible to return True, but
        possible to return False if the end result is nonnegative. e.g.

            pos * neg * nonpositive -> pos or zero -> False is returned
            pos * neg * nonnegative -> neg or zero -> None is returned
        """

        sign = 1
        saw_NON = False
        for t in self.args:
            if t.is_positive:
                continue
            elif t.is_negative:
                sign = -sign
            elif t.is_zero:
                return False
            elif t.is_nonpositive:
                sign = -sign
                saw_NON = True
            elif t.is_nonnegative:
                saw_NON = True
            else:
                return
        if sign == -1 and saw_NON is False:
            return True
        if sign > 0:
            return False

    def _eval_is_odd(self):
        is_integer = self.is_integer

        if is_integer:
            r = True
            for t in self.args:
                if t.is_even:
                    return False
                if t.is_odd is None:
                    r = None
            return r

        # !integer -> !odd
        elif is_integer is False:
            return False

    def _eval_is_even(self):
        is_integer = self.is_integer

        if is_integer:
            return fuzzy_not(self._eval_is_odd())

        elif is_integer is False:
            return False

    def _eval_subs(self, old, new):
        from sympy import sign, multiplicity
        from sympy.simplify.simplify import powdenest, fraction

        if not old.is_Mul:
            return None

        if old.args[0] == -1:
            return self._subs(-old, -new)

        def base_exp(a):
            # if I and -1 are in a Mul, they get both end up with
            # a -1 base (see issue 3322); all we want here are the
            # true Pow or exp separated into base and exponent
            if a.is_Pow or a.func is C.exp:
                return a.as_base_exp()
            return a, S.One

        def breakup(eq):
            """break up powers of eq when treated as a Mul:
                   b**(Rational*e) -> b**e, Rational
                commutatives come back as a dictionary {b**e: Rational}
                noncommutatives come back as a list [(b**e, Rational)]
            """

            (c, nc) = (defaultdict(int), list())
            for a in Mul.make_args(eq):
                a = powdenest(a)
                (b, e) = base_exp(a)
                if e is not S.One:
                    (co, _) = e.as_coeff_mul()
                    b = Pow(b, e/co)
                    e = co
                if a.is_commutative:
                    c[b] += e
                else:
                    nc.append([b, e])
            return (c, nc)

        def rejoin(b, co):
            """
            Put rational back with exponent; in general this is not ok, but
            since we took it from the exponent for analysis, it's ok to put
            it back.
            """

            (b, e) = base_exp(b)
            return Pow(b, e*co)

        def ndiv(a, b):
            """if b divides a in an extractive way (like 1/4 divides 1/2
            but not vice versa, and 2/5 does not divide 1/3) then return
            the integer number of times it divides, else return 0.
            """
            if not b.q % a.q or not a.q % b.q:
                return int(a/b)
            return 0

        # give Muls in the denominator a chance to be changed (see issue 2552)
        # rv will be the default return value
        rv = None
        n, d = fraction(self)
        if d is not S.One:
            self2 = n._subs(old, new)/d._subs(old, new)
            if not self2.is_Mul:
                return self2._subs(old, new)
            if self2 != self:
                self = rv = self2

        # Now continue with regular substitution.

        # handle the leading coefficient and use it to decide if anything
        # should even be started; we always know where to find the Rational
        # so it's a quick test

        co_self = self.args[0]
        co_old = old.args[0]
        co_xmul = None
        if co_old.is_Rational and co_self.is_Rational:
            # if coeffs are the same there will be no updating to do
            # below after breakup() step; so skip (and keep co_xmul=None)
            if co_old != co_self:
                co_xmul = co_self.extract_multiplicatively(co_old)
        elif co_old.is_Rational:
            return rv

        # break self and old into factors

        (c, nc) = breakup(self)
        (old_c, old_nc) = breakup(old)

        # update the coefficients if we had an extraction
        # e.g. if co_self were 2*(3/35*x)**2 and co_old = 3/5
        # then co_self in c is replaced by (3/5)**2 and co_residual
        # is 2*(1/7)**2

        if co_xmul and co_xmul.is_Rational:
            n_old, d_old = co_old.as_numer_denom()
            n_self, d_self = co_self.as_numer_denom()

            def _multiplicity(p, n):
                p = abs(p)
                if p is S.One:
                    return S.Infinity
                return multiplicity(p, abs(n))
            mult = S(min(_multiplicity(n_old, n_self),
                         _multiplicity(d_old, d_self)))
            c.pop(co_self)
            c[co_old] = mult
            co_residual = co_self/co_old**mult
        else:
            co_residual = 1

        # do quick tests to see if we can't succeed

        ok = True
        if len(old_nc) > len(nc):
            # more non-commutative terms
            ok = False
        elif len(old_c) > len(c):
            # more commutative terms
            ok = False
        elif set(i[0] for i in old_nc).difference(set(i[0] for i in nc)):
            # unmatched non-commutative bases
            ok = False
        elif set(old_c).difference(set(c)):
            # unmatched commutative terms
            ok = False
        elif any(sign(c[b]) != sign(old_c[b]) for b in old_c):
            # differences in sign
            ok = False
        if not ok:
            return rv

        if not old_c:
            cdid = None
        else:
            rat = []
            for (b, old_e) in old_c.items():
                c_e = c[b]
                rat.append(ndiv(c_e, old_e))
                if not rat[-1]:
                    return rv
            cdid = min(rat)

        if not old_nc:
            ncdid = None
            for i in range(len(nc)):
                nc[i] = rejoin(*nc[i])
        else:
            ncdid = 0  # number of nc replacements we did
            take = len(old_nc)  # how much to look at each time
            limit = cdid or S.Infinity  # max number that we can take
            failed = []  # failed terms will need subs if other terms pass
            i = 0
            while limit and i + take <= len(nc):
                hit = False

                # the bases must be equivalent in succession, and
                # the powers must be extractively compatible on the
                # first and last factor but equal inbetween.

                rat = []
                for j in range(take):
                    if nc[i + j][0] != old_nc[j][0]:
                        break
                    elif j == 0:
                        rat.append(ndiv(nc[i + j][1], old_nc[j][1]))
                    elif j == take - 1:
                        rat.append(ndiv(nc[i + j][1], old_nc[j][1]))
                    elif nc[i + j][1] != old_nc[j][1]:
                        break
                    else:
                        rat.append(1)
                    j += 1
                else:
                    ndo = min(rat)
                    if ndo:
                        if take == 1:
                            if cdid:
                                ndo = min(cdid, ndo)
                            nc[i] = Pow(new, ndo)*rejoin(nc[i][0],
                                    nc[i][1] - ndo*old_nc[0][1])
                        else:
                            ndo = 1

                            # the left residual

                            l = rejoin(nc[i][0], nc[i][1] - ndo*
                                    old_nc[0][1])

                            # eliminate all middle terms

                            mid = new

                            # the right residual (which may be the same as the middle if take == 2)

                            ir = i + take - 1
                            r = (nc[ir][0], nc[ir][1] - ndo*
                                 old_nc[-1][1])
                            if r[1]:
                                if i + take < len(nc):
                                    nc[i:i + take] = [l*mid, r]
                                else:
                                    r = rejoin(*r)
                                    nc[i:i + take] = [l*mid*r]
                            else:

                                # there was nothing left on the right

                                nc[i:i + take] = [l*mid]

                        limit -= ndo
                        ncdid += ndo
                        hit = True
                if not hit:

                    # do the subs on this failing factor

                    failed.append(i)
                i += 1
            else:

                if not ncdid:
                    return rv

                # although we didn't fail, certain nc terms may have
                # failed so we rebuild them after attempting a partial
                # subs on them

                failed.extend(range(i, len(nc)))
                for i in failed:
                    nc[i] = rejoin(*nc[i]).subs(old, new)

        # rebuild the expression

        if cdid is None:
            do = ncdid
        elif ncdid is None:
            do = cdid
        else:
            do = min(ncdid, cdid)

        margs = []
        for b in c:
            if b in old_c:

                # calculate the new exponent

                e = c[b] - old_c[b]*do
                margs.append(rejoin(b, e))
            else:
                margs.append(rejoin(b.subs(old, new), c[b]))
        if cdid and not ncdid:

            # in case we are replacing commutative with non-commutative,
            # we want the new term to come at the front just like the
            # rest of this routine

            margs = [Pow(new, cdid)] + margs
        return co_residual*Mul(*margs)*Mul(*nc)

    def _eval_nseries(self, x, n, logx):
        from sympy import powsimp
        terms = [t.nseries(x, n=n, logx=logx) for t in self.args]
        return powsimp(Mul(*terms).expand(), combine='exp', deep=True)

    def _eval_as_leading_term(self, x):
        return Mul(*[t.as_leading_term(x) for t in self.args])

    def _eval_conjugate(self):
        return Mul(*[t.conjugate() for t in self.args])

    def _eval_transpose(self):
        return Mul(*[t.transpose() for t in self.args[::-1]])

    def _eval_adjoint(self):
        return Mul(*[t.adjoint() for t in self.args[::-1]])

    def _sage_(self):
        s = 1
        for x in self.args:
            s *= x._sage_()
        return s

    def as_content_primitive(self, radical=False):
        """Return the tuple (R, self/R) where R is the positive Rational
        extracted from self.

        Examples
        ========

        >>> from sympy import sqrt
        >>> (-3*sqrt(2)*(2 - 2*sqrt(2))).as_content_primitive()
        (6, -sqrt(2)*(-sqrt(2) + 1))

        See docstring of Expr.as_content_primitive for more examples.
        """

        coef = S.One
        args = []
        for i, a in enumerate(self.args):
            c, p = a.as_content_primitive(radical=radical)
            coef *= c
            if p is not S.One:
                args.append(p)
        # don't use self._from_args here to reconstruct args
        # since there may be identical args now that should be combined
        # e.g. (2+2*x)*(3+3*x) should be (6, (1 + x)**2) not (6, (1+x)*(1+x))
        return coef, Mul(*args)

    def as_ordered_factors(self, order=None):
        """Transform an expression into an ordered list of factors.

        Examples
        ========

        >>> from sympy import sin, cos
        >>> from sympy.abc import x, y

        >>> (2*x*y*sin(x)*cos(x)).as_ordered_factors()
        [2, x, y, sin(x), cos(x)]

        """
        cpart, ncpart = self.args_cnc()
        cpart.sort(key=lambda expr: expr.sort_key(order=order))
        return cpart + ncpart

    @property
    def _sorted_args(self):
        return self.as_ordered_factors()
Ejemplo n.º 16
0
    def __new__(cls, expr, *symbols, **assumptions):

        expr = sympify(expr)
        if expr is S.NaN:
            return S.NaN

        if symbols:
            symbols = map(sympify, symbols)
            if not all(isinstance(s, Symbol) for s in symbols):
                raise NotImplementedError(
                    'Order at points other than 0 not supported.')
        else:
            symbols = list(expr.free_symbols)

        if expr.is_Order:
            v = set(expr.variables)
            symbols = v | set(symbols)
            if symbols == v:
                return expr
            symbols = list(symbols)

        elif symbols:

            symbols = list(set(symbols))

            if len(symbols) > 1:
                # XXX: better way?  We need this expand() to
                # workaround e.g: expr = x*(x + y).
                # (x*(x + y)).as_leading_term(x, y) currently returns
                # x*y (wrong order term!).  That's why we want to deal with
                # expand()'ed expr (handled in "if expr.is_Add" branch below).
                expr = expr.expand()

            if expr.is_Add:
                lst = expr.extract_leading_order(*symbols)
                expr = Add(*[f.expr for (e, f) in lst])

            elif expr:
                expr = expr.as_leading_term(*symbols)
                expr = expr.as_independent(*symbols, as_Add=False)[1]

                expr = expand_power_base(expr)
                expr = expand_log(expr)

                if len(symbols) == 1:
                    # The definition of O(f(x)) symbol explicitly stated that
                    # the argument of f(x) is irrelevant.  That's why we can
                    # combine some power exponents (only "on top" of the
                    # expression tree for f(x)), e.g.:
                    # x**p * (-x)**q -> x**(p+q) for real p, q.
                    x = symbols[0]
                    margs = list(
                        Mul.make_args(expr.as_independent(x, as_Add=False)[1]))

                    for i, t in enumerate(margs):
                        if t.is_Pow:
                            b, q = t.args
                            if b in (x, -x) and q.is_real and not q.has(x):
                                margs[i] = x**q
                            elif b.is_Pow and not b.exp.has(x):
                                b, r = b.args
                                if b in (x, -x) and r.is_real:
                                    margs[i] = x**(r * q)
                            elif b.is_Mul and b.args[0] is S.NegativeOne:
                                b = -b
                                if b.is_Pow and not b.exp.has(x):
                                    b, r = b.args
                                    if b in (x, -x) and r.is_real:
                                        margs[i] = x**(r * q)

                    expr = Mul(*margs)

        if expr is S.Zero:
            return expr

        if not expr.has(*symbols):
            expr = S.One

        # create Order instance:
        symbols.sort(key=cmp_to_key(Basic.compare))
        obj = Expr.__new__(cls, expr, *symbols, **assumptions)

        return obj
Ejemplo n.º 17
0
    def __new__(cls, expr, *symbols, **assumptions):

        expr = sympify(expr)
        if expr is S.NaN:
            return S.NaN

        if symbols:
            symbols = map(sympify, symbols)
            if not all(isinstance(s, Symbol) for s in symbols):
                raise NotImplementedError(
                    'Order at points other than 0 not supported.')
        else:
            symbols = list(expr.free_symbols)

        if expr.is_Order:
            v = set(expr.variables)
            symbols = v | set(symbols)
            if symbols == v:
                return expr
            symbols = list(symbols)

        elif symbols:

            symbols = list(set(symbols))

            if len(symbols) > 1:
                # XXX: better way?  We need this expand() to
                # workaround e.g: expr = x*(x + y).
                # (x*(x + y)).as_leading_term(x, y) currently returns
                # x*y (wrong order term!).  That's why we want to deal with
                # expand()'ed expr (handled in "if expr.is_Add" branch below).
                expr = expr.expand()

            if expr.is_Add:
                lst = expr.extract_leading_order(*symbols)
                expr = Add(*[f.expr for (e, f) in lst])

            elif expr:
                expr = expr.as_leading_term(*symbols)
                expr = expr.as_independent(*symbols, **dict(as_Add=False))[1]

                expr = expand_power_base(expr)
                expr = expand_log(expr)

                if len(symbols) == 1:
                    # The definition of O(f(x)) symbol explicitly stated that
                    # the argument of f(x) is irrelevant.  That's why we can
                    # combine some power exponents (only "on top" of the
                    # expression tree for f(x)), e.g.:
                    # x**p * (-x)**q -> x**(p+q) for real p, q.
                    x = symbols[0]
                    margs = list(Mul.make_args(
                        expr.as_independent(x, **dict(as_Add=False))[1]))

                    for i, t in enumerate(margs):
                        if t.is_Pow:
                            b, q = t.args
                            if b in (x, -x) and q.is_real and not q.has(x):
                                margs[i] = x**q
                            elif b.is_Pow and not b.exp.has(x):
                                b, r = b.args
                                if b in (x, -x) and r.is_real:
                                    margs[i] = x**(r*q)
                            elif b.is_Mul and b.args[0] is S.NegativeOne:
                                b = -b
                                if b.is_Pow and not b.exp.has(x):
                                    b, r = b.args
                                    if b in (x, -x) and r.is_real:
                                        margs[i] = x**(r*q)

                    expr = Mul(*margs)

        if expr is S.Zero:
            return expr

        if not expr.has(*symbols):
            expr = S.One

        # create Order instance:
        symbols.sort(key=cmp_to_key(Basic.compare))
        obj = Expr.__new__(cls, expr, *symbols, **assumptions)

        return obj
Ejemplo n.º 18
0
from __future__ import print_function, division

from collections import defaultdict

from sympy.core.basic import C, Basic
from sympy.core.compatibility import cmp_to_key, reduce, is_sequence
from sympy.core.singleton import S
from sympy.core.operations import AssocOp
from sympy.core.cache import cacheit
from sympy.core.numbers import ilcm, igcd
from sympy.core.expr import Expr


# Key for sorting commutative args in canonical order
_args_sortkey = cmp_to_key(Basic.compare)
def _addsort(args):
    # in-place sorting of args
    args.sort(key=_args_sortkey)


def _unevaluated_Add(*args):
    """Return a well-formed unevaluated Add: Numbers are collected and
    put in slot 0 and args are sorted. Use this when args have changed
    but you still want to return an unevaluated Add.

    Examples
    ========

    >>> from sympy.core.add import _unevaluated_Add as uAdd
    >>> from sympy import S, Add
    >>> from sympy.abc import x, y
Ejemplo n.º 19
0
class LatexPrinter(Printer):
    """
    A printer class which converts an expression into its LaTeX equivalent. This
    class extends the LatexPrinter class currently in sympy in the following ways:

        1. Variable and function names can now encode multiple Greek symbols,
           number, Greek, and roman super and subscripts and accents plus bold
           math in an alphanumeric ASCII string consisting of ``[A-Za-z0-9_]``
           symbols

            1 - Accents and bold math are implemented in reverse notation. For
                example if you wished the LaTeX output to be ``\bm{\hat{\sigma}}``
                you would give the variable the name sigmahatbm.
            2 - Subscripts are denoted by a single underscore and superscripts
                by a double underscore so that ``A_{\\rho\\beta}^{25}`` would be
                input as A_rhobeta__25.

        2. Some standard function names have been improved such as asin is now
           denoted by Sin^{-1} and log by ln.

        3. Several LaTeX formats for multivectors are available:
            1 - Print multivector on one line

            2 - Print each grade of multivector on one line

            3 - Print each base of multivector on one line

        4. A LaTeX output for numpy arrays containing sympy expressions is
           implemented for up to a three dimensional array.

        5. LaTeX formatting for raw LaTeX, eqnarray, and array is available
           in simple output strings.

            1 - The delimiter for raw LaTeX input is '%'.  The raw input starts
                on the line where '%' is first encountered and continues until
                the next line where '%' is encountered. It does not matter where
                '%' is in the line.
            2 - The delimiter for eqnarray input is '@'. The rules are the same
                as for raw input except that '=' in the first line is replaced
                be '&=&' and '\\begin{eqnarray*}' is added before the first line
                and '\end{eqnarray*}' to after the last line in the group of
                lines.
            3 - The delimiter for array input is '#'. The rules are the same
                as for raw input except that '\\begin{equation*}' is added before
                the first line and '\end{equation*}' to after the last line in
                the group of lines.

        6. Additional formats for partial derivatives:

            0 - Same as sympy latex module

            1 - Use subscript notation with partial symbol to indicate which
                variable the differentiation is with respect to.  Symbol is of
                form \partial_{differentiation variable}
    """

    #printmethod ='_latex_ex'
    sym_fmt = 0
    fct_fmt = 0
    pdiff_fmt = 0
    mv_fmt = 0
    str_fmt = 1
    LaTeX_flg = False

    mode = ('_', '^')

    fmt_dict = {'sym': 0, 'fct': 0, 'pdiff': 0, 'mv': 0, 'str': 1}

    fct_dict = {'sin':'sin','cos':'cos','tan':'tan','cot':'cot',\
                'asin':'Sin^{-1}','acos':'Cos^{-1}',\
                'atan':'Tan^{-1}','acot':'Cot^{-1}',\
                'sinh':'sinh','cosh':'cosh','tanh':'tanh','coth':'coth',\
                'asinh':'Sinh^{-1}','acosh':'Cosh^{-1}',
                'atanh':'Tanh^{-1}','acoth':'Coth^{-1}',\
                'sqrt':'sqrt','exp':'exp','log':'ln'}

    fct_dict_keys = fct_dict.keys()

    greek_keys = sorted(('alpha','beta','gamma','delta','varepsilon','epsilon','zeta',\
                         'vartheta','theta','iota','kappa','lambda','mu','nu','xi',\
                         'varpi','pi','rho','varrho','varsigma','sigma','tau','upsilon',\
                         'varphi','phi','chi','psi','omega','Gamma','Delta','Theta',\
                         'Lambda','Xi','Pi','Sigma','Upsilon','Phi','Psi','Omega','partial',\
                         'nabla','eta'),key=cmp_to_key(len_cmp))

    accent_keys = sorted(('hat','check','dot','breve','acute','ddot','grave','tilde',\
                          'mathring','bar','vec','bm','prm','abs'),key=cmp_to_key(len_cmp))

    greek_cnt = 0
    greek_dict = {}
    accent_cnt = 0
    accent_dict = {}

    preamble = '\\documentclass[10pt,letter,fleqn]{report}\n'+\
               '\\pagestyle{empty}\n'+\
               '\\usepackage[latin1]{inputenc}\n'+\
               '\\usepackage[dvips,landscape,top=1cm,nohead,nofoot]{geometry}\n'+\
               '\\usepackage{amsmath}\n'+\
               '\\usepackage{bm}\n'+\
               '\\usepackage{amsfonts}\n'+\
               '\\usepackage{amssymb}\n'+\
               '\\setlength{\\parindent}{0pt}\n'+\
               '\\newcommand{\\bfrac}[2]{\\displaystyle\\frac{#1}{#2}}\n'+\
               '\\newcommand{\\lp}{\\left (}\n'+\
               '\\newcommand{\\rp}{\\right )}\n'+\
               '\\newcommand{\\half}{\\frac{1}{2}}\n'+\
               '\\newcommand{\\llt}{\\left <}\n'+\
               '\\newcommand{\\rgt}{\\right >}\n'+\
               '\\newcommand{\\abs}[1]{\\left |{#1}\\right | }\n'+\
               '\\newcommand{\\pdiff}[2]{\\bfrac{\\partial {#1}}{\\partial {#2}}}\n'+\
               '\\newcommand{\\lbrc}{\\left \\{}\n'+\
               '\\newcommand{\\rbrc}{\\right \\}}\n'+\
               '\\newcommand{\\W}{\\wedge}\n'+\
               "\\newcommand{\\prm}[1]{{#1}'}\n"+\
               '\\newcommand{\\ddt}[1]{\\bfrac{d{#1}}{dt}}\n'+\
               '\\newcommand{\\R}{\\dagger}\n'+\
               '\\begin{document}\n'
    postscript = '\\end{document}\n'

    @staticmethod
    def latex_bases():
        """
        Generate LaTeX strings for multivector bases
        """
        if type(sympy.galgebra.GA.MV.basislabel_lst) == types.IntType:
            sys.stderr.write(
                'MV.setup() must be executed before LatexPrinter.format()!\n')
            sys.exit(1)
        LatexPrinter.latexbasis_lst = [['']]
        for grades in sympy.galgebra.GA.MV.basislabel_lst[1:]:
            grades_lst = []
            for grade in grades:
                grade_lst = []
                for base in grade:
                    latex_base = LatexPrinter.extended_symbol(base)
                    grade_lst.append(latex_base)
                grades_lst.append(grade_lst)
            LatexPrinter.latexbasis_lst.append(grades_lst)
        return

    @staticmethod
    def build_base(igrade, iblade, bld_flg):
        if igrade == 0:
            return ('')
        base_lst = LatexPrinter.latexbasis_lst[igrade][iblade]
        if len(base_lst) == 1:
            return (base_lst[0])
        base_str = ''
        for base in base_lst[:-1]:
            if bld_flg:
                base_str += base + '\\W '
            else:
                base_str += base
        base_str += base_lst[-1]
        return (base_str)

    @staticmethod
    def format(sym=0, fct=0, pdiff=0, mv=0):
        LatexPrinter.LaTeX_flg = True
        LatexPrinter.fmt_dict['sym'] = sym
        LatexPrinter.fmt_dict['fct'] = fct
        LatexPrinter.fmt_dict['pdiff'] = pdiff
        LatexPrinter.fmt_dict['mv'] = mv
        LatexPrinter.fmt_dict['str'] = 1
        if sympy.galgebra.GA.MV.is_setup:
            LatexPrinter.latex_bases()
        LatexPrinter.redirect()
        return

    @staticmethod
    def str_basic(in_str):
        if not LatexPrinter.LaTeX_flg:
            return (str(in_str))
        Basic.__str__ = LatexPrinter.Basic__str__
        out_str = str(in_str)
        Basic.__str__ = LaTeX
        return (out_str)

    @staticmethod
    def redirect():
        LatexPrinter.Basic__str__ = Basic.__str__
        LatexPrinter.MV__str__ = sympy.galgebra.GA.MV.__str__
        LatexPrinter.stdout = sys.stdout
        sys.stdout = StringIO.StringIO()
        Basic.__str__ = LaTeX
        sympy.galgebra.GA.MV.__str__ = LaTeX
        return

    @staticmethod
    def restore():
        LatexPrinter_stdout = sys.stdout
        LatexPrinter_Basic__str__ = Basic.__str__
        LatexPrinter_MV__str__ = sympy.galgebra.GA.MV.__str__

        sys.stdout = LatexPrinter.stdout
        Basic.__str__ = LatexPrinter.Basic__str__
        sympy.galgebra.GA.MV.__str__ = LatexPrinter.MV__str__

        LatexPrinter.stdout = LatexPrinter_stdout
        LatexPrinter.Basic__str__ = LatexPrinter_Basic__str__
        LatexPrinter.MV__str__ = LatexPrinter_MV__str__
        return

    @staticmethod
    def format_str(fmt='0 0 0 0'):
        fmt_lst = fmt.split()
        if '=' not in fmt:
            LatexPrinter.fmt_dict['sym'] = int(fmt_lst[0])
            LatexPrinter.fmt_dict['fct'] = int(fmt_lst[1])
            LatexPrinter.fmt_dict['pdiff'] = int(fmt_lst[2])
            LatexPrinter.fmt_dict['mv'] = int(fmt_lst[3])
        else:
            for fmt in fmt_lst:
                x = fmt.split('=')
                LatexPrinter.fmt_dict[x[0]] = int(x[1])

        if LatexPrinter.LaTeX_flg == False:
            if sympy.galgebra.GA.MV.is_setup:
                LatexPrinter.latex_bases()
            LatexPrinter.redirect()
            LatexPrinter.LaTeX_flg = True
        return

    @staticmethod
    def append_body(xstr):
        if LatexPrinter.body_flg:
            LatexPrinter.body += xstr
            return ('')
        else:
            return (xstr[:-1])

    @staticmethod
    def tokenize_greek(name_str):
        for sym in LatexPrinter.greek_keys:
            isym = name_str.find(sym)
            if isym > -1:
                keystr = '@' + str(LatexPrinter.greek_cnt)
                LatexPrinter.greek_cnt += 1
                LatexPrinter.greek_dict[keystr] = sym
                name_str = name_str.replace(sym, keystr)
        return (name_str)

    @staticmethod
    def tokenize_accents(name_str):
        for sym in LatexPrinter.accent_keys:
            if name_str.find(sym) > -1:
                keystr = '#' + str(LatexPrinter.accent_cnt) + '#'
                LatexPrinter.accent_cnt += 1
                LatexPrinter.accent_dict[keystr] = '\\' + sym
                name_str = name_str.replace(sym, keystr)
        return (name_str)

    @staticmethod
    def replace_greek_tokens(name_str):
        if name_str.find('@') == -1:
            return (name_str)
        for token in LatexPrinter.greek_dict.keys():
            name_str = name_str.replace(
                token, '{\\' + LatexPrinter.greek_dict[token] + '}')
        LatexPrinter.greek_cnt = 0
        LatexPrinter.greek_dict = {}
        return (name_str)

    @staticmethod
    def replace_accent_tokens(name_str):
        tmp_lst = name_str.split('#')
        name_str = tmp_lst[0]
        if len(tmp_lst) == 1:
            return (name_str)
        for x in tmp_lst[1:]:
            if x != '':
                name_str = '{}' + LatexPrinter.accent_dict[
                    '#' + x + '#'] + '{' + name_str + '}'
        LatexPrinter.accent_cnt = 0
        LatexPrinter.accent_dict = {}
        return (name_str)

    @staticmethod
    def extended_symbol(name_str):
        name_str = LatexPrinter.tokenize_greek(name_str)
        tmp_lst = name_str.split('_')
        subsup_str = ''
        sym_str = tmp_lst[0]
        sym_str = LatexPrinter.tokenize_accents(sym_str)
        sym_str = LatexPrinter.replace_accent_tokens(sym_str)
        if len(tmp_lst) > 1:
            imode = 0
            for x in tmp_lst[1:]:
                if x == '':
                    imode = (imode + 1) % 2
                else:
                    subsup_str += LatexPrinter.mode[imode] + '{' + x + '}'
                    #subsup_str += LatexPrinter.mode[imode]+x+' '
                    imode = (imode + 1) % 2
        name_str = sym_str + subsup_str
        name_str = LatexPrinter.replace_greek_tokens(name_str)
        return (name_str)

    def coefficient(self, coef, first_flg):
        if isinstance(coef, C.AssocOp) and isinstance(-coef, C.AssocOp):
            coef_str = r"\lp %s\rp " % self._print(coef)
        else:
            coef_str = self._print(coef)
        if first_flg:
            first_flg = False
            if coef_str[0] == '+':
                coef_str = coef_str[1:]
        else:
            if coef_str[0] != '-':
                if coef_str[0] != '+':
                    coef_str = '+' + coef_str
        if coef_str in ('1', '+1', '-1'):
            if coef_str == '1':
                coef_str = ''
            else:
                coef_str = coef_str[0]
        return (coef_str, first_flg)

    def __init__(self, inline=True):
        Printer.__init__(self)
        self._inline = inline

    def doprint(self, expr):
        tex = Printer.doprint(self, expr)
        xstr = ''

        if self._inline:
            if LatexPrinter.fmt_dict['fct'] == 1:
                xstr = r"%s" % tex
            else:
                xstr = r"$%s$" % tex
        else:
            xstr = r"\begin{equation*}%s\end{equation*}" % tex
        return (xstr)

    def _needs_brackets(self, expr):
        return not ((expr.is_Integer and expr.is_nonnegative) or expr.is_Atom)

    def _do_exponent(self, expr, exp):
        if exp is not None:
            return r"\left(%s\right)^{%s}" % (expr, exp)
        else:
            return expr

    def _print_Add(self, expr):
        tex = str(self._print(expr.args[0]))

        for term in expr.args[1:]:
            coeff = term.as_coeff_mul()[0]

            if coeff.is_negative:
                tex += r" %s" % self._print(term)
            else:
                tex += r" + %s" % self._print(term)

        return tex

    def _print_Mul(self, expr):
        coeff, tail = expr.as_two_terms()

        if not coeff.is_negative:
            tex = ""
        else:
            coeff = -coeff
            tex = "- "

        numer, denom = fraction(tail)

        def convert(terms):
            product = []

            if not terms.is_Mul:
                return str(self._print(terms))
            else:
                for term in terms.args:
                    pretty = self._print(term)

                    if term.is_Add:
                        product.append(r"\left(%s\right)" % pretty)
                    else:
                        product.append(str(pretty))

                return r" ".join(product)

        if denom is S.One:
            if coeff is not S.One:
                tex += str(self._print(coeff)) + " "

            if numer.is_Add:
                tex += r"\left(%s\right)" % convert(numer)
            else:
                tex += r"%s" % convert(numer)
        else:
            if numer is S.One:
                if coeff.is_Integer:
                    numer *= coeff.p
                elif coeff.is_Rational:
                    if coeff.p != 1:
                        numer *= coeff.p

                    denom *= coeff.q
                elif coeff is not S.One:
                    tex += str(self._print(coeff)) + " "
            else:
                if coeff.is_Rational and coeff.p == 1:
                    denom *= coeff.q
                elif coeff is not S.One:
                    tex += str(self._print(coeff)) + " "

            tex += r"\frac{%s}{%s}" % \
                (convert(numer), convert(denom))

        return tex

    def _print_Pow(self, expr):
        if expr.exp.is_Rational and expr.exp.q == 2:
            base, exp = self._print(expr.base), abs(expr.exp.p)

            if exp == 1:
                tex = r"\sqrt{%s}" % base
            else:
                tex = r"\sqrt[%s]{%s}" % (exp, base)

            if expr.exp.is_negative:
                return r"\frac{1}{%s}" % tex
            else:
                return tex
        else:
            if expr.base.is_Function:
                return self._print(expr.base, self._print(expr.exp))
            else:
                if expr.exp == S.NegativeOne:
                    #solves issue 1030
                    #As Mul always simplify 1/x to x**-1
                    #The objective is achieved with this hack
                    #first we get the latex for -1 * expr,
                    #which is a Mul expression
                    tex = self._print(S.NegativeOne * expr).strip()
                    #the result comes with a minus and a space, so we remove
                    if tex[:1] == "-":
                        return tex[1:].strip()
                if self._needs_brackets(expr.base):
                    tex = r"\left(%s\right)^{%s}"
                else:
                    tex = r"{%s}^{%s}"

                return tex % (self._print(expr.base), self._print(expr.exp))

    def _print_Derivative(self, expr):
        dim = len(expr.variables)

        if dim == 1:
            if LatexPrinter.fmt_dict['pdiff'] == 1:
                tex = r'\partial_{%s}' % self._print(expr.variables[0])
            else:
                tex = r"\frac{\partial}{\partial %s}" % self._print(
                    expr.variables[0])
        else:
            multiplicity, i, tex = [], 1, ""
            current = expr.variables[0]
            for symbol in expr.variables[1:]:
                if symbol == current:
                    i = i + 1
                else:
                    multiplicity.append((current, i))
                    current, i = symbol, 1
            else:
                multiplicity.append((current, i))

            if LatexPrinter.fmt_dict['pdiff'] == 1:
                for x, i in multiplicity:
                    if i == 1:
                        tex += r"\partial_{%s}" % self._print(x)
                    else:
                        tex += r"\partial^{%s}_{%s}" % (i, self._print(x))
            else:
                for x, i in multiplicity:
                    if i == 1:
                        tex += r"\partial %s" % self._print(x)
                    else:
                        tex += r"\partial^{%s} %s" % (i, self._print(x))

                tex = r"\frac{\partial^{%s}}{%s} " % (dim, tex)

        if isinstance(expr.expr, C.AssocOp):
            return r"%s\left(%s\right)" % (tex, self._print(expr.expr))
        else:
            return r"%s %s" % (tex, self._print(expr.expr))

    def _print_Integral(self, expr):
        tex, symbols = "", []

        for symbol, limits in reversed(expr.limits):
            tex += r"\int"

            if limits is not None:
                if not self._inline:
                    tex += r"\limits"

                tex += "_{%s}^{%s}" % (self._print(
                    limits[0]), self._print(limits[1]))

            symbols.insert(0, "d%s" % self._print(symbol))

        return r"%s %s\,%s" % (tex, str(self._print(
            expr.function)), " ".join(symbols))

    def _print_Limit(self, expr):
        tex = r"\lim_{%s \to %s}" % (self._print(
            expr.var), self._print(expr.varlim))

        if isinstance(expr.expr, C.AssocOp):
            return r"%s\left(%s\right)" % (tex, self._print(expr.expr))
        else:
            return r"%s %s" % (tex, self._print(expr.expr))

    def _print_Function(self, expr, exp=None):
        func = expr.func.__name__

        if hasattr(self, '_print_' + func):
            return getattr(self, '_print_' + func)(expr, exp)
        else:
            args = [str(self._print(arg)) for arg in expr.args]

            if LatexPrinter.fmt_dict['fct'] == 1:
                if func in LatexPrinter.fct_dict_keys:
                    if exp is not None:
                        if func in accepted_latex_functions:
                            name = r"\%s^{%s}" % (LatexPrinter.fct_dict[func],
                                                  exp)
                        else:
                            name = r"\operatorname{%s}^{%s}" % (
                                LatexPrinter.fct_dict[func], exp)
                    else:
                        if LatexPrinter.fct_dict[
                                func] in accepted_latex_functions:
                            name = r"\%s" % LatexPrinter.fct_dict[func]
                        else:
                            name = r"\operatorname{%s}" % LatexPrinter.fct_dict[
                                func]
                    name += r"\left(%s\right)" % ",".join(args)
                    return name
                else:
                    func = self.print_Symbol_name(func)
                    if exp is not None:
                        name = r"{%s}^{%s}" % (func, exp)
                    else:
                        name = r"{%s}" % func
                    return name
            else:
                if exp is not None:
                    if func in accepted_latex_functions:
                        name = r"\%s^{%s}" % (func, exp)
                    else:
                        name = r"\operatorname{%s}^{%s}" % (func, exp)
                else:
                    if func in accepted_latex_functions:
                        name = r"\%s" % func
                    else:
                        name = r"\operatorname{%s}" % func
                return name + r"\left(%s\right)" % ",".join(args)

    def _print_floor(self, expr, exp=None):
        tex = r"\lfloor{%s}\rfloor" % self._print(expr.args[0])

        if exp is not None:
            return r"%s^{%s}" % (tex, exp)
        else:
            return tex

    def _print_ceiling(self, expr, exp=None):
        tex = r"\lceil{%s}\rceil" % self._print(expr.args[0])

        if exp is not None:
            return r"%s^{%s}" % (tex, exp)
        else:
            return tex

    def _print_abs(self, expr, exp=None):
        tex = r"\lvert{%s}\rvert" % self._print(expr.args[0])

        if exp is not None:
            return r"%s^{%s}" % (tex, exp)
        else:
            return tex

    def _print_re(self, expr, exp=None):
        if self._needs_brackets(expr.args[0]):
            tex = r"\Re\left(%s\right)" % self._print(expr.args[0])
        else:
            tex = r"\Re{%s}" % self._print(expr.args[0])

        return self._do_exponent(tex, exp)

    def _print_im(self, expr, exp=None):
        if self._needs_brackets(expr.args[0]):
            tex = r"\Im\left(%s\right)" % self._print(expr.args[0])
        else:
            tex = r"\Im{%s}" % self._print(expr.args[0])

        return self._do_exponent(tex, exp)

    def _print_conjugate(self, expr, exp=None):
        tex = r"\overline{%s}" % self._print(expr.args[0])

        if exp is not None:
            return r"%s^{%s}" % (tex, exp)
        else:
            return tex

    def _print_exp(self, expr, exp=None):
        tex = r"{e}^{%s}" % self._print(expr.args[0])
        return self._do_exponent(tex, exp)

    def _print_gamma(self, expr, exp=None):
        tex = r"\left(%s\right)" % self._print(expr.args[0])

        if exp is not None:
            return r"\Gamma^{%s}%s" % (exp, tex)
        else:
            return r"\Gamma%s" % tex

    def _print_factorial(self, expr, exp=None):
        x = expr.args[0]
        if self._needs_brackets(x):
            tex = r"\left(%s\right)!" % self._print(x)
        else:
            tex = self._print(x) + "!"

        if exp is not None:
            return r"%s^{%s}" % (tex, exp)
        else:
            return tex

    def _print_binomial(self, expr, exp=None):
        tex = r"{{%s}\choose{%s}}" % (self._print(expr[0]), self._print(
            expr[1]))

        if exp is not None:
            return r"%s^{%s}" % (tex, exp)
        else:
            return tex

    def _print_RisingFactorial(self, expr, exp=None):
        tex = r"{\left(%s\right)}^{\left(%s\right)}" % \
            (self._print(expr[0]), self._print(expr[1]))

        return self._do_exponent(tex, exp)

    def _print_FallingFactorial(self, expr, exp=None):
        tex = r"{\left(%s\right)}_{\left(%s\right)}" % \
            (self._print(expr[0]), self._print(expr[1]))

        return self._do_exponent(tex, exp)

    def _print_Rational(self, expr):
        if expr.q != 1:
            sign = ""
            p = expr.p
            if expr.p < 0:
                sign = "- "
                p = -p
            return r"%s\frac{%d}{%d}" % (sign, p, expr.q)
        else:
            return self._print(expr.p)

    def _print_Infinity(self, expr):
        return r"\infty"

    def _print_NegativeInfinity(self, expr):
        return r"-\infty"

    def _print_ComplexInfinity(self, expr):
        return r"\tilde{\infty}"

    def _print_ImaginaryUnit(self, expr):
        return r"\mathbf{\imath}"

    def _print_NaN(self, expr):
        return r"\bot"

    def _print_Pi(self, expr):
        return r"\pi"

    def _print_Exp1(self, expr):
        return r"e"

    def _print_EulerGamma(self, expr):
        return r"\gamma"

    def _print_Order(self, expr):
        return r"\\mathcal{O}\left(%s\right)" % \
            self._print(expr.args[0])

    @staticmethod
    def print_Symbol_name(name_str):
        if len(name_str) == 1:
            return (name_str)
        if LatexPrinter.fmt_dict['sym'] == 1:
            return LatexPrinter.extended_symbol(name_str)
        else:
            return (name_str)

            #convert trailing digits to subscript
            m = regrep.match('(^[a-zA-Z]+)([0-9]+)$', name_str)
            if m is not None:
                name, sub = m.groups()
                tex = self._print_Symbol(Symbol(name))
                tex = "%s_{%s}" % (tex, sub)
                return tex

            # insert braces to expresions containing '_' or '^'
            m = regrep.match('(^[a-zA-Z0-9]+)([_\^]{1})([a-zA-Z0-9]+)$',
                             name_str)
            if m is not None:
                name, sep, rest = m.groups()
                tex = self._print_Symbol(Symbol(name))
                tex = "%s%s{%s}" % (tex, sep, rest)
                return tex

            greek = set([
                'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', 'eta',
                'theta', 'iota', 'kappa', 'lambda', 'mu', 'nu', 'xi',
                'omicron', 'pi', 'rho', 'sigma', 'tau', 'upsilon', 'phi',
                'chi', 'psi', 'omega'
            ])

            other = set([
                'aleph', 'beth', 'daleth', 'gimel', 'ell', 'eth', 'hbar',
                'hslash', 'mho'
            ])

            if name_str.lower() in greek:
                return "\\" + name_str
            elif name_str in other:
                return "\\" + name_str
            else:
                return name_str

    def _print_Symbol(self, expr):
        return LatexPrinter.print_Symbol_name(expr.name)

    def _print_str(self, expr):
        if LatexPrinter.fmt_dict['str'] > 0:
            expr = expr.replace('^', '{\\wedge}')
            expr = expr.replace('|', '{\\cdot}')
            expr = expr.replace('__', '^')
        return (expr)

    def _print_ndarray(self, expr):
        shape = numpy.shape(expr)
        ndim = len(shape)
        expr_str = ''

        if ndim == 1:
            expr_str += '#\\left [ \\begin{array}{' + shape[0] * 'c' + '}  \n'
            for col in expr:
                expr_str += self._print(col) + ' & '
            expr_str = expr_str[:-2] + '\n\\end{array}\\right ]#\n'
            return (expr_str)

        if ndim == 2:
            expr_str += '#\\left [ \\begin{array}{' + shape[1] * 'c' + '}  \n'
            for row in expr[:-1]:
                for xij in row[:-1]:
                    expr_str += self._print(xij) + ' & '
                expr_str += self._print(row[-1]) + ' \\\\ \n'
            for xij in expr[-1][:-1]:
                expr_str += self._print(xij) + ' & '
            expr_str += self._print(
                expr[-1][-1]) + '\n \\end{array} \\right ] #\n'
            return (expr_str)

        if ndim == 3:
            expr_str = '#\\left \\{ \\begin{array}{' + shape[0] * 'c' + '} \n'
            for x in expr[:-1]:
                xstr = self._print(x).replace('#', '')
                expr_str += xstr + ' , & '
            xstr = self._print(expr[-1]).replace('#', '')
            expr_str += xstr + '\n\\end{array} \\right \\}#\n'
        return (expr_str)

    def _print_MV(self, expr):
        igrade = 0
        MV_str = ''
        line_lst = []
        first_flg = True
        for grade in expr.mv:
            if type(grade) != types.IntType:
                if type(grade) != types.IntType:
                    ibase = 0
                    for base in grade:
                        if base != 0:
                            tmp = Symbol('XYZW')
                            base_str = str(base * tmp)
                            if base_str[0] != '-':
                                base_str = '+' + base_str
                            base_str = base_str.replace('- ', '-')
                            if base_str[1:5] == 'XYZW':
                                base_str = base_str.replace('XYZW', '')
                            else:
                                base_str = base_str.replace('XYZW', '1')
                            MV_str += base_str+\
                                      LatexPrinter.build_base(igrade,ibase,expr.bladeflg)
                            if LatexPrinter.fmt_dict['mv'] == 3:
                                line_lst.append(MV_str)
                                MV_str = ''
                        ibase += 1
                if LatexPrinter.fmt_dict['mv'] == 2:
                    if MV_str != '':
                        line_lst.append(MV_str)
                        MV_str = ''
            igrade += 1
        n_lines = len(line_lst)
        if MV_str == '':
            if n_lines > 0 and line_lst[0][0] == '+':
                line_lst[0] = line_lst[0][1:]
        else:
            if MV_str[0] == '+':
                MV_str = MV_str[1:]
        if n_lines == 1:
            MV_str = line_lst[0]
            n_lines = 0
        if LatexPrinter.fmt_dict['mv'] >= 2:
            MV_str = '@' + line_lst[0] + ' \\\\ \n'
            for line in line_lst[1:-1]:
                MV_str += '& ' + line + ' \\\\ \n'
            MV_str += '& ' + line_lst[-1] + '@\n'
        if MV_str == '':
            MV_str = '0'
        if expr.name != '':
            MV_str = LatexPrinter.extended_symbol(expr.name) + ' = ' + MV_str
        return (MV_str)

    def _print_OMV(self, expr):
        igrade = 0
        MV_str = ''
        line_lst = []
        first_flg = True
        for grade in expr.mv:
            if type(grade) is not None:
                if type(grade) is not None:
                    ibase = 0
                    for base in grade:
                        if base != 0:
                            tmp = Symbol('XYZW')
                            base_str = str(base * tmp)
                            if base_str[0] != '-':
                                base_str = '+' + base_str
                            base_str = base_str.replace('- ', '-')
                            if base_str[1:5] == 'XYZW':
                                base_str = base_str.replace('XYZW', '')
                            else:
                                base_str = base_str.replace('XYZW', '1')
                            MV_str += base_str+\
                                      LatexPrinter.build_base(igrade,ibase,expr.bladeflg)
                            if LatexPrinter.fmt_dict['mv'] == 3:
                                line_lst.append(MV_str)
                                MV_str = ''
                        ibase += 1
                if LatexPrinter.fmt_dict['mv'] == 2:
                    if MV_str != '':
                        line_lst.append(MV_str)
                        MV_str = ''
            igrade += 1
        n_lines = len(line_lst)
        if MV_str == '':
            if n_lines > 0 and line_lst[0][0] == '+':
                line_lst[0] = line_lst[0][1:]
        else:
            if MV_str[0] == '+':
                MV_str = MV_str[1:]
        if n_lines == 1:
            MV_str = line_lst[0]
            n_lines = 0
        if LatexPrinter.fmt_dict['mv'] >= 2:
            MV_str = '@' + line_lst[0] + ' \\\\ \n'
            for line in line_lst[1:-1]:
                MV_str += '& ' + line + ' \\\\ \n'
            MV_str += '& ' + line_lst[-1] + '@\n'
        if MV_str == '':
            MV_str = '0'
        if expr.name != '':
            MV_str = LatexPrinter.extended_symbol(expr.name) + ' = ' + MV_str
        return (MV_str)

    def _print_Relational(self, expr):
        charmap = {
            "==": "=",
            "<": "<",
            "<=": r"\leq",
            "!=": r"\neq",
        }

        return "%s %s %s" % (self._print(
            expr.lhs), charmap[expr.rel_op], self._print(expr.rhs))

    def _print_Matrix(self, expr):
        lines = []

        for line in range(expr.lines):  # horrible, should be 'rows'
            lines.append(" & ".join([self._print(i) for i in expr[line, :]]))

        if self._inline:
            tex = r"\left(\begin{smallmatrix}%s\end{smallmatrix}\right)"
        else:
            tex = r"\begin{pmatrix}%s\end{pmatrix}"

        return tex % r"\\".join(lines)

    def _print_tuple(self, expr):
        return r"\begin{pmatrix}%s\end{pmatrix}" % \
            r", & ".join([ self._print(i) for i in expr ])

    def _print_list(self, expr):
        return r"\begin{bmatrix}%s\end{bmatrix}" % \
            r", & ".join([ self._print(i) for i in expr ])

    def _print_dict(self, expr):
        items = []

        keys = expr.keys()
        keys.sort(key=default_sort_key)
        for key in keys:
            val = expr[key]
            items.append("%s : %s" % (self._print(key), self._print(val)))

        return r"\begin{Bmatrix}%s\end{Bmatrix}" % r", & ".join(items)

    def _print_DiracDelta(self, expr):
        if len(expr.args) == 1 or expr.args[1] == 0:
            tex = r"\delta\left(%s\right)" % self._print(expr.args[0])
        else:
            tex = r"\delta^{\left( %s \right)}\left( %s \right)" % (\
            self._print(expr.args[1]), self._print(expr.args[0]))
        return tex
Ejemplo n.º 20
0
Archivo: str.py Proyecto: Jerryy/sympy
 def _print_LatticeOp(self, expr):
     args = sorted(expr.args, key=cmp_to_key(expr._compare_pretty))
     return expr.func.__name__ + "(%s)"%", ".join(self._print(arg) for arg in args)
Ejemplo n.º 21
0
 def _print_LatticeOp(self, expr):
     args = sorted(expr.args, key=cmp_to_key(expr._compare_pretty))
     return expr.func.__name__ + "(%s)" % ", ".join(
         self._print(arg) for arg in args)
Ejemplo n.º 22
0
from __future__ import print_function, division

from collections import defaultdict

from sympy.core.basic import C, Basic
from sympy.core.compatibility import cmp_to_key, reduce, is_sequence
from sympy.core.singleton import S
from sympy.core.operations import AssocOp
from sympy.core.cache import cacheit
from sympy.core.numbers import ilcm, igcd
from sympy.core.expr import Expr


# Key for sorting commutative args in canonical order
_args_sortkey = cmp_to_key(Basic.compare)
def _addsort(args):
    # in-place sorting of args
    args.sort(key=_args_sortkey)


def _unevaluated_Add(*args):
    """Return a well-formed unevaluated Add: Numbers are collected and
    put in slot 0 and args are sorted. Use this when args have changed
    but you still want to return an unevaluated Add.

    Examples
    ========

    >>> from sympy.core.add import _unevaluated_Add as uAdd
    >>> from sympy import S, Add
    >>> from sympy.abc import x, y
Ejemplo n.º 23
0
    def __new__(cls, expr, *symbols, **assumptions):

        expr = sympify(expr).expand()
        if expr is S.NaN:
            return S.NaN

        if symbols:
            symbols = map(sympify, symbols)
            if not all(isinstance(s, Symbol) for s in symbols):
                raise NotImplementedError(
                    'Order at points other than 0 not supported.')
        else:
            symbols = list(expr.free_symbols)

        if expr.is_Order:
            v = set(expr.variables)
            symbols = v | set(symbols)
            if symbols == v:
                return expr
            symbols = list(symbols)

        elif symbols:

            symbols = list(set(symbols))

            if expr.is_Add:
                lst = expr.extract_leading_order(*symbols)
                expr = Add(*[f.expr for (e, f) in lst])

            elif expr:
                if len(symbols) > 1 or expr.is_commutative is False:
                    # TODO
                    # We cannot use compute_leading_term because that only
                    # works in one symbol.
                    expr = expr.as_leading_term(*symbols)
                else:
                    expr = expr.compute_leading_term(symbols[0])

                margs = list(Mul.make_args(expr.as_independent(*symbols)[1]))

                if len(symbols) == 1:
                    # The definition of O(f(x)) symbol explicitly stated that
                    # the argument of f(x) is irrelevant.  That's why we can
                    # combine some power exponents (only "on top" of the
                    # expression tree for f(x)), e.g.:
                    # x**p * (-x)**q -> x**(p+q) for real p, q.
                    x = symbols[0]

                    for i, t in enumerate(margs):
                        if t.is_Pow:
                            b, q = t.args
                            if b in (x, -x) and q.is_real and not q.has(x):
                                margs[i] = x**q
                            elif b.is_Pow and not b.exp.has(x):
                                b, r = b.args
                                if b in (x, -x) and r.is_real:
                                    margs[i] = x**(r*q)
                            elif b.is_Mul and b.args[0] is S.NegativeOne:
                                b = -b
                                if b.is_Pow and not b.exp.has(x):
                                    b, r = b.args
                                    if b in (x, -x) and r.is_real:
                                        margs[i] = x**(r*q)

                expr = Mul(*margs)

        if expr is S.Zero:
            return expr

        if not expr.has(*symbols):
            expr = S.One

        # create Order instance:
        symbols.sort(key=cmp_to_key(Basic.compare))
        obj = Expr.__new__(cls, expr, *symbols, **assumptions)

        return obj