Ejemplo n.º 1
0
def _solve_abs(f, symbol, domain):
    """ Helper function to solve equation involving absolute value function """
    if not domain.is_subset(S.Reals):
        raise ValueError(
            filldedent('''
            Absolute values cannot be inverted in the
            complex domain.'''))
    p, q, r = Wild('p'), Wild('q'), Wild('r')
    pattern_match = f.match(p * Abs(q) + r) or {}
    if not pattern_match.get(p, S.Zero).is_zero:
        f_p, f_q, f_r = pattern_match[p], pattern_match[q], pattern_match[r]
        q_pos_cond = solve_univariate_inequality(f_q >= 0,
                                                 symbol,
                                                 relational=False)
        q_neg_cond = solve_univariate_inequality(f_q < 0,
                                                 symbol,
                                                 relational=False)

        sols_q_pos = solveset_real(f_p * f_q + f_r,
                                   symbol).intersect(q_pos_cond)
        sols_q_neg = solveset_real(f_p * (-f_q) + f_r,
                                   symbol).intersect(q_neg_cond)
        return Union(sols_q_pos, sols_q_neg)
    else:
        return ConditionSet(symbol, Eq(f, 0), domain)
Ejemplo n.º 2
0
def _has_rational_power(expr, symbol):
    """
    Returns (bool, den) where bool is True if the term has a
    non-integer rational power and den is the denominator of the
    expression's exponent.

    Examples
    ========

    >>> from sympy.solvers.solveset import _has_rational_power
    >>> from sympy import sqrt
    >>> from sympy.abc import x
    >>> _has_rational_power(sqrt(x), x)
    (True, 2)
    >>> _has_rational_power(x**2, x)
    (False, 1)
    """
    a, p, q = Wild('a'), Wild('p'), Wild('q')
    pattern_match = expr.match(a*p**q) or {}
    if pattern_match.get(a, S.Zero) is S.Zero:
        return (False, S.One)
    elif p not in pattern_match.keys():
        return (False, S.One)
    elif isinstance(pattern_match[q], Rational) \
            and pattern_match[p].has(symbol):
        if not pattern_match[q].q == S.One:
            return (True, pattern_match[q].q)

    if not isinstance(pattern_match[a], Pow) \
            or isinstance(pattern_match[a], Mul):
        return (False, S.One)
    else:
        return _has_rational_power(pattern_match[a], symbol)
Ejemplo n.º 3
0
def test_relational_simplification_patterns_numerically():
    from sympy.core import Wild
    from sympy.logic.boolalg import _simplify_patterns_and, \
        _simplify_patterns_or, _simplify_patterns_xor
    a = Wild('a')
    b = Wild('b')
    c = Wild('c')
    symb = [a, b, c]
    patternlists = [[And, _simplify_patterns_and()],
                    [Or, _simplify_patterns_or()],
                    [Xor, _simplify_patterns_xor()]]
    valuelist = list(set(list(combinations(list(range(-2, 3)) * 3, 3))))
    # Skip combinations of +/-2 and 0, except for all 0
    valuelist = [v for v in valuelist if any([w % 2 for w in v]) or not any(v)]
    for func, patternlist in patternlists:
        for pattern in patternlist:
            original = func(*pattern[0].args)
            simplified = pattern[1]
            for values in valuelist:
                sublist = dict(zip(symb, values))
                originalvalue = original.xreplace(sublist)
                simplifiedvalue = simplified.xreplace(sublist)
                assert originalvalue == simplifiedvalue, "Original: {}\nand"\
                    " simplified: {}\ndo not evaluate to the same value for"\
                    "{}".format(pattern[0], simplified, sublist)
Ejemplo n.º 4
0
def test_relational_simplification_patterns_numerically():
    from sympy.core import Wild
    from sympy.logic.boolalg import simplify_patterns_and, \
        simplify_patterns_or, simplify_patterns_xor
    a = Wild('a')
    b = Wild('b')
    c = Wild('c')
    symb = [a, b, c]
    patternlists = [
        simplify_patterns_and(),
        simplify_patterns_or(),
        simplify_patterns_xor()
    ]
    for patternlist in patternlists:
        for pattern in patternlist:
            original = pattern[0]
            simplified = pattern[1]
            valuelist = list(set(list(combinations(list(range(-2, 2)) * 3,
                                                   3))))
            for values in valuelist:
                sublist = dict(zip(symb, values))
                originalvalue = original.subs(sublist)
                simplifiedvalue = simplified.subs(sublist)
                assert originalvalue == simplifiedvalue, "Original: {}\nand"\
                    " simplified: {}\ndo not evaluate to the same value for"\
                    "{}".format(original, simplified, sublist)
Ejemplo n.º 5
0
def _pat_gen(x):
    s = Wild('s')
    t = Wild('t')
    q = Wild('q')
    r = Wild('r')
    pat1 = (s**q) * (t**r)
    return pat1, s, t, q, r
Ejemplo n.º 6
0
def _pat_sincos(x):
    a = Wild('a', exclude=[x])
    n, m = [
        Wild(s, exclude=[x], properties=[lambda n: isinstance(n, Integer)])
        for s in 'nm'
    ]
    pat = sin(a * x)**n * cos(a * x)**m
    return pat, a, n, m
Ejemplo n.º 7
0
    def eval(self):
        f, i, a, b = self.f, self.i, self.a, self.b

        # Exploit the linearity of the sum
        if not f.has(i):
            return f*(b-a+1)
        if isinstance(f, Mul):
            L, R = getab(f)
            if not L.has(i): return L*Sum2(R, (i, a, b))
            if not R.has(i): return R*Sum2(L, (i, a, b))
        if isinstance(f, Add):
            L, R = getab(f)
            lsum = Sum2(L, (i,a,b))
            rsum = Sum2(R, (i,a,b))
            if not isinstance(lsum, Sum2) and not isinstance(rsum, Sum2):
                return lsum + rsum

        # Polynomial terms with Faulhaber's formula
        if f == i:
            f = Pow(i, 1, evaluate=False) # TODO: match should handle this
        p = Wild('p')
        e = f.match(i**p)
        if e != None:
            c = p.subs_dict(e)
            B = Basic.bernoulli
            if c.is_integer and c >= 0:
                s = (B(c+1, b+1) - B(c+1, a))/(c+1)
                return s.expand()

        # Geometric terms
        if isinstance(f, Pow):
            r, k = f[:]
            if not r.has(i) and k == i:
                # TODO: Pow should be able to simplify x**oo depending
                # on whether |x| < 1 or |x| > 1 for non-rational x
                if b == oo and isinstance(r, Rational) and abs(r) < 1:
                    return r**a / (1-r)
                else:
                    return (r**a - r**(b+1)) / (1-r)

        # Should nothing else works, use brute force if possible
        if isinstance(a, Rational) and a.is_integer and \
           isinstance(b, Rational) and b.is_integer:
            s = 0
            for j in range(a, b+1):
                s += f.subs(i, j)
            return s

        return self
Ejemplo n.º 8
0
    def eval(self):
        f, i, a, b = self.f, self.i, self.a, self.b

        # Exploit the linearity of the sum
        if not f.has(i):
            return f * (b - a + 1)
        if f.is_Mul:
            L, R = getab(f)
            if not L.has(i): return L * Sum2(R, (i, a, b))
            if not R.has(i): return R * Sum2(L, (i, a, b))
        if f.is_Add:
            L, R = getab(f)
            lsum = Sum2(L, (i, a, b))
            rsum = Sum2(R, (i, a, b))
            if not isinstance(lsum, Sum2) and not isinstance(rsum, Sum2):
                return lsum + rsum

        # Polynomial terms with Faulhaber's formula
        if f == i:
            f = Pow(i, 1, evaluate=False)  # TODO: match should handle this
        p = Wild('p')
        e = f.match(i**p)
        if e != None:
            c = p.subs(e)
            B = C.bernoulli
            if c.is_integer and c >= 0:
                s = (B(c + 1, b + 1) - B(c + 1, a)) / (c + 1)
                return s.expand()

        # Geometric terms
        if f.is_Pow:
            r, k = f.args[:]
            if not r.has(i) and k == i:
                # TODO: Pow should be able to simplify x**oo depending
                # on whether |x| < 1 or |x| > 1 for non-rational x
                if b == oo and isinstance(r, Rational) and abs(r) < 1:
                    return r**a / (1 - r)
                else:
                    return (r**a - r**(b + 1)) / (1 - r)

        # Should nothing else works, use brute force if possible
        if a.is_Integer and b.is_Integer:
            s = 0
            for j in range(a, b + 1):
                s += f.subs(i, j)
            return s

        return self
Ejemplo n.º 9
0
def limitinf(e, x):
    """Limit e(x) for x-> oo"""
    #rewrite e in terms of tractable functions only
    e = e.rewrite('tractable', deep=True)

    if not e.has(x):
        return e  # e is a constant
    if e.has(Order):
        e = e.expand().removeO()
    if not x.is_positive:
        # We make sure that x.is_positive is True so we
        # get all the correct mathematical behavior from the expression.
        # We need a fresh variable.
        p = Dummy('p', positive=True, finite=True)
        e = e.subs(x, p)
        x = p
    c0, e0 = mrv_leadterm(e, x)
    sig = sign(e0, x)
    if sig == 1:
        return S.Zero  # e0>0: lim f = 0
    elif sig == -1:  # e0<0: lim f = +-oo (the sign depends on the sign of c0)
        if c0.match(I * Wild("a", exclude=[I])):
            return c0 * oo
        s = sign(c0, x)
        #the leading term shouldn't be 0:
        if s == 0:
            raise ValueError("Leading term should not be 0")
        return s * oo
    elif sig == 0:
        return limitinf(c0, x)  # e0=0: lim f = lim c0
Ejemplo n.º 10
0
def _solve_abs(f, symbol):
    """ Helper function to solve equation involving absolute value function """
    p, q, r = Wild('p'), Wild('q'), Wild('r')
    pattern_match = f.match(p*Abs(q) + r) or {}
    if not pattern_match.get(p, S.Zero).is_zero:
        f_p, f_q, f_r = pattern_match[p], pattern_match[q], pattern_match[r]
        q_pos_cond = solve_univariate_inequality(f_q >= 0, symbol,
                                                 relational=False)
        q_neg_cond = solve_univariate_inequality(f_q < 0, symbol,
                                                 relational=False)

        sols_q_pos = solveset_real(f_p*f_q + f_r,
                                           symbol).intersect(q_pos_cond)
        sols_q_neg = solveset_real(f_p*(-f_q) + f_r,
                                           symbol).intersect(q_neg_cond)
        return Union(sols_q_pos, sols_q_neg)
    else:
        return ConditionSet(symbol, Eq(f, 0), S.Complexes)
Ejemplo n.º 11
0
 def compare_exponents(a, b):
     p1, p2 = Wild("p1"), Wild("p2")
     r_a = a.match(p1 * mainvar**p2)
     r_b = b.match(p1 * mainvar**p2)
     if r_a is None and r_b is None:
         c = Basic._compare_pretty(a, b)
         return c
     elif r_a is not None:
         if r_b is None:
             return 1
         else:
             c = Basic.compare(r_a[p2], r_b[p2])
             if c != 0:
                 return c
             else:
                 c = Basic._compare_pretty(a, b)
                 return c
     elif r_b is not None and r_a is None:
         return -1
Ejemplo n.º 12
0
def num_terms(expr):
    """Returns the number of terms comprising a sympy expression.
       Perhaps the measure of complexity/length should actually be number of
       variables, or depth of tree?
       
       from sympy.core import count_ops
       count_ops(Or(a,b, And(d,b)))
       This gives the number of predicates
    """
    return expr.count(Wild('a'))
Ejemplo n.º 13
0
def _solve_abs(f, symbol):
    """ Helper function to solve equation involving absolute value function """
    from sympy.solvers.inequalities import solve_univariate_inequality
    assert f.has(Abs)
    p, q, r = Wild('p'), Wild('q'), Wild('r')
    pattern_match = f.match(p*Abs(q) + r)
    if not pattern_match[p].is_zero:
        f_p, f_q, f_r = pattern_match[p], pattern_match[q], pattern_match[r]
        q_pos_cond = solve_univariate_inequality(f_q >= 0, symbol,
                                                 relational=False)
        q_neg_cond = solve_univariate_inequality(f_q < 0, symbol,
                                                 relational=False)

        sols_q_pos = solveset_real(f_p*f_q + f_r,
                                           symbol).intersect(q_pos_cond)
        sols_q_neg = solveset_real(f_p*(-f_q) + f_r,
                                           symbol).intersect(q_neg_cond)
        return Union(sols_q_pos, sols_q_neg)
    else:
        raise NotImplementedError
Ejemplo n.º 14
0
def limitinf(e, x, leadsimp=False):
    """Limit e(x) for x-> oo.

    Explanation
    ===========

    If ``leadsimp`` is True, an attempt is made to simplify the leading
    term of the series expansion of ``e``. That may succeed even if
    ``e`` cannot be simplified.
    """
    # rewrite e in terms of tractable functions only

    if not e.has(x):
        return e  # e is a constant
    if e.has(Order):
        e = e.expand().removeO()
    if not x.is_positive or x.is_integer:
        # We make sure that x.is_positive is True and x.is_integer is None
        # so we get all the correct mathematical behavior from the expression.
        # We need a fresh variable.
        p = Dummy('p', positive=True)
        e = e.subs(x, p)
        x = p
    e = e.rewrite('tractable', deep=True, limitvar=x)
    e = powdenest(e)
    c0, e0 = mrv_leadterm(e, x)
    sig = sign(e0, x)
    if sig == 1:
        return S.Zero  # e0>0: lim f = 0
    elif sig == -1:  # e0<0: lim f = +-oo (the sign depends on the sign of c0)
        if c0.match(I * Wild("a", exclude=[I])):
            return c0 * oo
        s = sign(c0, x)
        # the leading term shouldn't be 0:
        if s == 0:
            raise ValueError("Leading term should not be 0")
        return s * oo
    elif sig == 0:
        if leadsimp:
            c0 = c0.simplify()
        return limitinf(c0, x, leadsimp)  # e0=0: lim f = lim c0
    else:
        raise ValueError("{} could not be evaluated".format(sig))
Ejemplo n.º 15
0
def telescopic(L, R, limits):
    '''Tries to perform the summation using the telescopic property

    return None if not possible
    '''
    (i, a, b) = limits
    if L.is_Add or R.is_Add:
        return None

    # We want to solve(L.subs(i, i + m) + R, m)
    # First we try a simple match since this does things that
    # solve doesn't do, e.g. solve(f(k+m)-f(k), m) fails

    k = Wild("k")
    sol = (-R).match(L.subs(i, i + k))
    s = None
    if sol and k in sol:
        s = sol[k]
        if not (s.is_Integer and L.subs(i, i + s) == -R):
            #sometimes match fail(f(x+2).match(-f(x+k))->{k: -2 - 2x}))
            s = None

    # But there are things that match doesn't do that solve
    # can do, e.g. determine that 1/(x + m) = 1/(1 - x) when m = 1

    if s is None:
        m = Dummy('m')
        try:
            sol = solve(L.subs(i, i + m) + R, m) or []
        except NotImplementedError:
            return None
        sol = [
            si for si in sol
            if si.is_Integer and (L.subs(i, i + si) + R).expand().is_zero
        ]
        if len(sol) != 1:
            return None
        s = sol[0]

    if s < 0:
        return telescopic_direct(R, L, abs(s), (i, a, b))
    elif s > 0:
        return telescopic_direct(L, R, s, (i, a, b))
Ejemplo n.º 16
0
def limitinf(e, x):
    """Limit e(x) for x-> oo"""
    if not e.has(x):
        return e  #e is a constant
    if not x.is_positive:
        # We make sure that x.is_positive is True so we
        # get all the correct mathematical bechavior from the expression.
        # We need a fresh variable.
        p = Dummy('p', positive=True)
        e = e.subs(x, p)
        x = p
    c0, e0 = mrv_leadterm(e, x)
    sig = sign(e0, x)
    if sig == 1:
        return S.Zero  # e0>0: lim f = 0
    elif sig == -1:  #e0<0: lim f = +-oo (the sign depends on the sign of c0)
        if c0.match(I * Wild("a", exclude=[I])):
            return c0 * oo
        s = sign(c0, x)
        #the leading term shouldn't be 0:
        assert s != 0
        return s * oo
    elif sig == 0:
        return limitinf(c0, x)  #e0=0: lim f = lim c0
Ejemplo n.º 17
0
def __trigsimp(expr, deep=False):
    """recursive helper for trigsimp"""
    from sympy.simplify.fu import TR10i

    if _trigpat is None:
        _trigpats()
    a, b, c, d, matchers_division, matchers_add, \
    matchers_identity, artifacts = _trigpat

    if expr.is_Mul:
        # do some simplifications like sin/cos -> tan:
        if not expr.is_commutative:
            com, nc = expr.args_cnc()
            expr = _trigsimp(Mul._from_args(com), deep) * Mul._from_args(nc)
        else:
            for i, (pattern, simp, ok1, ok2) in enumerate(matchers_division):
                if not _dotrig(expr, pattern):
                    continue

                newexpr = _match_div_rewrite(expr, i)
                if newexpr is not None:
                    if newexpr != expr:
                        expr = newexpr
                        break
                    else:
                        continue

                # use SymPy matching instead
                res = expr.match(pattern)
                if res and res.get(c, 0):
                    if not res[c].is_integer:
                        ok = ok1.subs(res)
                        if not ok.is_positive:
                            continue
                        ok = ok2.subs(res)
                        if not ok.is_positive:
                            continue
                    # if "a" contains any of trig or hyperbolic funcs with
                    # argument "b" then skip the simplification
                    if any(w.args[0] == res[b] for w in res[a].atoms(
                            TrigonometricFunction, HyperbolicFunction)):
                        continue
                    # simplify and finish:
                    expr = simp.subs(res)
                    break  # process below

    if expr.is_Add:
        args = []
        for term in expr.args:
            if not term.is_commutative:
                com, nc = term.args_cnc()
                nc = Mul._from_args(nc)
                term = Mul._from_args(com)
            else:
                nc = S.One
            term = _trigsimp(term, deep)
            for pattern, result in matchers_identity:
                res = term.match(pattern)
                if res is not None:
                    term = result.subs(res)
                    break
            args.append(term * nc)
        if args != expr.args:
            expr = Add(*args)
            expr = min(expr, expand(expr), key=count_ops)
        if expr.is_Add:
            for pattern, result in matchers_add:
                if not _dotrig(expr, pattern):
                    continue
                expr = TR10i(expr)
                if expr.has(HyperbolicFunction):
                    res = expr.match(pattern)
                    # if "d" contains any trig or hyperbolic funcs with
                    # argument "a" or "b" then skip the simplification;
                    # this isn't perfect -- see tests
                    if res is None or not (a in res and b in res) or any(
                            w.args[0] in (res[a], res[b])
                            for w in res[d].atoms(TrigonometricFunction,
                                                  HyperbolicFunction)):
                        continue
                    expr = result.subs(res)
                    break

        # Reduce any lingering artifacts, such as sin(x)**2 changing
        # to 1 - cos(x)**2 when sin(x)**2 was "simpler"
        for pattern, result, ex in artifacts:
            if not _dotrig(expr, pattern):
                continue
            # Substitute a new wild that excludes some function(s)
            # to help influence a better match. This is because
            # sometimes, for example, 'a' would match sec(x)**2
            a_t = Wild('a', exclude=[ex])
            pattern = pattern.subs(a, a_t)
            result = result.subs(a, a_t)

            m = expr.match(pattern)
            was = None
            while m and was != expr:
                was = expr
                if m[a_t] == 0 or \
                        -m[a_t] in m[c].args or m[a_t] + m[c] == 0:
                    break
                if d in m and m[a_t] * m[d] + m[c] == 0:
                    break
                expr = result.subs(m)
                m = expr.match(pattern)
                m.setdefault(c, S.Zero)

    elif expr.is_Mul or expr.is_Pow or deep and expr.args:
        expr = expr.func(*[_trigsimp(a, deep) for a in expr.args])

    try:
        if not expr.has(*_trigs):
            raise TypeError
        e = expr.atoms(exp)
        new = expr.rewrite(exp, deep=deep)
        if new == e:
            raise TypeError
        fnew = factor(new)
        if fnew != new:
            new = sorted([new, factor(new)], key=count_ops)[0]
        # if all exp that were introduced disappeared then accept it
        if not (new.atoms(exp) - e):
            expr = new
    except TypeError:
        pass

    return expr
Ejemplo n.º 18
0
def _trigpats():
    global _trigpat
    a, b, c = symbols('a b c', cls=Wild)
    d = Wild('d', commutative=False)

    # for the simplifications like sinh/cosh -> tanh:
    # DO NOT REORDER THE FIRST 14 since these are assumed to be in this
    # order in _match_div_rewrite.
    matchers_division = (
        (a * sin(b)**c / cos(b)**c, a * tan(b)**c, sin(b), cos(b)),
        (a * tan(b)**c * cos(b)**c, a * sin(b)**c, sin(b), cos(b)),
        (a * cot(b)**c * sin(b)**c, a * cos(b)**c, sin(b), cos(b)),
        (a * tan(b)**c / sin(b)**c, a / cos(b)**c, sin(b), cos(b)),
        (a * cot(b)**c / cos(b)**c, a / sin(b)**c, sin(b), cos(b)),
        (a * cot(b)**c * tan(b)**c, a, sin(b), cos(b)),
        (a * (cos(b) + 1)**c * (cos(b) - 1)**c, a * (-sin(b)**2)**c,
         cos(b) + 1, cos(b) - 1),
        (a * (sin(b) + 1)**c * (sin(b) - 1)**c, a * (-cos(b)**2)**c,
         sin(b) + 1, sin(b) - 1),
        (a * sinh(b)**c / cosh(b)**c, a * tanh(b)**c, S.One, S.One),
        (a * tanh(b)**c * cosh(b)**c, a * sinh(b)**c, S.One, S.One),
        (a * coth(b)**c * sinh(b)**c, a * cosh(b)**c, S.One, S.One),
        (a * tanh(b)**c / sinh(b)**c, a / cosh(b)**c, S.One, S.One),
        (a * coth(b)**c / cosh(b)**c, a / sinh(b)**c, S.One, S.One),
        (a * coth(b)**c * tanh(b)**c, a, S.One, S.One),
        (c * (tanh(a) + tanh(b)) / (1 + tanh(a) * tanh(b)), tanh(a + b) * c,
         S.One, S.One),
    )

    matchers_add = (
        (c * sin(a) * cos(b) + c * cos(a) * sin(b) + d, sin(a + b) * c + d),
        (c * cos(a) * cos(b) - c * sin(a) * sin(b) + d, cos(a + b) * c + d),
        (c * sin(a) * cos(b) - c * cos(a) * sin(b) + d, sin(a - b) * c + d),
        (c * cos(a) * cos(b) + c * sin(a) * sin(b) + d, cos(a - b) * c + d),
        (c * sinh(a) * cosh(b) + c * sinh(b) * cosh(a) + d,
         sinh(a + b) * c + d),
        (c * cosh(a) * cosh(b) + c * sinh(a) * sinh(b) + d,
         cosh(a + b) * c + d),
    )

    # for cos(x)**2 + sin(x)**2 -> 1
    matchers_identity = (
        (a * sin(b)**2, a - a * cos(b)**2),
        (a * tan(b)**2, a * (1 / cos(b))**2 - a),
        (a * cot(b)**2, a * (1 / sin(b))**2 - a),
        (a * sin(b + c), a * (sin(b) * cos(c) + sin(c) * cos(b))),
        (a * cos(b + c), a * (cos(b) * cos(c) - sin(b) * sin(c))),
        (a * tan(b + c), a * ((tan(b) + tan(c)) / (1 - tan(b) * tan(c)))),
        (a * sinh(b)**2, a * cosh(b)**2 - a),
        (a * tanh(b)**2, a - a * (1 / cosh(b))**2),
        (a * coth(b)**2, a + a * (1 / sinh(b))**2),
        (a * sinh(b + c), a * (sinh(b) * cosh(c) + sinh(c) * cosh(b))),
        (a * cosh(b + c), a * (cosh(b) * cosh(c) + sinh(b) * sinh(c))),
        (a * tanh(b + c), a * ((tanh(b) + tanh(c)) / (1 + tanh(b) * tanh(c)))),
    )

    # Reduce any lingering artifacts, such as sin(x)**2 changing
    # to 1-cos(x)**2 when sin(x)**2 was "simpler"
    artifacts = (
        (a - a * cos(b)**2 + c, a * sin(b)**2 + c, cos),
        (a - a * (1 / cos(b))**2 + c, -a * tan(b)**2 + c, cos),
        (a - a * (1 / sin(b))**2 + c, -a * cot(b)**2 + c, sin),
        (a - a * cosh(b)**2 + c, -a * sinh(b)**2 + c, cosh),
        (a - a * (1 / cosh(b))**2 + c, a * tanh(b)**2 + c, cosh),
        (a + a * (1 / sinh(b))**2 + c, a * coth(b)**2 + c, sinh),

        # same as above but with noncommutative prefactor
        (a * d - a * d * cos(b)**2 + c, a * d * sin(b)**2 + c, cos),
        (a * d - a * d * (1 / cos(b))**2 + c, -a * d * tan(b)**2 + c, cos),
        (a * d - a * d * (1 / sin(b))**2 + c, -a * d * cot(b)**2 + c, sin),
        (a * d - a * d * cosh(b)**2 + c, -a * d * sinh(b)**2 + c, cosh),
        (a * d - a * d * (1 / cosh(b))**2 + c, a * d * tanh(b)**2 + c, cosh),
        (a * d + a * d * (1 / sinh(b))**2 + c, a * d * coth(b)**2 + c, sinh),
    )

    _trigpat = (a, b, c, d, matchers_division, matchers_add, matchers_identity,
                artifacts)
    return _trigpat
Ejemplo n.º 19
0
 def _eval_nseries(self, x, x0, n):
     from sympy import powsimp
     arg = self.args[0]
     k, l = Wild("k"), Wild("l")
     r = arg.match(k * x**l)
     if r is not None:
         k, l = r[k], r[l]
         if l != 0 and not l.has(x) and not k.has(x):
             r = log(k) + l * log(x)
             return r
     order = C.Order(x**n, x)
     arg = self.args[0]
     x = order.symbols[0]
     ln = C.log
     use_lt = not C.Order(1, x).contains(arg)
     if not use_lt:
         arg0 = arg.limit(x, 0)
         use_lt = (arg0 is S.Zero)
     if use_lt:  # singularity, #example: self = log(sin(x))
         # arg = (arg / lt) * lt
         lt = arg.as_leading_term(x)  # arg = sin(x); lt = x
         a = powsimp((arg / lt).expand(), deep=True,
                     combine='exp')  # a = sin(x)/x
         # the idea is to recursively call ln(a).series(), but one needs to
         # make sure that ln(sin(x)/x) doesn't get "simplified" to
         # -log(x)+ln(sin(x)) and an infinite recursion occurs, see also the
         # issue 252.
         obj = ln(lt) + ln(a)._eval_nseries(x, x0, n)
     else:
         # arg -> arg0 + (arg - arg0) -> arg0 * (1 + (arg/arg0 - 1))
         z = (arg / arg0 - 1)
         x = order.symbols[0]
         ln = C.log
         o = C.Order(z, x)
         if o is S.Zero:
             return ln(1 + z) + ln(arg0)
         if o.expr.is_number:
             e = ln(order.expr * x) / ln(x)
         else:
             e = ln(order.expr) / ln(o.expr)
         n = e.limit(x, 0) + 1
         if n.is_unbounded:
             # requested accuracy gives infinite series,
             # order is probably nonpolynomial e.g. O(exp(-1/x), x).
             return ln(1 + z) + ln(arg0)
         try:
             n = int(n)
         except TypeError:
             #well, the n is something more complicated (like 1+log(2))
             n = int(n.evalf()) + 1
         assert n >= 0, ` n `
         l = []
         g = None
         for i in xrange(n + 2):
             g = ln.taylor_term(i, z, g)
             g = g.nseries(x, x0, n)
             l.append(g)
         obj = C.Add(*l) + ln(arg0)
     obj2 = expand_log(powsimp(obj, deep=True, combine='exp'))
     if obj2 != obj:
         r = obj2.nseries(x, x0, n)
     else:
         r = obj
     if r == self:
         return self
     return r + order
Ejemplo n.º 20
0
    def eval(cls, arg):
        if arg.is_Number:
            if arg is S.NaN:
                return S.NaN
            elif arg is S.Zero:
                return S.One
            elif arg is S.One:
                return S.Exp1
            elif arg is S.Infinity:
                return S.Infinity
            elif arg is S.NegativeInfinity:
                return S.Zero
        elif arg.func is log:
            return arg.args[0]
        elif arg.is_Mul:
            coeff = arg.as_coefficient(S.Pi * S.ImaginaryUnit)

            if coeff is not None:
                if (2 * coeff).is_integer:
                    if coeff.is_even:
                        return S.One
                    elif coeff.is_odd:
                        return S.NegativeOne
                    elif (coeff + S.Half).is_even:
                        return -S.ImaginaryUnit
                    elif (coeff + S.Half).is_odd:
                        return S.ImaginaryUnit
            I = S.ImaginaryUnit
            oo = S.Infinity
            a = Wild("a", exclude=[I, oo])
            r = arg.match(I * a * oo)
            if r and r[a] != 0:
                return S.NaN

        if arg.is_Add:
            args = arg.args
        else:
            args = [arg]

        included, excluded = [], []

        for arg in args:
            coeff, terms = arg.as_coeff_terms()

            if coeff is S.Infinity:
                excluded.append(coeff**C.Mul(*terms))
            else:
                coeffs, log_term = [coeff], None

                for term in terms:
                    if term.func is log:
                        if log_term is None:
                            log_term = term.args[0]
                        else:
                            log_term = None
                            break
                    elif term.is_comparable:
                        coeffs.append(term)
                    else:
                        log_term = None
                        break

                if log_term is not None:
                    excluded.append(log_term**C.Mul(*coeffs))
                else:
                    included.append(arg)

        if excluded:
            return C.Mul(*(excluded + [cls(C.Add(*included))]))
Ejemplo n.º 21
0
def _pat_sincos(x):
    a, n, m = [Wild(s, exclude=[x]) for s in 'anm']
    pat = sin(a * x)**n * cos(a * x)**m

    return pat, a, n, m
Ejemplo n.º 22
0
    return s


def telescopic(L, R, (i, a, b)):
    '''Tries to perform the summation using the telescopic property

    return None if not possible
    '''
    if L.is_Add or R.is_Add:
        return None

    # We want to solve(L.subs(i, i + m) + R, m)
    # First we try a simple match since this does things that
    # solve doesn't do, e.g. solve(f(k+m)-f(k), m) fails

    k = Wild("k")
    sol = (-R).match(L.subs(i, i + k))
    s = None
    if sol and k in sol:
        s = sol[k]
        if not (s.is_Integer and L.subs(i, i + s) == -R):
            #sometimes match fail(f(x+2).match(-f(x+k))->{k: -2 - 2x}))
            s = None

    # But there are things that match doesn't do that solve
    # can do, e.g. determine that 1/(x + m) = 1/(1 - x) when m = 1

    if s is None:
        m = Dummy('m')
        try:
            sol = solve(L.subs(i, i + m) + R, m) or []
Ejemplo n.º 23
0
def eval_sum_symbolic(f, limits):
    (i, a, b) = limits
    if not f.has(i):
        return f * (b - a + 1)

    # Linearity
    if f.is_Mul:
        L, R = f.as_two_terms()

        if not L.has(i):
            sR = eval_sum_symbolic(R, (i, a, b))
            if sR:
                return L * sR

        if not R.has(i):
            sL = eval_sum_symbolic(L, (i, a, b))
            if sL:
                return R * sL

        try:
            f = apart(f, i)  # see if it becomes an Add
        except PolynomialError:
            pass

    if f.is_Add:
        L, R = f.as_two_terms()
        lrsum = telescopic(L, R, (i, a, b))

        if lrsum:
            return lrsum

        lsum = eval_sum_symbolic(L, (i, a, b))
        rsum = eval_sum_symbolic(R, (i, a, b))

        if None not in (lsum, rsum):
            return lsum + rsum

    # Polynomial terms with Faulhaber's formula
    n = Wild('n')
    result = f.match(i**n)

    if result is not None:
        n = result[n]

        if n.is_Integer:
            if n >= 0:
                return ((C.bernoulli(n + 1, b + 1) - C.bernoulli(n + 1, a)) /
                        (n + 1)).expand()
            elif a.is_Integer and a >= 1:
                if n == -1:
                    return C.harmonic(b) - C.harmonic(a - 1)
                else:
                    return C.harmonic(b, abs(n)) - C.harmonic(a - 1, abs(n))

    if not (a.has(S.Infinity, S.NegativeInfinity)
            or b.has(S.Infinity, S.NegativeInfinity)):
        # Geometric terms
        c1 = C.Wild('c1', exclude=[i])
        c2 = C.Wild('c2', exclude=[i])
        c3 = C.Wild('c3', exclude=[i])

        e = f.match(c1**(c2 * i + c3))

        if e is not None:
            p = (c1**c3).subs(e)
            q = (c1**c2).subs(e)

            r = p * (q**a - q**(b + 1)) / (1 - q)
            l = p * (b - a + 1)

            return Piecewise((l, Eq(q, S.One)), (r, True))

        r = gosper_sum(f, (i, a, b))

        if not r in (None, S.NaN):
            return r

    return eval_sum_hyper(f, (i, a, b))
Ejemplo n.º 24
0
def _pat_sincos(x):
    a = Wild("a", exclude=[x])
    n, m = [Wild(s, exclude=[x], properties=[_integer_instance]) for s in "nm"]
    pat = sin(a * x)**n * cos(a * x)**m
    return pat, a, n, m
Ejemplo n.º 25
0

def solve_linear_system_LU(matrix, syms):
    """ LU function works for invertible only """
    assert matrix.rows == matrix.cols - 1
    A = matrix[:matrix.rows, :matrix.rows]
    b = matrix[:, matrix.cols - 1:]
    soln = A.LUsolve(b)
    solutions = {}
    for i in range(soln.rows):
        solutions[syms[i]] = soln[i, 0]
    return solutions


x = Dummy('x')
a, b, c, d, e, f, g, h = [Wild(t, exclude=[x]) for t in 'abcdefgh']
patterns = None


def _generate_patterns():
    """
    Generates patterns for transcendental equations.

    This is lazily calculated (called) in the tsolve() function and stored in
    the patterns global variable.
    """

    tmp1 = f**(h - (c * g / b))
    tmp2 = (-e * tmp1 / a)**(1 / d)
    global patterns
    patterns = [
Ejemplo n.º 26
0
def denester(nested):
    """
    Denests a list of expressions that contain nested square roots.
    This method should not be called directly - use 'denest' instead.
    This algorithm is based on <http://www.almaden.ibm.com/cs/people/fagin/symb85.pdf>.

    It is assumed that all of the elements of 'nested' share the same
    bottom-level radicand. (This is stated in the paper, on page 177, in
    the paragraph immediately preceding the algorithm.)

    When evaluating all of the arguments in parallel, the bottom-level
    radicand only needs to be denested once. This means that calling
    denester with x arguments results in a recursive invocation with x+1
    arguments; hence denester has polynomial complexity.

    However, if the arguments were evaluated separately, each call would
    result in two recursive invocations, and the algorithm would have
    exponential complexity.

    This is discussed in the paper in the middle paragraph of page 179.
    """
    if all((n**2).is_Number
           for n in nested):  #If none of the arguments are nested
        for f in subsets(len(nested)):  #Test subset 'f' of nested
            p = prod(nested[i]**2 for i in range(len(f)) if f[i]).expand()
            if 1 in f and f.count(1) > 1 and f[-1]: p = -p
            if sqrt(p).is_Number:
                return sqrt(
                    p), f  #If we got a perfect square, return its square root.
        return nested[-1], [0] * len(
            nested
        )  #Otherwise, return the radicand from the previous invocation.
    else:
        a, b, r, R = Wild('a'), Wild('b'), Wild('r'), None
        values = [expr.match(sqrt(a + b * sqrt(r))) for expr in nested]
        for v in values:
            if r in v:  #Since if b=0, r is not defined
                if R is not None:
                    assert R == v[r]  #All the 'r's should be the same.
                else:
                    R = v[r]
        d, f = denester([
            sqrt((v[a]**2).expand() - (R * v[b]**2).expand()) for v in values
        ] + [sqrt(R)])
        if not any([f[i] for i in range(len(nested))
                    ]):  #If f[i]=0 for all i < len(nested)
            v = values[-1]
            return sqrt(v[a] + v[b] * d), f
        else:
            v = prod(nested[i]**2 for i in range(len(nested))
                     if f[i]).expand().match(a + b * sqrt(r))
            if 1 in f and f.index(1) < len(nested) - 1 and f[len(nested) - 1]:
                v[a] = -1 * v[a]
                v[b] = -1 * v[b]
            if not f[len(nested)]:  #Solution denests with square roots
                return (sqrt((v[a] + d).expand() / 2) + sign(v[b]) * sqrt(
                    (v[b]**2 * R / (2 * (v[a] + d))).expand())).expand(), f
            else:  #Solution requires a fourth root
                FR, s = (R.expand()**Rational(
                    1, 4)), sqrt((v[b] * R).expand() + d)
                return (s / (sqrt(2) * FR) + v[a] * FR /
                        (sqrt(2) * s)).expand(), f
Ejemplo n.º 27
0
def eval_sum_symbolic(f, limits):
    (i, a, b) = limits
    if not f.has(i):
        return f * (b - a + 1)

    # Linearity
    if f.is_Mul:
        L, R = f.as_two_terms()

        if not L.has(i):
            sR = eval_sum_symbolic(R, (i, a, b))
            if sR: return L * sR

        if not R.has(i):
            sL = eval_sum_symbolic(L, (i, a, b))
            if sL: return R * sL

        try:
            f = apart(f, i)  # see if it becomes an Add
        except PolynomialError:
            pass

    if f.is_Add:
        L, R = f.as_two_terms()
        lrsum = telescopic(L, R, (i, a, b))

        if lrsum:
            return lrsum

        lsum = eval_sum_symbolic(L, (i, a, b))
        rsum = eval_sum_symbolic(R, (i, a, b))

        if None not in (lsum, rsum):
            return lsum + rsum

    # Polynomial terms with Faulhaber's formula
    n = Wild('n')
    result = f.match(i**n)

    if result is not None:
        n = result[n]

        if n.is_Integer:
            if n >= 0:
                return ((C.bernoulli(n + 1, b + 1) - C.bernoulli(n + 1, a)) /
                        (n + 1)).expand()
            elif a.is_Integer and a >= 1:
                if n == -1:
                    return C.harmonic(b) - C.harmonic(a - 1)
                else:
                    return C.harmonic(b, abs(n)) - C.harmonic(a - 1, abs(n))

    # Geometric terms
    c1 = C.Wild('c1', exclude=[i])
    c2 = C.Wild('c2', exclude=[i])
    c3 = C.Wild('c3', exclude=[i])

    e = f.match(c1**(c2 * i + c3))

    if e is not None:
        c1 = c1.subs(e)
        c2 = c2.subs(e)
        c3 = c3.subs(e)

        # TODO: more general limit handling
        return c1**c3 * (c1**(a * c2) - c1**(c2 + b * c2)) / (1 - c1**c2)

    r = gosper_sum(f, (i, a, b))
    if not r in (None, S.NaN):
        return r

    return eval_sum_hyper(f, (i, a, b))
Ejemplo n.º 28
0
    for m in xrange(n):
        s += L.subs(i, a + m) + R.subs(i, b - m)
    return s


def telescopic(L, R, (i, a, b)):
    '''Tries to perform the summation using the telescopic property

    return None if not possible
    '''
    if L.is_Add or R.is_Add:
        return None
    s = None
    #First we try to solve using match
    #Maybe this should go inside solve
    k = Wild("k")
    sol = (-R).match(L.subs(i, i + k))
    if sol and k in sol:
        if L.subs(i, i + sol[k]) == -R:
            #sometimes match fail(f(x+2).match(-f(x+k))->{k: -2 - 2x}))
            s = sol[k]
    #Then we try to solve using solve
    if not s or not s.is_Integer:
        m = Symbol("m")
        try:
            s = solve(L.subs(i, i + m) + R, m)[0]
        except ValueError:
            pass
    if s and s.is_Integer:
        if s < 0:
            return telescopic_direct(R, L, abs(s), (i, a, b))
Ejemplo n.º 29
0
def tsolve(eq, sym):
    """
    Solves a transcendental equation with respect to the given
    symbol. Various equations containing mixed linear terms, powers,
    and logarithms, can be solved.

    Only a single solution is returned. This solution is generally
    not unique. In some cases, a complex solution may be returned
    even though a real solution exists.

        >>> from sympy import tsolve, log
        >>> from sympy.abc import x

        >>> tsolve(3**(2*x+5)-4, x)
        [(-5*log(3) + log(4))/(2*log(3))]

        >>> tsolve(log(x) + 2*x, x)
        [LambertW(2)/2]

    """
    if patterns is None:
        _generate_patterns()
    eq = sympify(eq)
    if isinstance(eq, Equality):
        eq = eq.lhs - eq.rhs
    sym = sympify(sym)
    eq2 = eq.subs(sym, x)
    # First see if the equation has a linear factor
    # In that case, the other factor can contain x in any way (as long as it
    # is finite), and we have a direct solution to which we add others that
    # may be found for the remaining portion.
    r = Wild('r')
    m = eq2.match((a * x + b) * r)
    if m and m[a]:
        return [(-b / a).subs(m).subs(x, sym)] + solve(m[r], x)
    for p, sol in patterns:
        m = eq2.match(p)
        if m:
            return [sol.subs(m).subs(x, sym)]

    # let's also try to inverse the equation
    lhs = eq
    rhs = S.Zero

    while True:
        indep, dep = lhs.as_independent(sym)

        # dep + indep == rhs
        if lhs.is_Add:
            # this indicates we have done it all
            if indep is S.Zero:
                break

            lhs = dep
            rhs -= indep

        # dep * indep == rhs
        else:
            # this indicates we have done it all
            if indep is S.One:
                break

            lhs = dep
            rhs /= indep

    #                    -1
    # f(x) = g  ->  x = f  (g)
    if lhs.is_Function and lhs.nargs == 1 and hasattr(lhs, 'inverse'):
        rhs = lhs.inverse()(rhs)
        lhs = lhs.args[0]

        sol = solve(lhs - rhs, sym)
        return sol

    elif lhs.is_Add:
        # just a simple case - we do variable substitution for first function,
        # and if it removes all functions - let's call solve.
        #      x    -x                   -1
        # UC: e  + e   = y      ->  t + t   = y
        t = Dummy('t')
        terms = lhs.args

        # find first term which is Function
        for f1 in lhs.args:
            if f1.is_Function:
                break
        else:
            raise NotImplementedError("Unable to solve the equation" + \
                "(tsolve: at least one Function expected at this point")

        # perform the substitution
        lhs_ = lhs.subs(f1, t)

        # if no Functions left, we can proceed with usual solve
        if not (lhs_.is_Function or any(term.is_Function
                                        for term in lhs_.args)):
            cv_sols = solve(lhs_ - rhs, t)
            for sol in cv_sols:
                if sol.has(sym):
                    raise NotImplementedError("Unable to solve the equation")
            cv_inv = solve(t - f1, sym)[0]
            sols = list()
            for sol in cv_sols:
                sols.append(cv_inv.subs(t, sol))
            return sols

    raise NotImplementedError("Unable to solve the equation.")
Ejemplo n.º 30
0
 def _eval_nseries(self, x, x0, n):
     from sympy import powsimp
     arg = self.args[0]
     k, l = Wild("k"), Wild("l")
     r = arg.match(k*x**l)
     if r is not None:
         k, l = r[k], r[l]
         if l != 0 and not l.has(x) and not k.has(x):
             r = log(k) + l*log(x)
             return r
     order = C.Order(x**n, x)
     arg = self.args[0]
     x = order.symbols[0]
     ln = C.log
     use_lt = not C.Order(1,x).contains(arg)
     if not use_lt:
         arg0 = arg.limit(x, 0)
         use_lt = (arg0 is S.Zero)
     if use_lt: # singularity, #example: self = log(sin(x))
         # arg = (arg / lt) * lt
         lt = arg.as_leading_term(x) # arg = sin(x); lt = x
         a = powsimp((arg/lt).expand(), deep=True, combine='exp') # a = sin(x)/x
         # the idea is to recursively call ln(a).series(), but one needs to
         # make sure that ln(sin(x)/x) doesn't get "simplified" to
         # -log(x)+ln(sin(x)) and an infinite recursion occurs, see also the
         # issue 252.
         obj = ln(lt) + ln(a)._eval_nseries(x, x0, n)
     else:
         # arg -> arg0 + (arg - arg0) -> arg0 * (1 + (arg/arg0 - 1))
         z = (arg/arg0 - 1)
         x = order.symbols[0]
         ln = C.log
         o = C.Order(z, x)
         if o is S.Zero:
             return ln(1+z)+ ln(arg0)
         if o.expr.is_number:
             e = ln(order.expr*x)/ln(x)
         else:
             e = ln(order.expr)/ln(o.expr)
         n = e.limit(x,0) + 1
         if n.is_unbounded:
             # requested accuracy gives infinite series,
             # order is probably nonpolynomial e.g. O(exp(-1/x), x).
             return ln(1+z)+ ln(arg0)
         try:
             n = int(n)
         except TypeError:
             #well, the n is something more complicated (like 1+log(2))
             n = int(n.evalf()) + 1
         assert n>=0,`n`
         l = []
         g = None
         for i in xrange(n+2):
             g = ln.taylor_term(i, z, g)
             g = g.nseries(x, x0, n)
             l.append(g)
         obj = C.Add(*l) + ln(arg0)
     obj2 = expand_log(powsimp(obj, deep=True, combine='exp'))
     if obj2 != obj:
         r = obj2.nseries(x, x0, n)
     else:
         r = obj
     if r == self:
         return self
     return r + order
Ejemplo n.º 31
0
    def _eval_integral(self, f, x):
        """Calculate the anti-derivative to the function f(x).

        This is a powerful function that should in theory be able to integrate
        everything that can be integrated. If you find something, that it
        doesn't, it is easy to implement it.

        (1) Simple heuristics (based on pattern matching and integral table):

         - most frequently used functions (e.g. polynomials)
         - functions non-integrable by any of the following algorithms (e.g.
           exp(-x**2))

        (2) Integration of rational functions:

         (a) using apart() - apart() is full partial fraction decomposition
         procedure based on Bronstein-Salvy algorithm. It gives formal
         decomposition with no polynomial factorization at all (so it's fast
         and gives the most general results). However it needs much better
         implementation of RootsOf class (if fact any implementation).
         (b) using Trager's algorithm - possibly faster than (a) but needs
         implementation :)

        (3) Whichever implementation of pmInt (Mateusz, Kirill's or a
        combination of both).

          - this way we can handle efficiently huge class of elementary and
            special functions

        (4) Recursive Risch algorithm as described in Bronstein's integration
        tutorial.

          - this way we can handle those integrable functions for which (3)
            fails

        (5) Powerful heuristics based mostly on user defined rules.

         - handle complicated, rarely used cases
        """

        # if it is a poly(x) then let the polynomial integrate itself (fast)
        #
        # It is important to make this check first, otherwise the other code
        # will return a sympy expression instead of a Polynomial.
        #
        # see Polynomial for details.
        if isinstance(f, Poly):
            return f.integrate(x)

        # Piecewise antiderivatives need to call special integrate.
        if f.func is Piecewise:
            return f._eval_integral(x)

        # let's cut it short if `f` does not depend on `x`
        if not f.has(x):
            return f*x

        # try to convert to poly(x) and then integrate if successful (fast)
        poly = f.as_poly(x)

        if poly is not None:
            return poly.integrate().as_basic()

        # since Integral(f=g1+g2+...) == Integral(g1) + Integral(g2) + ...
        # we are going to handle Add terms separately,
        # if `f` is not Add -- we only have one term
        parts = []
        args = Add.make_args(f)
        for g in args:
            coeff, g = g.as_independent(x)

            # g(x) = const
            if g is S.One:
                parts.append(coeff*x)
                continue

            #               c
            # g(x) = (a*x+b)
            if g.is_Pow and not g.exp.has(x):
                a = Wild('a', exclude=[x])
                b = Wild('b', exclude=[x])

                M = g.base.match(a*x + b)

                if M is not None:
                    if g.exp == -1:
                        h = C.log(g.base)
                    else:
                        h = g.base**(g.exp + 1) / (g.exp + 1)

                    parts.append(coeff * h / M[a])
                    continue

            #        poly(x)
            # g(x) = -------
            #        poly(x)
            if g.is_rational_function(x):
                parts.append(coeff * ratint(g, x))
                continue

            # g(x) = Mul(trig)
            h = trigintegrate(g, x)
            if h is not None:
                parts.append(coeff * h)
                continue

            # g(x) has at least a DiracDelta term
            h = deltaintegrate(g, x)
            if h is not None:
                parts.append(coeff * h)
                continue

            # fall back to the more general algorithm
            h = heurisch(g, x, hints=[])

            # if we failed maybe it was because we had
            # a product that could have been expanded,
            # so let's try an expansion of the whole
            # thing before giving up; we don't try this
            # out the outset because there are things
            # that cannot be solved unless they are
            # NOT expanded e.g., x**x*(1+log(x)). There
            # should probably be a checker somewhere in this
            # routine to look for such cases and try to do
            # collection on the expressions if they are already
            # in an expanded form
            if not h and len(args) == 1:
                f = f.expand(mul=True, deep=False)
                if f.is_Add:
                    return self._eval_integral(f, x)


            if h is not None:
                parts.append(coeff * h)
            else:
                return None

        return Add(*parts)
Ejemplo n.º 32
0
def trigsimp_nonrecursive(expr, deep=False):
    """
    A nonrecursive trig simplifier, used from trigsimp.

    == Usage ==
        trigsimp_nonrecursive(expr) -> reduces expression by using known trig
                                       identities

    == Notes ==

    deep ........ apply trigsimp inside functions

    == Examples ==
        >>> from sympy import *
        >>> x = Symbol('x')
        >>> y = Symbol('y')
        >>> e = 2*sin(x)**2 + 2*cos(x)**2
        >>> trigsimp(e)
        2
        >>> trigsimp_nonrecursive(log(e))
        log(2*cos(x)**2 + 2*sin(x)**2)
        >>> trigsimp_nonrecursive(log(e), deep=True)
        log(2)
    """
    from sympy.core.basic import S
    sin, cos, tan, cot = C.sin, C.cos, C.tan, C.cot

    if expr.is_Function:
        if deep:
            return expr.func(trigsimp_nonrecursive(expr.args[0], deep))
    elif expr.is_Mul:
        ret = S.One
        for x in expr.args:
            ret *= trigsimp_nonrecursive(x, deep)

        return ret
    elif expr.is_Pow:
        return Pow(trigsimp_nonrecursive(expr.base, deep),
                   trigsimp_nonrecursive(expr.exp, deep))
    elif expr.is_Add:
        # TODO this needs to be faster

        # The types of trig functions we are looking for
        a, b, c = map(Wild, 'abc')
        matchers = ((a * sin(b)**2, a - a * cos(b)**2),
                    (a * tan(b)**2, a * (1 / cos(b))**2 - a),
                    (a * cot(b)**2, a * (1 / sin(b))**2 - a))

        # Scan for the terms we need
        ret = S.Zero
        for term in expr.args:
            term = trigsimp_nonrecursive(term, deep)
            res = None
            for pattern, result in matchers:
                res = term.match(pattern)
                if res is not None:
                    ret += result.subs(res)
                    break
            if res is None:
                ret += term

        # Reduce any lingering artifacts, such as sin(x)**2 changing
        # to 1-cos(x)**2 when sin(x)**2 was "simpler"
        artifacts = ((a - a * cos(b)**2 + c, a * sin(b)**2 + c, cos),
                     (a - a * (1 / cos(b))**2 + c, -a * tan(b)**2 + c, cos),
                     (a - a * (1 / sin(b))**2 + c, -a * cot(b)**2 + c, sin))

        expr = ret
        for pattern, result, ex in artifacts:
            # Substitute a new wild that excludes some function(s)
            # to help influence a better match. This is because
            # sometimes, for example, 'a' would match sec(x)**2
            a_t = Wild('a', exclude=[ex])
            pattern = pattern.subs(a, a_t)
            result = result.subs(a, a_t)
            if expr.is_number:
                continue

            m = expr.match(pattern)
            while m is not None:
                if m[a_t] == 0 or -m[a_t] in m[c].args or m[a_t] + m[c] == 0:
                    break
                expr = result.subs(m)
                m = expr.match(pattern)

        return expr
    return expr