Ejemplo n.º 1
def mrv(e, x):
    "Returns a python set of  most rapidly varying (mrv) subexpressions of 'e'"
    e = powsimp(e, deep=True, combine='exp')
    assert isinstance(e, Basic)
    if not e.has(x):
        return set([])
    elif e == x:
        return set([x])
    elif e.is_Mul:
        a, b = e.as_two_terms()
        return mrv_max(mrv(a, x), mrv(b, x), x)
    elif e.is_Add:
        a, b = e.as_two_terms()
        return mrv_max(mrv(a, x), mrv(b, x), x)
    elif e.is_Pow:
        if e.exp.has(x):
            return mrv(exp(e.exp * log(e.base)), x)
            return mrv(e.base, x)
    elif e.func is log:
        return mrv(e.args[0], x)
    elif e.func is exp:
        if limitinf(e.args[0], x) in [oo, -oo]:
            return mrv_max(set([e]), mrv(e.args[0], x), x)
            return mrv(e.args[0], x)
    elif e.is_Function:
        if len(e.args) == 1:
            return mrv(e.args[0], x)
        #only functions of 1 argument currently implemented
        raise NotImplementedError("Functions with more arguments: '%s'" % e)
    raise NotImplementedError("Don't know how to calculate the mrv of '%s'" %
Ejemplo n.º 2
def mrv(e, x):
    """Returns a SubsSet of most rapidly varying (mrv) subexpressions of 'e',
       and e rewritten in terms of these"""
    e = powsimp(e, deep=True, combine='exp')
    assert isinstance(e, Basic)
    if not e.has(x):
        return SubsSet(), e
    elif e == x:
        s = SubsSet()
        return s, s[x]
    elif e.is_Mul or e.is_Add:
        i, d = e.as_independent(x)  # throw away x-independent terms
        if d.func != e.func:
            s, expr = mrv(d, x)
            return s, e.func(i, expr)
        a, b = d.as_two_terms()
        s1, e1 = mrv(a, x)
        s2, e2 = mrv(b, x)
        return mrv_max1(s1, s2, e.func(i, e1, e2), x)
    elif e.is_Pow:
        b, e = e.as_base_exp()
        if e.has(x):
            return mrv(exp(e * log(b)), x)
            s, expr = mrv(b, x)
            return s, expr**e
    elif e.func is log:
        s, expr = mrv(e.args[0], x)
        return s, log(expr)
    elif e.func is exp:
        # We know from the theory of this algorithm that exp(log(...)) may always
        # be simplified here, and doing so is vital for termination.
        if e.args[0].func is log:
            return mrv(e.args[0].args[0], x)
        if limitinf(e.args[0], x).is_unbounded:
            s1 = SubsSet()
            e1 = s1[e]
            s2, e2 = mrv(e.args[0], x)
            su = s1.union(s2)[0]
            su.rewrites[e1] = exp(e2)
            return mrv_max3(s1, e1, s2, exp(e2), su, e1, x)
            s, expr = mrv(e.args[0], x)
            return s, exp(expr)
    elif e.is_Function:
        l = [mrv(a, x) for a in e.args]
        l2 = [s for (s, _) in l if s != SubsSet()]
        if len(l2) != 1:
            # e.g. something like BesselJ(x, x)
            raise NotImplementedError("MRV set computation for functions in"
                                      " several variables not implemented.")
        s, ss = l2[0], SubsSet()
        args = [ss.do_subs(x[1]) for x in l]
        return s, e.func(*args)
    elif e.is_Derivative:
        raise NotImplementedError("MRV set computation for derviatives"
                                  " not implemented yet.")
        return mrv(e.args[0], x)
    raise NotImplementedError("Don't know how to calculate the mrv of '%s'" %
Ejemplo n.º 3
    def doit(self):
        prod = self._eval_product()

        if prod is not None:
            return powsimp(prod)
            return self
Ejemplo n.º 4
def _solsimp(e, t):
    no_t, has_t = powsimp(expand_mul(e)).as_independent(t)

    no_t = ratsimp(no_t)
    has_t = has_t.replace(exp, lambda a: exp(factor_terms(a)))

    return no_t + has_t
Ejemplo n.º 5
    def doit(self):
        prod = self._eval_product()

        if prod is not None:
            return powsimp(prod)
            return self
Ejemplo n.º 6
def mrv(e, x):
    "Returns a python set of  most rapidly varying (mrv) subexpressions of 'e'"
    e = powsimp(e, deep=True, combine='exp')
    assert isinstance(e, Basic)
    if not e.has(x):
        return set([])
    elif e == x:
        return set([x])
    elif e.is_Mul:
        a, b = e.as_two_terms()
        return mrv_max(mrv(a,x), mrv(b,x), x)
    elif e.is_Add:
        a, b = e.as_two_terms()
        return mrv_max(mrv(a,x), mrv(b,x), x)
    elif e.is_Pow:
        if e.exp.has(x):
            return mrv(exp(e.exp * log(e.base)), x)
            return mrv(e.base, x)
    elif e.func is log:
        return mrv(e.args[0], x)
    elif e.func is exp:
        if limitinf(e.args[0], x) in [oo,-oo]:
            return mrv_max(set([e]), mrv(e.args[0], x), x)
            return mrv(e.args[0], x)
    elif e.is_Function:
        if len(e.args) == 1:
            return mrv(e.args[0], x)
        #only functions of 1 argument currently implemented
        raise NotImplementedError("Functions with more arguments: '%s'" % e)
    raise NotImplementedError("Don't know how to calculate the mrv of '%s'" % e)
Ejemplo n.º 7
def mrv(e, x):
    """Returns a SubsSet of most rapidly varying (mrv) subexpressions of 'e',
       and e rewritten in terms of these"""
    e = powsimp(e, deep=True, combine='exp')
    assert isinstance(e, Basic)
    if not e.has(x):
        return SubsSet(), e
    elif e == x:
        s = SubsSet()
        return s, s[x]
    elif e.is_Mul or e.is_Add:
        i, d = e.as_independent(x)  # throw away x-independent terms
        if d.func != e.func:
            s, expr = mrv(d, x)
            return s, e.func(i, expr)
        a, b = d.as_two_terms()
        s1, e1 = mrv(a, x)
        s2, e2 = mrv(b, x)
        return mrv_max1(s1, s2, e.func(i, e1, e2), x)
    elif e.is_Pow:
        b, e = e.as_base_exp()
        if e.has(x):
            return mrv(exp(e * log(b)), x)
            s, expr = mrv(b, x)
            return s, expr**e
    elif e.func is log:
        s, expr = mrv(e.args[0], x)
        return s, log(expr)
    elif e.func is exp:
        # We know from the theory of this algorithm that exp(log(...)) may always
        # be simplified here, and doing so is vital for termination.
        if e.args[0].func is log:
            return mrv(e.args[0].args[0], x)
        if limitinf(e.args[0], x).is_unbounded:
            s1 = SubsSet()
            e1 = s1[e]
            s2, e2 = mrv(e.args[0], x)
            su = s1.union(s2)[0]
            su.rewrites[e1] = exp(e2)
            return mrv_max3(s1, e1, s2, exp(e2), su, e1, x)
            s, expr = mrv(e.args[0], x)
            return s, exp(expr)
    elif e.is_Function:
        l = [mrv(a, x) for a in e.args]
        l2 = [s for (s, _) in l if s != SubsSet()]
        if len(l2) != 1:
            # e.g. something like BesselJ(x, x)
            raise NotImplementedError("MRV set computation for functions in"
                                      " several variables not implemented.")
        s, ss = l2[0], SubsSet()
        args = [ss.do_subs(x[1]) for x in l]
        return s, e.func(*args)
    elif e.is_Derivative:
        raise NotImplementedError("MRV set computation for derviatives"
                                  " not implemented yet.")
        return mrv(e.args[0], x)
    raise NotImplementedError(
        "Don't know how to calculate the mrv of '%s'" % e)
Ejemplo n.º 8
    def doit(self, **hints):
        f = g = self.function
        for index, limit in enumerate(self.limits):
            i, a, b = limit
            dif = b - a
            if dif.is_Integer and dif < 0:
                return 1                
            g = self._eval_product(f, (i, a, b))
            if g is None:
                return Product(powsimp(f), *self.limits[index:])
                f = g

        if hints.get('deep', True):
            return f.doit(**hints)
            return powsimp(f)
Ejemplo n.º 9
    def doit(self, **hints):
        f = g = self.function
        for index, limit in enumerate(self.limits):
            i, a, b = limit
            dif = b - a
            if dif.is_Integer and dif < 0:
                a, b = b, a

            g = self._eval_product(f, (i, a, b))
            if g is None:
                return Product(powsimp(f), *self.limits[index:])
                f = g

        if hints.get('deep', True):
            return f.doit(**hints)
            return powsimp(f)
Ejemplo n.º 10
def rewrite(e, Omega, x, wsym):
    """e(x) ... the function
    Omega ... the mrv set
    wsym ... the symbol which is going to be used for w

    Returns the rewritten e in terms of w and log(w). See test_rewrite1()
    for examples and correct results.
    assert isinstance(Omega, SubsSet)
    assert len(Omega) != 0
    #all items in Omega must be exponentials
    for t in Omega.keys():
        assert t.func is exp
    rewrites = Omega.rewrites
    Omega = Omega.items()

    nodes = build_expression_tree(Omega, rewrites)
    Omega.sort(key=lambda x: nodes[x[1]].ht(), reverse=True)

    g, _ = Omega[
        -1]  #g is going to be the "w" - the simplest one in the mrv set
    sig = sign(g.args[0], x)
    if sig == 1:
        wsym = 1 / wsym  #if g goes to oo, substitute 1/w
    elif sig != -1:
        raise NotImplementedError('Result depends on the sign of %s' % sig)
    #O2 is a list, which results by rewriting each item in Omega using "w"
    O2 = []
    for f, var in Omega:
        c = limitinf(f.args[0] / g.args[0], x)
        arg = f.args[0]
        if var in rewrites:
            assert rewrites[var].func is exp
            arg = rewrites[var].args[0]
        O2.append((var, exp((arg - c * g.args[0]).expand()) * wsym**c))

    #Remember that Omega contains subexpressions of "e". So now we find
    #them in "e" and substitute them for our rewriting, stored in O2

    # the following powsimp is necessary to automatically combine exponentials,
    # so that the .subs() below succeeds:
    # TODO this should not be necessary
    f = powsimp(e, deep=True, combine='exp')
    for a, b in O2:
        f = f.subs(a, b)

    for _, var in Omega:
        assert not f.has(var)

    #finally compute the logarithm of w (logw).
    logw = g.args[0]
    if sig == 1:
        logw = -logw  #log(w)->log(1/w)=-log(w)

    return f, logw
Ejemplo n.º 11
Archivo: gruntz.py Proyecto: fxkr/sympy
def rewrite(e, Omega, x, wsym):
    """e(x) ... the function
    Omega ... the mrv set
    wsym ... the symbol which is going to be used for w

    Returns the rewritten e in terms of w and log(w). See test_rewrite1()
    for examples and correct results.
    assert isinstance(Omega, SubsSet)
    assert len(Omega) != 0
    #all items in Omega must be exponentials
    for t in Omega.keys():
        assert t.func is exp
    rewrites = Omega.rewrites
    Omega = Omega.items()

    nodes = build_expression_tree(Omega, rewrites)
    Omega.sort(key=lambda x: nodes[x[1]].ht(), reverse=True)

    g, _ = Omega[-1] #g is going to be the "w" - the simplest one in the mrv set
    sig = sign(g.args[0], x)
    if sig == 1:
        wsym = 1/wsym #if g goes to oo, substitute 1/w
    elif sig != -1:
        raise NotImplementedError('Result depends on the sign of %s' % sig)
    #O2 is a list, which results by rewriting each item in Omega using "w"
    O2 = []
    for f, var in Omega:
        c = limitinf(f.args[0]/g.args[0], x)
        arg = f.args[0]
        if var in rewrites:
            assert rewrites[var].func is exp
            arg = rewrites[var].args[0]
        O2.append((var, exp((arg - c*g.args[0]).expand())*wsym**c))

    #Remember that Omega contains subexpressions of "e". So now we find
    #them in "e" and substitute them for our rewriting, stored in O2

    # the following powsimp is necessary to automatically combine exponentials,
    # so that the .subs() below succeeds:
    # TODO this should not be necessary
    f = powsimp(e, deep=True, combine='exp')
    for a, b in O2:
        f = f.subs(a, b)

    for _, var in Omega:
        assert not f.has(var)

    #finally compute the logarithm of w (logw).
    logw = g.args[0]
    if sig == 1:
        logw = -logw     #log(w)->log(1/w)=-log(w)

    return f, logw
Ejemplo n.º 12
    def doit(self, **hints):
        f = self.function

        for index, limit in enumerate(self.limits):
            i, a, b = limit
            dif = b - a
            if dif.is_Integer and dif < 0:
                a, b = b + 1, a - 1
                f = 1 / f
            if isinstance(i, Idx):
                i = i.label

            g = self._eval_product(f, (i, a, b))
            if g in (None, S.NaN):
                return self.func(powsimp(f), *self.limits[index:])
                f = g

        if hints.get('deep', True):
            return f.doit(**hints)
            return powsimp(f)
Ejemplo n.º 13
    def doit(self, **hints):
        f = self.function

        for index, limit in enumerate(self.limits):
            i, a, b = limit
            dif = b - a
            if dif.is_Integer and dif < 0:
                a, b = b + 1, a - 1
                f = 1 / f
            if isinstance(i, Idx):
                i = i.label

            g = self._eval_product(f, (i, a, b))
            if g in (None, S.NaN):
                return self.func(powsimp(f), *self.limits[index:])
                f = g

        if hints.get('deep', True):
            return f.doit(**hints)
            return powsimp(f)
Ejemplo n.º 14
    def doit(self, **hints):
        # first make sure any definite limits have product
        # variables with matching assumptions
        reps = {}
        for xab in self.limits:
            # Must be imported here to avoid circular imports
            from .summations import _dummy_with_inherited_properties_concrete

            d = _dummy_with_inherited_properties_concrete(xab)
            if d:
                reps[xab[0]] = d
        if reps:
            undo = dict([(v, k) for k, v in reps.items()])
            did = self.xreplace(reps).doit(**hints)
            if type(did) is tuple:  # when separate=True
                did = tuple([i.xreplace(undo) for i in did])
                did = did.xreplace(undo)
            return did

        f = self.function
        for index, limit in enumerate(self.limits):
            i, a, b = limit
            dif = b - a
            if dif.is_integer and dif.is_negative:
                a, b = b + 1, a - 1
                f = 1 / f

            g = self._eval_product(f, (i, a, b))
            if g in (None, S.NaN):
                return self.func(powsimp(f), *self.limits[index:])
                f = g

        if hints.get("deep", True):
            return f.doit(**hints)
            return powsimp(f)
Ejemplo n.º 15
    def doit(self, **hints):
        term = self.term
        lower = self.lower
        upper = self.upper
        if hints.get('deep', True):
            term = term.doit(**hints)
            lower = lower.doit(**hints)
            upper = upper.doit(**hints)

        prod = self._eval_product(lower, upper, term)

        if prod is not None:
            return powsimp(prod)
            return self
Ejemplo n.º 16
    def doit(self, **hints):
        term = self.term
        lower = self.lower
        upper = self.upper
        if hints.get('deep', True):
            term = term.doit(**hints)
            lower = lower.doit(**hints)
            upper = upper.doit(**hints)

        prod = self._eval_product(lower, upper, term)

        if prod is not None:
            return powsimp(prod)
            return self
Ejemplo n.º 17
def rewrite(e, Omega, x, wsym):
    """e(x) ... the function
    Omega ... the mrv set
    wsym ... the symbol which is going to be used for w

    Returns the rewritten e in terms of w and log(w). See test_rewrite1()
    for examples and correct results.
    assert isinstance(Omega, set)
    assert len(Omega) != 0
    #all items in Omega must be exponentials
    for t in Omega:
        assert t.func is exp

    def cmpfunc(a, b):
        return -cmp(len(mrv(a, x)), len(mrv(b, x)))

    #sort Omega (mrv set) from the most complicated to the simplest ones
    #the complexity of "a" from Omega: the length of the mrv set of "a"
    Omega = list(Omega)
    g = Omega[-1]  #g is going to be the "w" - the simplest one in the mrv set
    sig = (sign(g.args[0], x) == 1)
    if sig:
        wsym = 1 / wsym  #if g goes to oo, substitute 1/w
    #O2 is a list, which results by rewriting each item in Omega using "w"
    O2 = []
    for f in Omega:
        c = mrv_leadterm(f.args[0] / g.args[0], x)
        #the c is a constant, because both f and g are from Omega:
        assert c[1] == 0
        O2.append(exp((f.args[0] - c[0] * g.args[0]).expand()) * wsym**c[0])

    #Remember that Omega contains subexpressions of "e". So now we find
    #them in "e" and substitute them for our rewriting, stored in O2

    # the following powsimp is necessary to automatically combine exponentials,
    # so that the .subs() below succeeds:
    f = powsimp(e, deep=True, combine='exp')
    for a, b in zip(Omega, O2):
        f = f.subs(a, b)

    #finally compute the logarithm of w (logw).
    logw = g.args[0]
    if sig:
        logw = -logw  #log(w)->log(1/w)=-log(w)

    return f, logw
Ejemplo n.º 18
def rewrite(e, Omega, x, wsym):
    """e(x) ... the function
    Omega ... the mrv set
    wsym ... the symbol which is going to be used for w

    Returns the rewritten e in terms of w and log(w). See test_rewrite1()
    for examples and correct results.
    assert isinstance(Omega, set)
    assert len(Omega) != 0
    # all items in Omega must be exponentials
    for t in Omega:
        assert t.func is exp

    def cmpfunc(a, b):
        return -cmp(len(mrv(a, x)), len(mrv(b, x)))

    # sort Omega (mrv set) from the most complicated to the simplest ones
    # the complexity of "a" from Omega: the length of the mrv set of "a"
    Omega = list(Omega)
    g = Omega[-1]  # g is going to be the "w" - the simplest one in the mrv set
    sig = sign(g.args[0], x) == 1
    if sig:
        wsym = 1 / wsym  # if g goes to oo, substitute 1/w
    # O2 is a list, which results by rewriting each item in Omega using "w"
    O2 = []
    for f in Omega:
        c = mrv_leadterm(f.args[0] / g.args[0], x)
        # the c is a constant, because both f and g are from Omega:
        assert c[1] == 0
        O2.append(exp((f.args[0] - c[0] * g.args[0]).expand()) * wsym ** c[0])

    # Remember that Omega contains subexpressions of "e". So now we find
    # them in "e" and substitute them for our rewriting, stored in O2

    # the following powsimp is necessary to automatically combine exponentials,
    # so that the .subs() below succeeds:
    f = powsimp(e, deep=True, combine="exp")
    for a, b in zip(Omega, O2):
        f = f.subs(a, b)

    # finally compute the logarithm of w (logw).
    logw = g.args[0]
    if sig:
        logw = -logw  # log(w)->log(1/w)=-log(w)

    return f, logw
Ejemplo n.º 19
def _simpsol(soleq):
    lhs = soleq.lhs
    sol = soleq.rhs
    sol = powsimp(sol)
    gens = list(sol.atoms(exp))
    p = Poly(sol, *gens, expand=False)
    gens = [factor_terms(g) for g in gens]
    if not gens:
        gens = p.gens
    syms = [Symbol('C1'), Symbol('C2')]
    terms = []
    for coeff, monom in zip(p.coeffs(), p.monoms()):
        coeff = piecewise_fold(coeff)
        if type(coeff) is Piecewise:
            coeff = Piecewise(*((ratsimp(coef).collect(syms), cond) for coef, cond in coeff.args))
            coeff = ratsimp(coeff).collect(syms)
        monom = Mul(*(g ** i for g, i in zip(gens, monom)))
        terms.append(coeff * monom)
    return Eq(lhs, Add(*terms))
Ejemplo n.º 20
def mrv(e, x):
    "Returns a python set of  most rapidly varying (mrv) subexpressions of 'e'"
    e = powsimp(e, deep=True, combine='exp')
    assert isinstance(e, Basic)
    if not e.has(x):
        return set([])
    elif e == x:
        return set([x])
    elif e.is_Mul or e.is_Add:
        while 1:
            i, d = e.as_independent(x)  # throw away x-independent terms
            if d.func != e.func and (d.is_Add or d.is_Mul):
                e = d
        if d.func != e.func:
            return mrv(d, x)
        a, b = d.as_two_terms()
        return mrv_max(mrv(a, x), mrv(b, x), x)
    elif e.is_Pow:
        b, e = e.as_base_exp()
        if e.has(x):
            return mrv(exp(e * log(b)), x)
            return mrv(b, x)
    elif e.func is log:
        return mrv(e.args[0], x)
    elif e.func is exp:
        if limitinf(e.args[0], x).is_unbounded:
            return mrv_max(set([e]), mrv(e.args[0], x), x)
            return mrv(e.args[0], x)
    elif e.is_Function:
        return reduce(lambda a, b: mrv_max(a, b, x),
                      [mrv(a, x) for a in e.args])
    elif e.is_Derivative:
        return mrv(e.args[0], x)
    raise NotImplementedError("Don't know how to calculate the mrv of '%s'" %
Ejemplo n.º 21
    def doit(self, **hints):
        term = self.term
        if term == 0:
            return S.Zero
        elif term == 1:
            return S.One
        lower = self.lower
        upper = self.upper
        if hints.get('deep', True):
            term = term.doit(**hints)
            lower = lower.doit(**hints)
            upper = upper.doit(**hints)
        dif = upper - lower
        if dif.is_Number and dif < 0:
            upper, lower = lower, upper

        prod = self._eval_product(lower, upper, term)

        if prod is not None:
            return powsimp(prod)
            return self
Ejemplo n.º 22
    def doit(self, **hints):
        term = self.term
        if term == 0:
            return S.Zero
        elif term == 1:
            return S.One
        lower = self.lower
        upper = self.upper
        if hints.get('deep', True):
            term = term.doit(**hints)
            lower = lower.doit(**hints)
            upper = upper.doit(**hints)
        dif = upper - lower
        if dif.is_Number and dif < 0:
            upper, lower = lower, upper

        prod = self._eval_product(lower, upper, term)

        if prod is not None:
            return powsimp(prod)
            return self
Ejemplo n.º 23
def mrv(e, x):
    "Returns a python set of  most rapidly varying (mrv) subexpressions of 'e'"
    e = powsimp(e, deep=True, combine="exp")
    assert isinstance(e, Basic)
    if not e.has(x):
        return set([])
    elif e == x:
        return set([x])
    elif e.is_Mul or e.is_Add:
        while 1:
            i, d = e.as_independent(x)  # throw away x-independent terms
            if d.func != e.func and (d.is_Add or d.is_Mul):
                e = d
        if d.func != e.func:
            return mrv(d, x)
        a, b = d.as_two_terms()
        return mrv_max(mrv(a, x), mrv(b, x), x)
    elif e.is_Pow:
        b, e = e.as_base_exp()
        if e.has(x):
            return mrv(exp(e * log(b)), x)
            return mrv(b, x)
    elif e.func is log:
        return mrv(e.args[0], x)
    elif e.func is exp:
        if limitinf(e.args[0], x).is_unbounded:
            return mrv_max(set([e]), mrv(e.args[0], x), x)
            return mrv(e.args[0], x)
    elif e.is_Function:
        return reduce(lambda a, b: mrv_max(a, b, x), [mrv(a, x) for a in e.args])
    elif e.is_Derivative:
        return mrv(e.args[0], x)
    raise NotImplementedError("Don't know how to calculate the mrv of '%s'" % e)
Ejemplo n.º 24
def _get_simplified_sol(sol, func, collectterms):
    Helper function which collects the solution on
    collectterms. Ideally this should be handled by odesimp.It is used
    only when the simplify is set to True in dsolve.

    The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is
    the multiplicity of the root, reroot is real part and imroot being the imaginary part.

    f = func.func
    x = func.args[0]
    assert len(sol) == 1 and sol[0].lhs == f(x)
    sol = sol[0].rhs
    sol = expand_mul(sol)
    for i, reroot, imroot in collectterms:
        sol = collect(sol, x**i * exp(reroot * x) * sin(abs(imroot) * x))
        sol = collect(sol, x**i * exp(reroot * x) * cos(imroot * x))
    for i, reroot, imroot in collectterms:
        sol = collect(sol, x**i * exp(reroot * x))
    sol = powsimp(sol)
    return Eq(f(x), sol)
Ejemplo n.º 25
def _quintic_simplify(expr):
    expr = powsimp(expr)
    expr = cancel(expr)
    return together(expr)
Ejemplo n.º 26
def _undetermined_coefficients_match(expr,
    Returns a trial function match if undetermined coefficients can be applied
    to ``expr``, and ``None`` otherwise.

    A trial expression can be found for an expression for use with the method
    of undetermined coefficients if the expression is an
    additive/multiplicative combination of constants, polynomials in `x` (the
    independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and
    `e^{a x}` terms (in other words, it has a finite number of linearly
    independent derivatives).

    Note that you may still need to multiply each term returned here by
    sufficient `x` to make it linearly independent with the solutions to the
    homogeneous equation.

    This is intended for internal use by ``undetermined_coefficients`` hints.

    SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of
    only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented.  So,
    for example, you will need to manually convert `\sin^2(x)` into `[1 +
    \cos(2 x)]/2` to properly apply the method of undetermined coefficients on


    >>> from sympy import log, exp
    >>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match
    >>> from sympy.abc import x
    >>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x)
    {'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}}
    >>> _undetermined_coefficients_match(log(x), x)
    {'test': False}

    a = Wild('a', exclude=[x])
    b = Wild('b', exclude=[x])
    expr = powsimp(expr, combine='exp')  # exp(x)*exp(2*x + 1) => exp(3*x + 1)
    retdict = {}

    def _test_term(expr, x):
        Test if ``expr`` fits the proper form for undetermined coefficients.
        if not expr.has(x):
            return True
        elif expr.is_Add:
            return all(_test_term(i, x) for i in expr.args)
        elif expr.is_Mul:
            if expr.has(sin, cos):
                foundtrig = False
                # Make sure that there is only one trig function in the args.
                # See the docstring.
                for i in expr.args:
                    if i.has(sin, cos):
                        if foundtrig:
                            return False
                            foundtrig = True
            return all(_test_term(i, x) for i in expr.args)
        elif expr.is_Function:
            if expr.func in (sin, cos, exp, sinh, cosh):
                if expr.args[0].match(a * x + b):
                    return True
                    return False
                return False
        elif expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \
                expr.exp >= 0:
            return True
        elif expr.is_Pow and expr.base.is_number:
            if expr.exp.match(a * x + b):
                return True
                return False
        elif expr.is_Symbol or expr.is_number:
            return True
            return False

    def _get_trial_set(expr, x, exprs=set()):
        Returns a set of trial terms for undetermined coefficients.

        The idea behind undetermined coefficients is that the terms expression
        repeat themselves after a finite number of derivatives, except for the
        coefficients (they are linearly dependent).  So if we collect these,
        we should have the terms of our trial function.
        def _remove_coefficient(expr, x):
            Returns the expression without a coefficient.

            Similar to expr.as_independent(x)[1], except it only works
            term = S.One
            if expr.is_Mul:
                for i in expr.args:
                    if i.has(x):
                        term *= i
            elif expr.has(x):
                term = expr
            return term

        expr = expand_mul(expr)
        if expr.is_Add:
            for term in expr.args:
                if _remove_coefficient(term, x) in exprs:
                    exprs.add(_remove_coefficient(term, x))
                    exprs = exprs.union(_get_trial_set(term, x, exprs))
            term = _remove_coefficient(expr, x)
            tmpset = exprs.union({term})
            oldset = set()
            while tmpset != oldset:
                # If you get stuck in this loop, then _test_term is probably
                # broken
                oldset = tmpset.copy()
                expr = expr.diff(x)
                term = _remove_coefficient(expr, x)
                if term.is_Add:
                    tmpset = tmpset.union(_get_trial_set(term, x, tmpset))
            exprs = tmpset
        return exprs

    def is_homogeneous_solution(term):
        r""" This function checks whether the given trialset contains any root
            of homogenous equation"""
        return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero

    retdict['test'] = _test_term(expr, x)
    if retdict['test']:
        # Try to generate a list of trial solutions that will have the
        # undetermined coefficients. Note that if any of these are not linearly
        # independent with any of the solutions to the homogeneous equation,
        # then they will need to be multiplied by sufficient x to make them so.
        # This function DOES NOT do that (it doesn't even look at the
        # homogeneous equation).
        temp_set = set()
        for i in Add.make_args(expr):
            act = _get_trial_set(i, x)
            if eq_homogeneous is not S.Zero:
                while any(is_homogeneous_solution(ts) for ts in act):
                    act = {x * ts for ts in act}
            temp_set = temp_set.union(act)

        retdict['trialset'] = temp_set
    return retdict
Ejemplo n.º 27
def checksol(f, symbol, sol=None, **flags):
    """Checks whether sol is a solution of equation f == 0.

    Input can be either a single symbol and corresponding value
    or a dictionary of symbols and values.


       >>> from sympy import symbols
       >>> from sympy.solvers import checksol
       >>> x, y = symbols('x,y')
       >>> checksol(x**4-1, x, 1)
       >>> checksol(x**4-1, x, 0)
       >>> checksol(x**2 + y**2 - 5**2, {x:3, y: 4})

       None is returned if checksol() could not conclude.

           'numerical=True (default)'
               do a fast numerical check if f has only one symbol.
           'minimal=True (default is False)'
               a very fast, minimal testing.
           'warning=True (default is False)'
               print a warning if checksol() could not conclude.
           'simplified=True (default)'
               solution should be simplified before substituting into function
               and function should be simplified after making substitution.
           'force=True (default is False)'
               make positive all symbols without assumptions regarding sign.

    if sol is not None:
        sol = {symbol: sol}
    elif isinstance(symbol, dict):
        sol = symbol
        msg = 'Expecting sym, val or {sym: val}, None but got %s, %s'
        raise ValueError(msg % (symbol, sol))

    if hasattr(f, '__iter__') and hasattr(f, '__len__'):
        if not f:
            raise ValueError('no functions to check')
        rv = set()
        for fi in f:
            check = checksol(fi, sol, **flags)
            if check is False:
                return False
        if None in rv: # rv might contain True and/or None
            return None
        assert len(rv) == 1 # True
        return True

    if isinstance(f, Poly):
        f = f.as_expr()
    elif isinstance(f, Equality):
        f = f.lhs - f.rhs

    if not f:
        return True

    if not f.has(*sol.keys()):
        return False

    attempt = -1
    numerical = flags.get('numerical', True)
    while 1:
        attempt += 1
        if attempt == 0:
            val = f.subs(sol)
        elif attempt == 1:
            if not val.atoms(Symbol) and numerical:
                # val is a constant, so a fast numerical test may suffice
                if val not in [S.Infinity, S.NegativeInfinity]:
                    # issue 2088 shows that +/-oo chops to 0
                    val = val.evalf(36).n(30, chop=True)
        elif attempt == 2:
            if flags.get('minimal', False):
            # the flag 'simplified=False' is used in solve to avoid
            # simplifying the solution. So if it is set to False there
            # the simplification will not be attempted here, either. But
            # if the simplification is done here then the flag should be
            # set to False so it isn't done again there.
            # FIXME: this can't work, since `flags` is not passed to
            # `checksol()` as a dict, but as keywords.
            # So, any modification to `flags` here will be lost when returning
            # from `checksol()`.
            if flags.get('simplified', True):
                for k in sol:
                    sol[k] = simplify(sympify(sol[k]))
                flags['simplified'] = False
                val = simplify(f.subs(sol))
            if flags.get('force', False):
                val = posify(val)[0]
        elif attempt == 3:
            val = powsimp(val)
        elif attempt == 4:
            val = cancel(val)
        elif attempt == 5:
            val = val.expand()
        elif attempt == 6:
            val = together(val)
        elif attempt == 7:
            val = powsimp(val)
        if val.is_zero:
            return True
        elif attempt > 0 and numerical and val.is_nonzero:
            return False

    if flags.get('warning', False):
        print("\n\tWarning: could not verify solution %s." % sol)
Ejemplo n.º 28
def rewrite(e, Omega, x, wsym):
    """e(x) ... the function
    Omega ... the mrv set
    wsym ... the symbol which is going to be used for w

    Returns the rewritten e in terms of w and log(w). See test_rewrite1()
    for examples and correct results.
    from sympy import ilcm
    assert isinstance(Omega, SubsSet)
    assert len(Omega) != 0
    #all items in Omega must be exponentials
    for t in Omega.keys():
        assert t.func is exp
    rewrites = Omega.rewrites
    Omega = list(Omega.items())

    nodes = build_expression_tree(Omega, rewrites)
    Omega.sort(key=lambda x: nodes[x[1]].ht(), reverse=True)

    # make sure we know the sign of each exp() term; after the loop,
    # g is going to be the "w" - the simplest one in the mrv set
    for g, _ in Omega:
        sig = sign(g.args[0], x)
        if sig != 1 and sig != -1:
            raise NotImplementedError('Result depends on the sign of %s' % sig)
    if sig == 1:
        wsym = 1 / wsym  # if g goes to oo, substitute 1/w
    #O2 is a list, which results by rewriting each item in Omega using "w"
    O2 = []
    denominators = []
    for f, var in Omega:
        c = limitinf(f.args[0] / g.args[0], x)
        if c.is_Rational:
        arg = f.args[0]
        if var in rewrites:
            assert rewrites[var].func is exp
            arg = rewrites[var].args[0]
        O2.append((var, exp((arg - c * g.args[0]).expand()) * wsym**c))

    #Remember that Omega contains subexpressions of "e". So now we find
    #them in "e" and substitute them for our rewriting, stored in O2

    # the following powsimp is necessary to automatically combine exponentials,
    # so that the .subs() below succeeds:
    # TODO this should not be necessary
    f = powsimp(e, deep=True, combine='exp')
    for a, b in O2:
        f = f.subs(a, b)

    for _, var in Omega:
        assert not f.has(var)

    #finally compute the logarithm of w (logw).
    logw = g.args[0]
    if sig == 1:
        logw = -logw  # log(w)->log(1/w)=-log(w)

    # Some parts of sympy have difficulty computing series expansions with
    # non-integral exponents. The following heuristic improves the situation:
    exponent = reduce(ilcm, denominators, 1)
    f = f.subs(wsym, wsym**exponent)
    logw /= exponent

    return f, logw
Ejemplo n.º 29
def rewrite(e, Omega, x, wsym):
    """e(x) ... the function
    Omega ... the mrv set
    wsym ... the symbol which is going to be used for w

    Returns the rewritten e in terms of w and log(w). See test_rewrite1()
    for examples and correct results.
    from sympy import ilcm

    if not isinstance(Omega, SubsSet):
        raise TypeError("Omega should be an instance of SubsSet")
    if len(Omega) == 0:
        raise ValueError("Length can not be 0")
    # all items in Omega must be exponentials
    for t in Omega.keys():
        if not t.func is exp:
            raise ValueError("Value should be exp")
    rewrites = Omega.rewrites
    Omega = list(Omega.items())

    nodes = build_expression_tree(Omega, rewrites)
    Omega.sort(key=lambda x: nodes[x[1]].ht(), reverse=True)

    # make sure we know the sign of each exp() term; after the loop,
    # g is going to be the "w" - the simplest one in the mrv set
    for g, _ in Omega:
        sig = sign(g.args[0], x)
        if sig != 1 and sig != -1:
            raise NotImplementedError("Result depends on the sign of %s" % sig)
    if sig == 1:
        wsym = 1 / wsym  # if g goes to oo, substitute 1/w
    # O2 is a list, which results by rewriting each item in Omega using "w"
    O2 = []
    denominators = []
    for f, var in Omega:
        c = limitinf(f.args[0] / g.args[0], x)
        if c.is_Rational:
        arg = f.args[0]
        if var in rewrites:
            if not rewrites[var].func is exp:
                raise ValueError("Value should be exp")
            arg = rewrites[var].args[0]
        O2.append((var, exp((arg - c * g.args[0]).expand()) * wsym ** c))

    # Remember that Omega contains subexpressions of "e". So now we find
    # them in "e" and substitute them for our rewriting, stored in O2

    # the following powsimp is necessary to automatically combine exponentials,
    # so that the .subs() below succeeds:
    # TODO this should not be necessary
    f = powsimp(e, deep=True, combine="exp")
    for a, b in O2:
        f = f.subs(a, b)

    for _, var in Omega:
        assert not f.has(var)

    # finally compute the logarithm of w (logw).
    logw = g.args[0]
    if sig == 1:
        logw = -logw  # log(w)->log(1/w)=-log(w)

    # Some parts of sympy have difficulty computing series expansions with
    # non-integral exponents. The following heuristic improves the situation:
    exponent = reduce(ilcm, denominators, 1)
    f = f.subs(wsym, wsym ** exponent)
    logw /= exponent

    return f, logw
Ejemplo n.º 30
def checksol(f, symbol, sol=None, **flags):
    """Checks whether sol is a solution of equation f == 0.

    Input can be either a single symbol and corresponding value
    or a dictionary of symbols and values.


       >>> from sympy import symbols
       >>> from sympy.solvers import checksol
       >>> x, y = symbols('x,y')
       >>> checksol(x**4-1, x, 1)
       >>> checksol(x**4-1, x, 0)
       >>> checksol(x**2 + y**2 - 5**2, {x:3, y: 4})

       None is returned if checksol() could not conclude.

           'numerical=True (default)'
               do a fast numerical check if f has only one symbol.
           'minimal=True (default is False)'
               a very fast, minimal testing.
           'warning=True (default is False)'
               print a warning if checksol() could not conclude.
           'simplified=True (default)'
               solution should be simplified before substituting into function
               and function should be simplified after making substitution.
           'force=True (default is False)'
               make positive all symbols without assumptions regarding sign.

    if sol is not None:
        sol = {symbol: sol}
    elif isinstance(symbol, dict):
        sol = symbol
        msg = 'Expecting sym, val or {sym: val}, None but got %s, %s'
        raise ValueError(msg % (symbol, sol))

    if hasattr(f, '__iter__') and hasattr(f, '__len__'):
        if not f:
            raise ValueError('no functions to check')
        rv = set()
        for fi in f:
            check = checksol(fi, sol, **flags)
            if check is False:
                return False
        if None in rv:  # rv might contain True and/or None
            return None
        assert len(rv) == 1  # True
        return True

    if isinstance(f, Poly):
        f = f.as_expr()
    elif isinstance(f, Equality):
        f = f.lhs - f.rhs

    if not f:
        return True

    if not f.has(*sol.keys()):
        return False

    attempt = -1
    numerical = flags.get('numerical', True)
    while 1:
        attempt += 1
        if attempt == 0:
            val = f.subs(sol)
        elif attempt == 1:
            if not val.atoms(Symbol) and numerical:
                # val is a constant, so a fast numerical test may suffice
                if val not in [S.Infinity, S.NegativeInfinity]:
                    # issue 2088 shows that +/-oo chops to 0
                    val = val.evalf(36).n(30, chop=True)
        elif attempt == 2:
            if flags.get('minimal', False):
            # the flag 'simplified=False' is used in solve to avoid
            # simplifying the solution. So if it is set to False there
            # the simplification will not be attempted here, either. But
            # if the simplification is done here then the flag should be
            # set to False so it isn't done again there.
            if flags.get('simplified', True):
                for k in sol:
                    sol[k] = simplify(sympify(sol[k]))
                flags['simplified'] = False
                val = simplify(f.subs(sol))
            if flags.get('force', False):
                val = posify(val)[0]
        elif attempt == 3:
            val = powsimp(val)
        elif attempt == 4:
            val = cancel(val)
        elif attempt == 5:
            val = val.expand()
        elif attempt == 6:
            val = together(val)
        elif attempt == 7:
            val = powsimp(val)
        if val.is_zero:
            return True
        elif attempt > 0 and numerical and val.is_nonzero:
            return False

    if flags.get('warning', False):
        print("Warning: could not verify solution %s." % sol)