Esempio n. 1
0
def roots_linear(f):
    """Returns a list of roots of a linear polynomial."""
    r = -f.nth(0)/f.nth(1)
    dom = f.get_domain()

    if not dom.is_Numerical:
        if dom.is_Composite:
            r = factor(r)
        else:
            r = simplify(r)

    return [r]
Esempio n. 2
0
    def _rational_case(cls, poly, func):
        """Handle the rational function case. """
        roots = symbols('r:%d' % poly.degree())
        var, expr = func.variables[0], func.expr

        f = sum(expr.subs(var, r) for r in roots)
        p, q = together(f).as_numer_denom()

        domain = QQ[roots]

        p = p.expand()
        q = q.expand()

        try:
            p = Poly(p, domain=domain, expand=False)
        except GeneratorsNeeded:
            p, p_coeff = None, (p,)
        else:
            p_monom, p_coeff = zip(*p.terms())

        try:
            q = Poly(q, domain=domain, expand=False)
        except GeneratorsNeeded:
            q, q_coeff = None, (q,)
        else:
            q_monom, q_coeff = zip(*q.terms())

        coeffs, mapping = symmetrize(p_coeff + q_coeff, formal=True)
        formulas, values = viete(poly, roots), []

        for (sym, _), (_, val) in zip(mapping, formulas):
            values.append((sym, val))

        for i, (coeff, _) in enumerate(coeffs):
            coeffs[i] = coeff.subs(values)

        n = len(p_coeff)

        p_coeff = coeffs[:n]
        q_coeff = coeffs[n:]

        if p is not None:
            p = Poly(dict(zip(p_monom, p_coeff)), *p.gens).as_expr()
        else:
            (p,) = p_coeff

        if q is not None:
            q = Poly(dict(zip(q_monom, q_coeff)), *q.gens).as_expr()
        else:
            (q,) = q_coeff

        return factor(p/q)
Esempio n. 3
0
def _eval_sum_hyper(f, i, a):
    """ Returns (res, cond). Sums from a to oo. """
    from sympy.functions import hyper
    from sympy.simplify import hyperexpand, hypersimp, fraction, simplify
    from sympy.polys.polytools import Poly, factor
    from sympy.core.numbers import Float

    if a != 0:
        return _eval_sum_hyper(f.subs(i, i + a), i, 0)

    if f.subs(i, 0) == 0:
        if simplify(f.subs(i, Dummy('i', integer=True, positive=True))) == 0:
            return S(0), True
        return _eval_sum_hyper(f.subs(i, i + 1), i, 0)

    hs = hypersimp(f, i)
    if hs is None:
        return None

    if isinstance(hs, Float):
        from sympy.simplify.simplify import nsimplify
        hs = nsimplify(hs)

    numer, denom = fraction(factor(hs))
    top, topl = numer.as_coeff_mul(i)
    bot, botl = denom.as_coeff_mul(i)
    ab = [top, bot]
    factors = [topl, botl]
    params = [[], []]
    for k in range(2):
        for fac in factors[k]:
            mul = 1
            if fac.is_Pow:
                mul = fac.exp
                fac = fac.base
                if not mul.is_Integer:
                    return None
            p = Poly(fac, i)
            if p.degree() != 1:
                return None
            m, n = p.all_coeffs()
            ab[k] *= m**mul
            params[k] += [n/m]*mul

    # Add "1" to numerator parameters, to account for implicit n! in
    # hypergeometric series.
    ap = params[0] + [1]
    bq = params[1]
    x = ab[0]/ab[1]
    h = hyper(ap, bq, x)

    return f.subs(i, 0)*hyperexpand(h), h.convergence_statement
Esempio n. 4
0
def roots_linear(f):
    """Returns a list of roots of a linear polynomial."""
    
    add_comment('This equation is linear')

    r = -f.nth(0)/f.nth(1)
    dom = f.get_domain()

    if not dom.is_Numerical:
        if dom.is_Composite:
            r = factor(r)
        else:
            r = simplify(r)
    add_comment("The root of this equation is")
    add_eq(f.gen, r)
    return [r]
Esempio n. 5
0
def roots_linear(f):
    """Returns a list of roots of a linear polynomial."""
    
    add_eq(f.as_expr(), 0)
    
    tmp = str(Poly(f.as_expr()).factor_list()).split(',')[2][1:]

    r = -f.nth(0)/f.nth(1)
    dom = f.get_domain()

    if not dom.is_Numerical:
        if dom.is_Composite:
            r = factor(r)
        else:
            r = simplify(r)
    tmp = str(Poly(f.as_expr()).factor_list()).split(',')[2][1:]
    x = r
    add_eq(tmp, r)
    add_comment('')
    return [r]
Esempio n. 6
0
def _solve_lambert(f, symbol, gens):
    """Return solution to ``f`` if it is a Lambert-type expression
    else raise NotImplementedError.

    The equality, ``f(x, a..f) = a*log(b*X + c) + d*X - f = 0`` has the
    solution,  `X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))`. There
    are a variety of forms for `f(X, a..f)` as enumerated below:

    1a1)
      if B**B = R for R not [0, 1] then
      log(B) + log(log(B)) = log(log(R))
      X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R))
    1a2)
      if B*(b*log(B) + c)**a = R then
      log(B) + a*log(b*log(B) + c) = log(R)
      X = log(B); d=1, f=log(R)
    1b)
      if a*log(b*B + c) + d*B = R then
      X = B, f = R
    2a)
      if (b*B + c)*exp(d*B + g) = R then
      log(b*B + c) + d*B + g = log(R)
      a = 1, f = log(R) - g, X = B
    2b)
      if -b*B + g*exp(d*B + h) = c then
      log(g) + d*B + h - log(b*B + c) = 0
      a = -1, f = -h - log(g), X = B
    3)
      if d*p**(a*B + g) - b*B = c then
      log(d) + (a*B + g)*log(p) - log(c + b*B) = 0
      a = -1, d = a*log(p), f = -log(d) - g*log(p)
    """

    nrhs, lhs = f.as_independent(symbol, as_Add=True)
    rhs = -nrhs

    lamcheck = [
        tmp for tmp in gens if (tmp.func in [exp, log] or (
            tmp.is_Pow and symbol in tmp.exp.free_symbols))
    ]
    if not lamcheck:
        raise NotImplementedError()

    if lhs.is_Mul:
        lhs = expand_log(log(lhs))
        rhs = log(rhs)

    lhs = factor(lhs, deep=True)
    # make sure we are inverted as completely as possible
    r = Dummy()
    i, lhs = _invert(lhs - r, symbol)
    rhs = i.xreplace({r: rhs})

    # For the first ones:
    # 1a1) B**B = R != 0 (when 0, there is only a solution if the base is 0,
    #                     but if it is, the exp is 0 and 0**0=1
    #                     comes back as B*log(B) = log(R)
    # 1a2) B*(a + b*log(B))**p = R or with monomial expanded or with whole
    #                              thing expanded comes back unchanged
    #     log(B) + p*log(a + b*log(B)) = log(R)
    #     lhs is Mul:
    #         expand log of both sides to give:
    #         log(B) + log(log(B)) = log(log(R))
    # 1b) d*log(a*B + b) + c*B = R
    #     lhs is Add:
    #         isolate c*B and expand log of both sides:
    #         log(c) + log(B) = log(R - d*log(a*B + b))

    soln = []
    if not soln:
        mainlog = _mostfunc(lhs, log, symbol)
        if mainlog:
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(log(lhs) - log(rhs), symbol)
            elif lhs.is_Add:
                other = lhs.subs(mainlog, 0)
                if other and not other.is_Add and [
                        tmp for tmp in other.atoms(Pow)
                        if symbol in tmp.free_symbols
                ]:
                    if not rhs:
                        diff = log(other) - log(other - lhs)
                    else:
                        diff = log(lhs - other) - log(rhs - other)
                    soln = _lambert(expand_log(diff), symbol)
                else:
                    #it's ready to go
                    soln = _lambert(lhs - rhs, symbol)

    # For the next two,
    #     collect on main exp
    #     2a) (b*B + c)*exp(d*B + g) = R
    #         lhs is mul:
    #             log to give
    #             log(b*B + c) + d*B = log(R) - g
    #     2b) -b*B + g*exp(d*B + h) = R
    #         lhs is add:
    #             add b*B
    #             log and rearrange
    #             log(R + b*B) - d*B = log(g) + h

    if not soln:
        mainexp = _mostfunc(lhs, exp, symbol)
        if mainexp:
            lhs = collect(lhs, mainexp)
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainexp-containing term to rhs
                other = lhs.subs(mainexp, 0)
                mainterm = lhs - other
                rhs = rhs - other
                if (mainterm.could_extract_minus_sign()
                        and rhs.could_extract_minus_sign()):
                    mainterm *= -1
                    rhs *= -1
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    # 3) d*p**(a*B + b) + c*B = R
    #     collect on main pow
    #     log(R - c*B) - a*B*log(p) = log(d) + b*log(p)

    if not soln:
        mainpow = _mostfunc(lhs, Pow, symbol)
        if mainpow and symbol in mainpow.exp.free_symbols:
            lhs = collect(lhs, mainpow)
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainpow-containing term to rhs
                other = lhs.subs(mainpow, 0)
                mainterm = lhs - other
                rhs = rhs - other
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    if not soln:
        raise NotImplementedError('%s does not appear to have a solution in '
                                  'terms of LambertW' % f)

    return list(ordered(soln))
Esempio n. 7
0
def test_factor_nc():
    x, y = symbols('x,y')
    k = symbols('k', integer=True)
    n, m, o = symbols('n,m,o', commutative=False)

    # mul and multinomial expansion is needed
    from sympy.core.function import _mexpand
    e = x * (1 + y)**2
    assert _mexpand(e) == x + x * 2 * y + x * y**2

    def factor_nc_test(e):
        ex = _mexpand(e)
        assert ex.is_Add
        f = factor_nc(ex)
        assert not f.is_Add and _mexpand(f) == ex

    factor_nc_test(x * (1 + y))
    factor_nc_test(n * (x + 1))
    factor_nc_test(n * (x + m))
    factor_nc_test((x + m) * n)
    factor_nc_test(n * m * (x * o + n * o * m) * n)
    s = Sum(x, (x, 1, 2))
    factor_nc_test(x * (1 + s))
    factor_nc_test(x * (1 + s) * s)
    factor_nc_test(x * (1 + sin(s)))
    factor_nc_test((1 + n)**2)

    factor_nc_test((x + n) * (x + m) * (x + y))
    factor_nc_test(x * (n * m + 1))
    factor_nc_test(x * (n * m + x))
    factor_nc_test(x * (x * n * m + 1))
    factor_nc_test(x * n * (x * m + 1))
    factor_nc_test(x * (m * n + x * n * m))
    factor_nc_test(n * (1 - m) * n**2)

    factor_nc_test((n + m)**2)
    factor_nc_test((n - m) * (n + m)**2)
    factor_nc_test((n + m)**2 * (n - m))
    factor_nc_test((m - n) * (n + m)**2 * (n - m))

    assert factor_nc(n * (n + n * m)) == n**2 * (1 + m)
    assert factor_nc(m * (m * n + n * m * n**2)) == m * (m + n * m * n) * n
    eq = m * sin(n) - sin(n) * m
    assert factor_nc(eq) == eq

    # for coverage:
    from sympy.physics.secondquant import Commutator
    from sympy.polys.polytools import factor
    eq = 1 + x * Commutator(m, n)
    assert factor_nc(eq) == eq
    eq = x * Commutator(m, n) + x * Commutator(m, o) * Commutator(m, n)
    assert factor(eq) == x * (1 + Commutator(m, o)) * Commutator(m, n)

    # issue 6534
    assert (2 * n + 2 * m).factor() == 2 * (n + m)

    # issue 6701
    assert factor_nc(n**k + n**(k + 1)) == n**k * (1 + n)
    assert factor_nc((m * n)**k + (m * n)**(k + 1)) == (1 + m * n) * (m * n)**k

    # issue 6918
    assert factor_nc(-n * (2 * x**2 + 2 * x)) == -2 * n * x * (x + 1)
Esempio n. 8
0
def roots_quintic(f):
    """
    Calculate exact roots of a solvable quintic
    """
    result = []
    coeff_5, coeff_4, p, q, r, s = f.all_coeffs()

    # Eqn must be of the form x^5 + px^3 + qx^2 + rx + s
    if coeff_4:
        return result

    if coeff_5 != 1:
        l = [p / coeff_5, q / coeff_5, r / coeff_5, s / coeff_5]
        if not all(coeff.is_Rational for coeff in l):
            return result
        f = Poly(f / coeff_5)
    quintic = PolyQuintic(f)

    # Eqn standardized. Algo for solving starts here
    if not f.is_irreducible:
        return result

    f20 = quintic.f20
    # Check if f20 has linear factors over domain Z
    if f20.is_irreducible:
        return result

    # Now, we know that f is solvable
    for _factor in f20.factor_list()[1]:
        if _factor[0].is_linear:
            theta = _factor[0].root(0)
            break
    d = discriminant(f)
    delta = sqrt(d)
    # zeta = a fifth root of unity
    zeta1, zeta2, zeta3, zeta4 = quintic.zeta
    T = quintic.T(theta, d)
    tol = S(1e-10)
    alpha = T[1] + T[2] * delta
    alpha_bar = T[1] - T[2] * delta
    beta = T[3] + T[4] * delta
    beta_bar = T[3] - T[4] * delta

    disc = alpha**2 - 4 * beta
    disc_bar = alpha_bar**2 - 4 * beta_bar

    l0 = quintic.l0(theta)

    l1 = _quintic_simplify((-alpha + sqrt(disc)) / S(2))
    l4 = _quintic_simplify((-alpha - sqrt(disc)) / S(2))

    l2 = _quintic_simplify((-alpha_bar + sqrt(disc_bar)) / S(2))
    l3 = _quintic_simplify((-alpha_bar - sqrt(disc_bar)) / S(2))

    order = quintic.order(theta, d)
    test = (order * delta.n()) - ((l1.n() - l4.n()) * (l2.n() - l3.n()))
    # Comparing floats
    if not comp(test, 0, tol):
        l2, l3 = l3, l2

    # Now we have correct order of l's
    R1 = l0 + l1 * zeta1 + l2 * zeta2 + l3 * zeta3 + l4 * zeta4
    R2 = l0 + l3 * zeta1 + l1 * zeta2 + l4 * zeta3 + l2 * zeta4
    R3 = l0 + l2 * zeta1 + l4 * zeta2 + l1 * zeta3 + l3 * zeta4
    R4 = l0 + l4 * zeta1 + l3 * zeta2 + l2 * zeta3 + l1 * zeta4

    Res = [None, [None] * 5, [None] * 5, [None] * 5, [None] * 5]
    Res_n = [None, [None] * 5, [None] * 5, [None] * 5, [None] * 5]
    sol = Symbol('sol')

    # Simplifying improves performance a lot for exact expressions
    R1 = _quintic_simplify(R1)
    R2 = _quintic_simplify(R2)
    R3 = _quintic_simplify(R3)
    R4 = _quintic_simplify(R4)

    # Solve imported here. Causing problems if imported as 'solve'
    # and hence the changed name
    from sympy.solvers.solvers import solve as _solve
    a, b = symbols('a b', cls=Dummy)
    _sol = _solve(sol**5 - a - I * b, sol)
    for i in range(5):
        _sol[i] = factor(_sol[i])
    R1 = R1.as_real_imag()
    R2 = R2.as_real_imag()
    R3 = R3.as_real_imag()
    R4 = R4.as_real_imag()

    for i, root in enumerate(_sol):
        Res[1][i] = _quintic_simplify(root.subs({a: R1[0], b: R1[1]}))
        Res[2][i] = _quintic_simplify(root.subs({a: R2[0], b: R2[1]}))
        Res[3][i] = _quintic_simplify(root.subs({a: R3[0], b: R3[1]}))
        Res[4][i] = _quintic_simplify(root.subs({a: R4[0], b: R4[1]}))

    for i in range(1, 5):
        for j in range(5):
            Res_n[i][j] = Res[i][j].n()
            Res[i][j] = _quintic_simplify(Res[i][j])
    r1 = Res[1][0]
    r1_n = Res_n[1][0]

    for i in range(5):
        if comp(im(r1_n * Res_n[4][i]), 0, tol):
            r4 = Res[4][i]
            break

    u, v = quintic.uv(theta, d)
    sqrt5 = math.sqrt(5)

    # Now we have various Res values. Each will be a list of five
    # values. We have to pick one r value from those five for each Res
    u, v = quintic.uv(theta, d)
    testplus = (u + v * delta * sqrt(5)).n()
    testminus = (u - v * delta * sqrt(5)).n()

    # Evaluated numbers suffixed with _n
    # We will use evaluated numbers for calculation. Much faster.
    r4_n = r4.n()
    r2 = r3 = None

    for i in range(5):
        r2temp_n = Res_n[2][i]
        for j in range(5):
            # Again storing away the exact number and using
            # evaluated numbers in computations
            r3temp_n = Res_n[3][j]
            if (comp((r1_n * r2temp_n**2 + r4_n * r3temp_n**2 - testplus).n(),
                     0, tol) and comp(
                         (r3temp_n * r1_n**2 + r2temp_n * r4_n**2 -
                          testminus).n(), 0, tol)):
                r2 = Res[2][i]
                r3 = Res[3][j]
                break
        if r2:
            break

    # Now, we have r's so we can get roots
    x1 = (r1 + r2 + r3 + r4) / 5
    x2 = (r1 * zeta4 + r2 * zeta3 + r3 * zeta2 + r4 * zeta1) / 5
    x3 = (r1 * zeta3 + r2 * zeta1 + r3 * zeta4 + r4 * zeta2) / 5
    x4 = (r1 * zeta2 + r2 * zeta4 + r3 * zeta1 + r4 * zeta3) / 5
    x5 = (r1 * zeta1 + r2 * zeta2 + r3 * zeta3 + r4 * zeta4) / 5
    result = [x1, x2, x3, x4, x5]

    # Now check if solutions are distinct

    saw = set()
    for r in result:
        r = r.n(2)
        if r in saw:
            # Roots were identical. Abort, return []
            # and fall back to usual solve
            return []
        saw.add(r)
    return result
Esempio n. 9
0
def roots_quintic(f):
    """
    Calculate exact roots of a solvable quintic
    """
    result = []
    coeff_5, coeff_4, p, q, r, s = f.all_coeffs()

    # Eqn must be of the form x^5 + px^3 + qx^2 + rx + s
    if coeff_4:
        return result

    if coeff_5 != 1:
        l = [p/coeff_5, q/coeff_5, r/coeff_5, s/coeff_5]
        if not all(coeff.is_Rational for coeff in l):
            return result
        f = Poly(f/coeff_5)
    quintic = PolyQuintic(f)

    # Eqn standardized. Algo for solving starts here
    if not f.is_irreducible:
        return result

    f20 = quintic.f20
    # Check if f20 has linear factors over domain Z
    if f20.is_irreducible:
        return result

    # Now, we know that f is solvable
    for _factor in f20.factor_list()[1]:
        if _factor[0].is_linear:
            theta = _factor[0].root(0)
            break
    d = discriminant(f)
    delta = sqrt(d)
    # zeta = a fifth root of unity
    zeta1, zeta2, zeta3, zeta4 = quintic.zeta
    T = quintic.T(theta, d)
    tol = S(1e-10)
    alpha = T[1] + T[2]*delta
    alpha_bar = T[1] - T[2]*delta
    beta = T[3] + T[4]*delta
    beta_bar = T[3] - T[4]*delta

    disc = alpha**2 - 4*beta
    disc_bar = alpha_bar**2 - 4*beta_bar

    l0 = quintic.l0(theta)

    l1 = _quintic_simplify((-alpha + sqrt(disc)) / S(2))
    l4 = _quintic_simplify((-alpha - sqrt(disc)) / S(2))

    l2 = _quintic_simplify((-alpha_bar + sqrt(disc_bar)) / S(2))
    l3 = _quintic_simplify((-alpha_bar - sqrt(disc_bar)) / S(2))

    order = quintic.order(theta, d)
    test = (order*delta.n()) - ( (l1.n() - l4.n())*(l2.n() - l3.n()) )
    # Comparing floats
    if not comp(test, 0, tol):
        l2, l3 = l3, l2

    # Now we have correct order of l's
    R1 = l0 + l1*zeta1 + l2*zeta2 + l3*zeta3 + l4*zeta4
    R2 = l0 + l3*zeta1 + l1*zeta2 + l4*zeta3 + l2*zeta4
    R3 = l0 + l2*zeta1 + l4*zeta2 + l1*zeta3 + l3*zeta4
    R4 = l0 + l4*zeta1 + l3*zeta2 + l2*zeta3 + l1*zeta4

    Res = [None, [None]*5, [None]*5, [None]*5, [None]*5]
    Res_n = [None, [None]*5, [None]*5, [None]*5, [None]*5]
    sol = Symbol('sol')

    # Simplifying improves performance a lot for exact expressions
    R1 = _quintic_simplify(R1)
    R2 = _quintic_simplify(R2)
    R3 = _quintic_simplify(R3)
    R4 = _quintic_simplify(R4)

    # Solve imported here. Causing problems if imported as 'solve'
    # and hence the changed name
    from sympy.solvers.solvers import solve as _solve
    a, b = symbols('a b', cls=Dummy)
    _sol = _solve( sol**5 - a - I*b, sol)
    for i in range(5):
        _sol[i] = factor(_sol[i])
    R1 = R1.as_real_imag()
    R2 = R2.as_real_imag()
    R3 = R3.as_real_imag()
    R4 = R4.as_real_imag()

    for i, currentroot in enumerate(_sol):
        Res[1][i] = _quintic_simplify(currentroot.subs({ a: R1[0], b: R1[1] }))
        Res[2][i] = _quintic_simplify(currentroot.subs({ a: R2[0], b: R2[1] }))
        Res[3][i] = _quintic_simplify(currentroot.subs({ a: R3[0], b: R3[1] }))
        Res[4][i] = _quintic_simplify(currentroot.subs({ a: R4[0], b: R4[1] }))

    for i in range(1, 5):
        for j in range(5):
            Res_n[i][j] = Res[i][j].n()
            Res[i][j] = _quintic_simplify(Res[i][j])
    r1 = Res[1][0]
    r1_n = Res_n[1][0]

    for i in range(5):
        if comp(im(r1_n*Res_n[4][i]), 0, tol):
            r4 = Res[4][i]
            break

    # Now we have various Res values. Each will be a list of five
    # values. We have to pick one r value from those five for each Res
    u, v = quintic.uv(theta, d)
    testplus = (u + v*delta*sqrt(5)).n()
    testminus = (u - v*delta*sqrt(5)).n()

    # Evaluated numbers suffixed with _n
    # We will use evaluated numbers for calculation. Much faster.
    r4_n = r4.n()
    r2 = r3 = None

    for i in range(5):
        r2temp_n = Res_n[2][i]
        for j in range(5):
            # Again storing away the exact number and using
            # evaluated numbers in computations
            r3temp_n = Res_n[3][j]
            if (comp((r1_n*r2temp_n**2 + r4_n*r3temp_n**2 - testplus).n(), 0, tol) and
                comp((r3temp_n*r1_n**2 + r2temp_n*r4_n**2 - testminus).n(), 0, tol)):
                r2 = Res[2][i]
                r3 = Res[3][j]
                break
        if r2:
            break

    # Now, we have r's so we can get roots
    x1 = (r1 + r2 + r3 + r4)/5
    x2 = (r1*zeta4 + r2*zeta3 + r3*zeta2 + r4*zeta1)/5
    x3 = (r1*zeta3 + r2*zeta1 + r3*zeta4 + r4*zeta2)/5
    x4 = (r1*zeta2 + r2*zeta4 + r3*zeta1 + r4*zeta3)/5
    x5 = (r1*zeta1 + r2*zeta2 + r3*zeta3 + r4*zeta4)/5
    result = [x1, x2, x3, x4, x5]

    # Now check if solutions are distinct

    saw = set()
    for r in result:
        r = r.n(2)
        if r in saw:
            # Roots were identical. Abort, return []
            # and fall back to usual solve
            return []
        saw.add(r)
    return result
Esempio n. 10
0
def test_integrate_hyperexponential():
    # TODO: Add tests for integrate_hyperexponential() from the book
    a = Poly((1 + 2*t1 + t1**2 + 2*t1**3)*t**2 + (1 + t1**2)*t + 1 + t1**2, t)
    d = Poly(1, t)
    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(1 + t1**2, t1),
        Poly(t*(1 + t1**2), t)], 'Tfuncs': [tan, Lambda(i, exp(tan(i)))]})
    assert integrate_hyperexponential(a, d, DE) == \
        (exp(2*tan(x))*tan(x) + exp(tan(x)), 1 + t1**2, True)
    a = Poly((t1**3 + (x + 1)*t1**2 + t1 + x + 2)*t, t)
    assert integrate_hyperexponential(a, d, DE) == \
        ((x + tan(x))*exp(tan(x)), 0, True)

    a = Poly(t, t)
    d = Poly(1, t)
    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(2*x*t, t)],
        'Tfuncs': [Lambda(i, exp(x**2))]})

    assert integrate_hyperexponential(a, d, DE) == \
        (0, NonElementaryIntegral(exp(x**2), x), False)

    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t, t)], 'Tfuncs': [exp]})
    assert integrate_hyperexponential(a, d, DE) == (exp(x), 0, True)

    a = Poly(25*t**6 - 10*t**5 + 7*t**4 - 8*t**3 + 13*t**2 + 2*t - 1, t)
    d = Poly(25*t**6 + 35*t**4 + 11*t**2 + 1, t)
    assert integrate_hyperexponential(a, d, DE) == \
        (-(11 - 10*exp(x))/(5 + 25*exp(2*x)) + log(1 + exp(2*x)), -1, True)
    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t0, t0), Poly(t0*t, t)],
        'Tfuncs': [exp, Lambda(i, exp(exp(i)))]})
    assert integrate_hyperexponential(Poly(2*t0*t**2, t), Poly(1, t), DE) == (exp(2*exp(x)), 0, True)

    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t0, t0), Poly(-t0*t, t)],
        'Tfuncs': [exp, Lambda(i, exp(-exp(i)))]})
    assert integrate_hyperexponential(Poly(-27*exp(9) - 162*t0*exp(9) +
    27*x*t0*exp(9), t), Poly((36*exp(18) + x**2*exp(18) - 12*x*exp(18))*t, t), DE) == \
        (27*exp(exp(x))/(-6*exp(9) + x*exp(9)), 0, True)

    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t, t)], 'Tfuncs': [exp]})
    assert integrate_hyperexponential(Poly(x**2/2*t, t), Poly(1, t), DE) == \
        ((2 - 2*x + x**2)*exp(x)/2, 0, True)
    assert integrate_hyperexponential(Poly(1 + t, t), Poly(t, t), DE) == \
        (-exp(-x), 1, True)  # x - exp(-x)
    assert integrate_hyperexponential(Poly(x, t), Poly(t + 1, t), DE) == \
        (0, NonElementaryIntegral(x/(1 + exp(x)), x), False)

    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(1/x, t0), Poly(2*x*t1, t1)],
        'Tfuncs': [log, Lambda(i, exp(i**2))]})

    elem, nonelem, b = integrate_hyperexponential(Poly((8*x**7 - 12*x**5 + 6*x**3 - x)*t1**4 +
        (8*t0*x**7 - 8*t0*x**6 - 4*t0*x**5 + 2*t0*x**3 + 2*t0*x**2 - t0*x +
        24*x**8 - 36*x**6 - 4*x**5 + 22*x**4 + 4*x**3 - 7*x**2 - x + 1)*t1**3
        + (8*t0*x**8 - 4*t0*x**6 - 16*t0*x**5 - 2*t0*x**4 + 12*t0*x**3 +
        t0*x**2 - 2*t0*x + 24*x**9 - 36*x**7 - 8*x**6 + 22*x**5 + 12*x**4 -
        7*x**3 - 6*x**2 + x + 1)*t1**2 + (8*t0*x**8 - 8*t0*x**6 - 16*t0*x**5 +
        6*t0*x**4 + 10*t0*x**3 - 2*t0*x**2 - t0*x + 8*x**10 - 12*x**8 - 4*x**7
        + 2*x**6 + 12*x**5 + 3*x**4 - 9*x**3 - x**2 + 2*x)*t1 + 8*t0*x**7 -
        12*t0*x**6 - 4*t0*x**5 + 8*t0*x**4 - t0*x**2 - 4*x**7 + 4*x**6 +
        4*x**5 - 4*x**4 - x**3 + x**2, t1), Poly((8*x**7 - 12*x**5 + 6*x**3 -
        x)*t1**4 + (24*x**8 + 8*x**7 - 36*x**6 - 12*x**5 + 18*x**4 + 6*x**3 -
        3*x**2 - x)*t1**3 + (24*x**9 + 24*x**8 - 36*x**7 - 36*x**6 + 18*x**5 +
        18*x**4 - 3*x**3 - 3*x**2)*t1**2 + (8*x**10 + 24*x**9 - 12*x**8 -
        36*x**7 + 6*x**6 + 18*x**5 - x**4 - 3*x**3)*t1 + 8*x**10 - 12*x**8 +
        6*x**6 - x**4, t1), DE)

    assert factor(elem) == -((x - 1)*log(x)/((x + exp(x**2))*(2*x**2 - 1)))
    assert (nonelem, b) == (NonElementaryIntegral(exp(x**2)/(exp(x**2) + 1), x), False)
Esempio n. 11
0
 def r(a, b, c):
     return factor(a * x**2 + b * x + c)
Esempio n. 12
0
def _solve_lambert(f, symbol, gens):
    """Return solution to ``f`` if it is a Lambert-type expression
    else raise NotImplementedError.

    The equality, ``f(x, a..f) = a*log(b*X + c) + d*X - f = 0`` has the
    solution,  `X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))`. There
    are a variety of forms for `f(X, a..f)` as enumerated below:

    1a1)
      if B**B = R for R not [0, 1] then
      log(B) + log(log(B)) = log(log(R))
      X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R))
    1a2)
      if B*(b*log(B) + c)**a = R then
      log(B) + a*log(b*log(B) + c) = log(R)
      X = log(B); d=1, f=log(R)
    1b)
      if a*log(b*B + c) + d*B = R then
      X = B, f = R
    2a)
      if (b*B + c)*exp(d*B + g) = R then
      log(b*B + c) + d*B + g = log(R)
      a = 1, f = log(R) - g, X = B
    2b)
      if -b*B + g*exp(d*B + h) = c then
      log(g) + d*B + h - log(b*B + c) = 0
      a = -1, f = -h - log(g), X = B
    3)
      if d*p**(a*B + g) - b*B = c then
      log(d) + (a*B + g)*log(p) - log(c + b*B) = 0
      a = -1, d = a*log(p), f = -log(d) - g*log(p)
    """

    nrhs, lhs = f.as_independent(symbol, as_Add=True)
    rhs = -nrhs

    lamcheck = [tmp for tmp in gens
                if (tmp.func in [exp, log] or
                (tmp.is_Pow and symbol in tmp.exp.free_symbols))]
    if not lamcheck:
        raise NotImplementedError()

    if lhs.is_Mul:
        lhs = expand_log(log(lhs))
        rhs = log(rhs)

    lhs = factor(lhs, deep=True)
    # make sure we are inverted as completely as possible
    r = Dummy()
    i, lhs = _invert(lhs - r, symbol)
    rhs = i.xreplace({r: rhs})

    # For the first ones:
    # 1a1) B**B = R != 0 (when 0, there is only a solution if the base is 0,
    #                     but if it is, the exp is 0 and 0**0=1
    #                     comes back as B*log(B) = log(R)
    # 1a2) B*(a + b*log(B))**p = R or with monomial expanded or with whole
    #                              thing expanded comes back unchanged
    #     log(B) + p*log(a + b*log(B)) = log(R)
    #     lhs is Mul:
    #         expand log of both sides to give:
    #         log(B) + log(log(B)) = log(log(R))
    # 1b) d*log(a*B + b) + c*B = R
    #     lhs is Add:
    #         isolate c*B and expand log of both sides:
    #         log(c) + log(B) = log(R - d*log(a*B + b))

    soln = []
    if not soln:
        mainlog = _mostfunc(lhs, log, symbol)
        if mainlog:
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(log(lhs) - log(rhs), symbol)
            elif lhs.is_Add:
                other = lhs.subs(mainlog, 0)
                if other and not other.is_Add and [
                        tmp for tmp in other.atoms(Pow)
                        if symbol in tmp.free_symbols]:
                    if not rhs:
                        diff = log(other) - log(other - lhs)
                    else:
                        diff = log(lhs - other) - log(rhs - other)
                    soln = _lambert(expand_log(diff), symbol)
                else:
                    #it's ready to go
                    soln = _lambert(lhs - rhs, symbol)

    # For the next two,
    #     collect on main exp
    #     2a) (b*B + c)*exp(d*B + g) = R
    #         lhs is mul:
    #             log to give
    #             log(b*B + c) + d*B = log(R) - g
    #     2b) -b*B + g*exp(d*B + h) = R
    #         lhs is add:
    #             add b*B
    #             log and rearrange
    #             log(R + b*B) - d*B = log(g) + h

    if not soln:
        mainexp = _mostfunc(lhs, exp, symbol)
        if mainexp:
            lhs = collect(lhs, mainexp)
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainexp-containing term to rhs
                other = lhs.subs(mainexp, 0)
                mainterm = lhs - other
                rhs=rhs - other
                if (mainterm.could_extract_minus_sign() and
                    rhs.could_extract_minus_sign()):
                    mainterm *= -1
                    rhs *= -1
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    # 3) d*p**(a*B + b) + c*B = R
    #     collect on main pow
    #     log(R - c*B) - a*B*log(p) = log(d) + b*log(p)

    if not soln:
        mainpow = _mostfunc(lhs, Pow, symbol)
        if mainpow and symbol in mainpow.exp.free_symbols:
            lhs = collect(lhs, mainpow)
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainpow-containing term to rhs
                other = lhs.subs(mainpow, 0)
                mainterm = lhs - other
                rhs = rhs - other
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    if not soln:
        raise NotImplementedError('%s does not appear to have a solution in '
            'terms of LambertW' % f)

    return list(ordered(soln))
Esempio n. 13
0
def equivalence_hypergeometric(A, B, func):
    # This method for finding the equivalence is only for 2F1 type.
    # We can extend it for 1F1 and 0F1 type also.
    x = func.args[0]

    # making given equation in normal form
    I1 = factor(cancel(A.diff(x) / 2 + A**2 / 4 - B))

    # computing shifted invariant(J1) of the equation
    J1 = factor(cancel(x**2 * I1 + S(1) / 4))
    num, dem = J1.as_numer_denom()
    num = powdenest(expand(num))
    dem = powdenest(expand(dem))
    pow_num = set()
    pow_dem = set()

    # this function will compute the different powers of variable(x) in J1.
    # then it will help in finding value of k. k is power of x such that we can express
    # J1 = x**k * J0(x**k) then all the powers in J0 become integers.
    def _power_counting(num):
        _pow = {0}
        for val in num:
            if val.has(x):
                if isinstance(val, Pow) and val.as_base_exp()[0] == x:
                    _pow.add(val.as_base_exp()[1])
                elif val == x:
                    _pow.add(val.as_base_exp()[1])
                else:
                    _pow.update(_power_counting(val.args))
        return _pow

    pow_num = _power_counting((num, ))
    pow_dem = _power_counting((dem, ))
    pow_dem.update(pow_num)

    _pow = pow_dem
    k = gcd(_pow)

    # computing I0 of the given equation
    I0 = powdenest(simplify(factor(((J1 / k**2) - S(1) / 4) / ((x**k)**2))),
                   force=True)
    I0 = factor(cancel(powdenest(I0.subs(x, x**(S(1) / k)), force=True)))
    num, dem = I0.as_numer_denom()

    max_num_pow = max(_power_counting((num, )))
    dem_args = dem.args
    sing_point = []
    dem_pow = []
    # calculating singular point of I0.
    for arg in dem_args:
        if arg.has(x):
            if isinstance(arg, Pow):
                # (x-a)**n
                dem_pow.append(arg.as_base_exp()[1])
                sing_point.append(
                    list(roots(arg.as_base_exp()[0], x).keys())[0])
            else:
                # (x-a) type
                dem_pow.append(arg.as_base_exp()[1])
                sing_point.append(list(roots(arg, x).keys())[0])

    dem_pow.sort()
    # checking if equivalence is exists or not.

    if equivalence(max_num_pow, dem_pow) == "2F1":
        return {'I0': I0, 'k': k, 'sing_point': sing_point, 'type': "2F1"}
    else:
        return None
Esempio n. 14
0
def match_2nd_2F1_hypergeometric(I, k, sing_point, func):
    x = func.args[0]
    a = Wild("a")
    b = Wild("b")
    c = Wild("c")
    t = Wild("t")
    s = Wild("s")
    r = Wild("r")
    alpha = Wild("alpha")
    beta = Wild("beta")
    gamma = Wild("gamma")
    delta = Wild("delta")

    rn = {'type': None}
    # I0 of the standerd 2F1 equation.
    I0 = ((a - b + 1) * (a - b - 1) * x**2 + 2 *
          ((1 - a - b) * c + 2 * a * b) * x + c * (c - 2)) / (4 * x**2 *
                                                              (x - 1)**2)
    if sing_point != [0, 1]:
        # If singular point is [0, 1] then we have standerd equation.
        eqs = []
        sing_eqs = [
            -beta / alpha, -delta / gamma, (delta - beta) / (alpha - gamma)
        ]
        # making equations for the finding the mobius transformation
        for i in range(3):
            if i < len(sing_point):
                eqs.append(Eq(sing_eqs[i], sing_point[i]))
            else:
                eqs.append(Eq(1 / sing_eqs[i], 0))
        # solving above equations for the mobius transformation
        _beta = -alpha * sing_point[0]
        _delta = -gamma * sing_point[1]
        _gamma = alpha
        if len(sing_point) == 3:
            _gamma = (_beta + sing_point[2] * alpha) / (sing_point[2] -
                                                        sing_point[1])
        mob = (alpha * x + beta) / (gamma * x + delta)
        mob = mob.subs(beta, _beta)
        mob = mob.subs(delta, _delta)
        mob = mob.subs(gamma, _gamma)
        mob = cancel(mob)
        t = (beta - delta * x) / (gamma * x - alpha)
        t = cancel(((t.subs(beta, _beta)).subs(delta,
                                               _delta)).subs(gamma, _gamma))
    else:
        mob = x
        t = x

    # applying mobius transformation in I to make it into I0.
    I = I.subs(x, t)
    I = I * (t.diff(x))**2
    I = factor(I)
    dict_I = {x**2: 0, x: 0, 1: 0}
    I0_num, I0_dem = I0.as_numer_denom()
    # collecting coeff of (x**2, x), of the standerd equation.
    # substituting (a-b) = s, (a+b) = r
    dict_I0 = {
        x**2: s**2 - 1,
        x: (2 * (1 - r) * c + (r + s) * (r - s)),
        1: c * (c - 2)
    }
    # collecting coeff of (x**2, x) from I0 of the given equation.
    dict_I.update(
        collect(expand(cancel(I * I0_dem)), [x**2, x], evaluate=False))
    eqs = []
    # We are comparing the coeff of powers of different x, for finding the values of
    # parameters of standerd equation.
    for key in [x**2, x, 1]:
        eqs.append(Eq(dict_I[key], dict_I0[key]))

    # We can have many possible roots for the equation.
    # I am selecting the root on the basis that when we have
    # standard equation eq = x*(x-1)*f(x).diff(x, 2) + ((a+b+1)*x-c)*f(x).diff(x) + a*b*f(x)
    # then root should be a, b, c.

    _c = 1 - factor(sqrt(1 + eqs[2].lhs))
    if not _c.has(Symbol):
        _c = min(list(roots(eqs[2], c)))
    _s = factor(sqrt(eqs[0].lhs + 1))
    _r = _c - factor(sqrt(_c**2 + _s**2 + eqs[1].lhs - 2 * _c))
    _a = (_r + _s) / 2
    _b = (_r - _s) / 2

    rn = {
        'a': simplify(_a),
        'b': simplify(_b),
        'c': simplify(_c),
        'k': k,
        'mobius': mob,
        'type': "2F1"
    }
    return rn
Esempio n. 15
0
def _bench_R9():
    "factor(x^20 - pi^5*y^20)"
    factor(x**20 - pi**5*y**20)
Esempio n. 16
0
def test_factor():
    assert factor(A*B - B*A) == A*B - B*A
Esempio n. 17
0
def test_rsolve():
    f = y(n + 2) - y(n + 1) - y(n)
    h = sqrt(5)*(S.Half + S.Half*sqrt(5))**n \
        - sqrt(5)*(S.Half - S.Half*sqrt(5))**n

    assert rsolve(f, y(n)) in [
        C0 * (S.Half - S.Half * sqrt(5))**n + C1 *
        (S.Half + S.Half * sqrt(5))**n,
        C1 * (S.Half - S.Half * sqrt(5))**n + C0 *
        (S.Half + S.Half * sqrt(5))**n,
    ]

    assert rsolve(f, y(n), [0, 5]) == h
    assert rsolve(f, y(n), {0: 0, 1: 5}) == h
    assert rsolve(f, y(n), {y(0): 0, y(1): 5}) == h
    assert rsolve(y(n) - y(n - 1) - y(n - 2), y(n), [0, 5]) == h
    assert rsolve(Eq(y(n), y(n - 1) + y(n - 2)), y(n), [0, 5]) == h

    assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0

    f = (n - 1) * y(n + 2) - (n**2 + 3 * n - 2) * y(n + 1) + 2 * n * (n +
                                                                      1) * y(n)
    g = C1 * factorial(n) + C0 * 2**n
    h = -3 * factorial(n) + 3 * 2**n

    assert rsolve(f, y(n)) == g
    assert rsolve(f, y(n), []) == g
    assert rsolve(f, y(n), {}) == g

    assert rsolve(f, y(n), [0, 3]) == h
    assert rsolve(f, y(n), {0: 0, 1: 3}) == h
    assert rsolve(f, y(n), {y(0): 0, y(1): 3}) == h

    assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0

    f = y(n) - y(n - 1) - 2

    assert rsolve(f, y(n), {y(0): 0}) == 2 * n
    assert rsolve(f, y(n), {y(0): 1}) == 2 * n + 1
    assert rsolve(f, y(n), {y(0): 0, y(1): 1}) is None

    assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0

    f = 3 * y(n - 1) - y(n) - 1

    assert rsolve(f, y(n), {y(0): 0}) == -3**n / 2 + S.Half
    assert rsolve(f, y(n), {y(0): 1}) == 3**n / 2 + S.Half
    assert rsolve(f, y(n), {y(0): 2}) == 3 * 3**n / 2 + S.Half

    assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0

    f = y(n) - 1 / n * y(n - 1)
    assert rsolve(f, y(n)) == C0 / factorial(n)
    assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0

    f = y(n) - 1 / n * y(n - 1) - 1
    assert rsolve(f, y(n)) is None

    f = 2 * y(n - 1) + (1 - n) * y(n) / n

    assert rsolve(f, y(n), {y(1): 1}) == 2**(n - 1) * n
    assert rsolve(f, y(n), {y(1): 2}) == 2**(n - 1) * n * 2
    assert rsolve(f, y(n), {y(1): 3}) == 2**(n - 1) * n * 3

    assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0

    f = (n - 1) * (n - 2) * y(n + 2) - (n + 1) * (n + 2) * y(n)

    assert rsolve(f, y(n), {y(3): 6, y(4): 24}) == n * (n - 1) * (n - 2)
    assert rsolve(f, y(n), {
        y(3): 6,
        y(4): -24
    }) == -n * (n - 1) * (n - 2) * (-1)**(n)

    assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0

    assert rsolve(Eq(y(n + 1), a * y(n)), y(n), {y(1): a}).simplify() == a**n

    assert rsolve(y(n) - a*y(n-2),y(n), \
            {y(1): sqrt(a)*(a + b), y(2): a*(a - b)}).simplify() == \
            a**(n/2)*(-(-1)**n*b + a)

    f = (-16 * n**2 + 32 * n - 12) * y(n - 1) + (4 * n**2 - 12 * n + 9) * y(n)

    yn = rsolve(f, y(n), {y(1): binomial(2 * n + 1, 3)})
    sol = 2**(2 * n) * n * (2 * n - 1)**2 * (2 * n + 1) / 12
    assert factor(expand(yn, func=True)) == sol

    sol = rsolve(y(n) + a * (y(n + 1) + y(n - 1)) / 2, y(n))
    Y = lambda i: sol.subs(n, i)
    assert (Y(n) + a * (Y(n + 1) + Y(n - 1)) / 2).expand().cancel() == 0

    assert rsolve((k + 1) * y(k), y(k)) is None
    assert (rsolve((k + 1) * y(k) + (k + 3) * y(k + 1) +
                   (k + 5) * y(k + 2), y(k)) is None)

    assert rsolve(y(n) + y(n + 1) + 2**n + 3**n,
                  y(n)) == (-1)**n * C0 - 2**n / 3 - 3**n / 4
Esempio n. 18
0
def _solve_lambert(f, symbol, gens):
    """Return solution to ``f`` if it is a Lambert-type expression
    else raise NotImplementedError.

    For ``f(X, a..f) = a*log(b*X + c) + d*X - f = 0`` the solution
    for ``X`` is ``X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))``.
    There are a variety of forms for `f(X, a..f)` as enumerated below:

    1a1)
      if B**B = R for R not in [0, 1] (since those cases would already
      be solved before getting here) then log of both sides gives
      log(B) + log(log(B)) = log(log(R)) and
      X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R))
    1a2)
      if B*(b*log(B) + c)**a = R then log of both sides gives
      log(B) + a*log(b*log(B) + c) = log(R) and
      X = log(B), d=1, f=log(R)
    1b)
      if a*log(b*B + c) + d*B = R and
      X = B, f = R
    2a)
      if (b*B + c)*exp(d*B + g) = R then log of both sides gives
      log(b*B + c) + d*B + g = log(R) and
      X = B, a = 1, f = log(R) - g
    2b)
      if g*exp(d*B + h) - b*B = c then the log form is
      log(g) + d*B + h - log(b*B + c) = 0 and
      X = B, a = -1, f = -h - log(g)
    3)
      if d*p**(a*B + g) - b*B = c then the log form is
      log(d) + (a*B + g)*log(p) - log(b*B + c) = 0 and
      X = B, a = -1, d = a*log(p), f = -log(d) - g*log(p)
    """
    def _solve_even_degree_expr(expr, t, symbol):
        """Return the unique solutions of equations derived from
        ``expr`` by replacing ``t`` with ``+/- symbol``.

        Parameters
        ==========

        expr : Expr
            The expression which includes a dummy variable t to be
            replaced with +symbol and -symbol.

        symbol : Symbol
            The symbol for which a solution is being sought.

        Returns
        =======

        List of unique solution of the two equations generated by
        replacing ``t`` with positive and negative ``symbol``.

        Notes
        =====

        If ``expr = 2*log(t) + x/2` then solutions for
        ``2*log(x) + x/2 = 0`` and ``2*log(-x) + x/2 = 0`` are
        returned by this function. Though this may seem
        counter-intuitive, one must note that the ``expr`` being
        solved here has been derived from a different expression. For
        an expression like ``eq = x**2*g(x) = 1``, if we take the
        log of both sides we obtain ``log(x**2) + log(g(x)) = 0``. If
        x is positive then this simplifies to
        ``2*log(x) + log(g(x)) = 0``; the Lambert-solving routines will
        return solutions for this, but we must also consider the
        solutions for  ``2*log(-x) + log(g(x))`` since those must also
        be a solution of ``eq`` which has the same value when the ``x``
        in ``x**2`` is negated. If `g(x)` does not have even powers of
        symbol then we don't want to replace the ``x`` there with
        ``-x``. So the role of the ``t`` in the expression received by
        this function is to mark where ``+/-x`` should be inserted
        before obtaining the Lambert solutions.

        """
        nlhs, plhs = [expr.xreplace({t: sgn * symbol}) for sgn in (-1, 1)]
        sols = _solve_lambert(nlhs, symbol, gens)
        if plhs != nlhs:
            sols.extend(_solve_lambert(plhs, symbol, gens))
        # uniq is needed for a case like
        # 2*log(t) - log(-z**2) + log(z + log(x) + log(z))
        # where subtituting t with +/-x gives all the same solution;
        # uniq, rather than list(set()), is used to maintain canonical
        # order
        return list(uniq(sols))

    nrhs, lhs = f.as_independent(symbol, as_Add=True)
    rhs = -nrhs

    lamcheck = [
        tmp for tmp in gens if (tmp.func in [exp, log] or (
            tmp.is_Pow and symbol in tmp.exp.free_symbols))
    ]
    if not lamcheck:
        raise NotImplementedError()

    if lhs.is_Add or lhs.is_Mul:
        # replacing all even_degrees of symbol with dummy variable t
        # since these will need special handling; non-Add/Mul do not
        # need this handling
        t = Dummy('t', **symbol.assumptions0)
        lhs = lhs.replace(
            lambda i:  # find symbol**even
            i.is_Pow and i.base == symbol and i.exp.is_even,
            lambda i:  # replace t**even
            t**i.exp)

        if lhs.is_Add and lhs.has(t):
            t_indep = lhs.subs(t, 0)
            t_term = lhs - t_indep
            _rhs = rhs - t_indep
            if not t_term.is_Add and _rhs and not (t_term.has(
                    S.ComplexInfinity, S.NaN)):
                eq = expand_log(log(t_term) - log(_rhs))
                return _solve_even_degree_expr(eq, t, symbol)
        elif lhs.is_Mul and rhs:
            # this needs to happen whether t is present or not
            lhs = expand_log(log(lhs), force=True)
            rhs = log(rhs)
            if lhs.has(t) and lhs.is_Add:
                # it expanded from Mul to Add
                eq = lhs - rhs
                return _solve_even_degree_expr(eq, t, symbol)

        # restore symbol in lhs
        lhs = lhs.xreplace({t: symbol})

    lhs = powsimp(factor(lhs, deep=True))

    # make sure we have inverted as completely as possible
    r = Dummy()
    i, lhs = _invert(lhs - r, symbol)
    rhs = i.xreplace({r: rhs})

    # For the first forms:
    #
    # 1a1) B**B = R will arrive here as B*log(B) = log(R)
    #      lhs is Mul so take log of both sides:
    #        log(B) + log(log(B)) = log(log(R))
    # 1a2) B*(b*log(B) + c)**a = R will arrive unchanged so
    #      lhs is Mul, so take log of both sides:
    #        log(B) + a*log(b*log(B) + c) = log(R)
    # 1b) d*log(a*B + b) + c*B = R will arrive unchanged so
    #      lhs is Add, so isolate c*B and expand log of both sides:
    #        log(c) + log(B) = log(R - d*log(a*B + b))

    soln = []
    if not soln:
        mainlog = _mostfunc(lhs, log, symbol)
        if mainlog:
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(log(lhs) - log(rhs), symbol)
            elif lhs.is_Add:
                other = lhs.subs(mainlog, 0)
                if other and not other.is_Add and [
                        tmp for tmp in other.atoms(Pow)
                        if symbol in tmp.free_symbols
                ]:
                    if not rhs:
                        diff = log(other) - log(other - lhs)
                    else:
                        diff = log(lhs - other) - log(rhs - other)
                    soln = _lambert(expand_log(diff), symbol)
                else:
                    #it's ready to go
                    soln = _lambert(lhs - rhs, symbol)

    # For the next forms,
    #
    #     collect on main exp
    #     2a) (b*B + c)*exp(d*B + g) = R
    #         lhs is mul, so take log of both sides:
    #           log(b*B + c) + d*B = log(R) - g
    #     2b) g*exp(d*B + h) - b*B = R
    #         lhs is add, so add b*B to both sides,
    #         take the log of both sides and rearrange to give
    #           log(R + b*B) - d*B = log(g) + h

    if not soln:
        mainexp = _mostfunc(lhs, exp, symbol)
        if mainexp:
            lhs = collect(lhs, mainexp)
            if lhs.is_Mul and rhs != 0:
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainexp-containing term to rhs
                other = lhs.subs(mainexp, 0)
                mainterm = lhs - other
                rhs = rhs - other
                if (mainterm.could_extract_minus_sign()
                        and rhs.could_extract_minus_sign()):
                    mainterm *= -1
                    rhs *= -1
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    # For the last form:
    #
    #  3) d*p**(a*B + g) - b*B = c
    #     collect on main pow, add b*B to both sides,
    #     take log of both sides and rearrange to give
    #       a*B*log(p) - log(b*B + c) = -log(d) - g*log(p)
    if not soln:
        mainpow = _mostfunc(lhs, Pow, symbol)
        if mainpow and symbol in mainpow.exp.free_symbols:
            lhs = collect(lhs, mainpow)
            if lhs.is_Mul and rhs != 0:
                # b*B = 0
                soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
            elif lhs.is_Add:
                # move all but mainpow-containing term to rhs
                other = lhs.subs(mainpow, 0)
                mainterm = lhs - other
                rhs = rhs - other
                diff = log(mainterm) - log(rhs)
                soln = _lambert(expand_log(diff), symbol)

    if not soln:
        raise NotImplementedError('%s does not appear to have a solution in '
                                  'terms of LambertW' % f)

    return list(ordered(soln))
Esempio n. 19
0
 def _simplify(expr):
     if dom.is_Composite:
         return factor(expr)
     else:
         return simplify(expr)
Esempio n. 20
0
 def _simplify(expr):
     if dom.is_Composite:
         return factor(expr)
     else:
         return simplify(expr)
Esempio n. 21
0
def test_gauss_opt():
    mat = RayTransferMatrix(1, 2, 3, 4)
    assert mat == Matrix([[1, 2], [3, 4]])
    assert mat == RayTransferMatrix(Matrix([[1, 2], [3, 4]]))
    assert [mat.A, mat.B, mat.C, mat.D] == [1, 2, 3, 4]

    d, f, h, n1, n2, R = symbols('d f h n1 n2 R')
    lens = ThinLens(f)
    assert lens == Matrix([[1, 0], [-1 / f, 1]])
    assert lens.C == -1 / f
    assert FreeSpace(d) == Matrix([[1, d], [0, 1]])
    assert FlatRefraction(n1, n2) == Matrix([[1, 0], [0, n1 / n2]])
    assert CurvedRefraction(R, n1, n2) == Matrix([[1, 0],
                                                  [(n1 - n2) / (R * n2),
                                                   n1 / n2]])
    assert FlatMirror() == Matrix([[1, 0], [0, 1]])
    assert CurvedMirror(R) == Matrix([[1, 0], [-2 / R, 1]])
    assert ThinLens(f) == Matrix([[1, 0], [-1 / f, 1]])

    mul = CurvedMirror(R) * FreeSpace(d)
    mul_mat = Matrix([[1, 0], [-2 / R, 1]]) * Matrix([[1, d], [0, 1]])
    assert mul.A == mul_mat[0, 0]
    assert mul.B == mul_mat[0, 1]
    assert mul.C == mul_mat[1, 0]
    assert mul.D == mul_mat[1, 1]

    angle = symbols('angle')
    assert GeometricRay(h, angle) == Matrix([[h], [angle]])
    assert FreeSpace(d) * GeometricRay(h, angle) == Matrix([[angle * d + h],
                                                            [angle]])
    assert GeometricRay(Matrix(((h, ), (angle, )))) == Matrix([[h], [angle]])
    assert (FreeSpace(d) * GeometricRay(h, angle)).height == angle * d + h
    assert (FreeSpace(d) * GeometricRay(h, angle)).angle == angle

    p = BeamParameter(530e-9, 1, w=1e-3)
    assert streq(p.q, 1 + 1.88679245283019 * I * pi)
    assert streq(N(p.q), 1.0 + 5.92753330865999 * I)
    assert streq(N(p.w_0), Float(0.00100000000000000))
    assert streq(N(p.z_r), Float(5.92753330865999))
    fs = FreeSpace(10)
    p1 = fs * p
    assert streq(N(p.w), Float(0.00101413072159615))
    assert streq(N(p1.w), Float(0.00210803120913829))

    w, wavelen = symbols('w wavelen')
    assert waist2rayleigh(w, wavelen) == pi * w**2 / wavelen
    z_r, wavelen = symbols('z_r wavelen')
    assert rayleigh2waist(z_r, wavelen) == sqrt(wavelen * z_r) / sqrt(pi)

    a, b, f = symbols('a b f')
    assert geometric_conj_ab(a, b) == a * b / (a + b)
    assert geometric_conj_af(a, f) == a * f / (a - f)
    assert geometric_conj_bf(b, f) == b * f / (b - f)
    assert geometric_conj_ab(oo, b) == b
    assert geometric_conj_ab(a, oo) == a

    s_in, z_r_in, f = symbols('s_in z_r_in f')
    assert gaussian_conj(s_in, z_r_in,
                         f)[0] == 1 / (-1 / (s_in + z_r_in**2 /
                                             (-f + s_in)) + 1 / f)
    assert gaussian_conj(
        s_in, z_r_in, f)[1] == z_r_in / (1 - s_in**2 / f**2 + z_r_in**2 / f**2)
    assert gaussian_conj(
        s_in, z_r_in, f)[2] == 1 / sqrt(1 - s_in**2 / f**2 + z_r_in**2 / f**2)

    l, w_i, w_o, f = symbols('l w_i w_o f')
    assert conjugate_gauss_beams(
        l, w_i, w_o, f=f)[0] == f * (-sqrt(w_i**2 / w_o**2 - pi**2 * w_i**4 /
                                           (f**2 * l**2)) + 1)
    assert factor(conjugate_gauss_beams(
        l, w_i, w_o,
        f=f)[1]) == f * w_o**2 * (w_i**2 / w_o**2 -
                                  sqrt(w_i**2 / w_o**2 - pi**2 * w_i**4 /
                                       (f**2 * l**2))) / w_i**2
    assert conjugate_gauss_beams(l, w_i, w_o, f=f)[2] == f

    z, l, w = symbols('z l r', positive=True)
    p = BeamParameter(l, z, w=w)
    assert p.radius == z * (pi**2 * w**4 / (l**2 * z**2) + 1)
    assert p.w == w * sqrt(l**2 * z**2 / (pi**2 * w**4) + 1)
    assert p.w_0 == w
    assert p.divergence == l / (pi * w)
    assert p.gouy == atan2(z, pi * w**2 / l)
    assert p.waist_approximation_limit == 2 * l / pi