Пример #1
0
def test_log1p_opt():
    x = Symbol('x')
    expr1 = log(x + 1)
    opt1 = optimize(expr1, [log1p_opt])
    assert log1p(x) - opt1 == 0
    expr2 = log(3 * x + 3)
    opt2 = optimize(expr2, [log1p_opt])
    assert log1p(x) + log(3) == opt2
Пример #2
0
def test_subclass_print_method__ns():
    class MyPrinter(CXX11CodePrinter):
        _ns = "my_library::"

    p = CXX11CodePrinter()
    myp = MyPrinter()

    assert p.doprint(log1p(x)) == "std::log1p(x)"
    assert myp.doprint(log1p(x)) == "my_library::log1p(x)"
Пример #3
0
def test_subclass_print_method__ns():
    class MyPrinter(CXX11CodePrinter):
        _ns = 'my_library::'

    p = CXX11CodePrinter()
    myp = MyPrinter()

    assert p.doprint(log1p(x)) == 'std::log1p(x)'
    assert myp.doprint(log1p(x)) == 'my_library::log1p(x)'
Пример #4
0
def test_optims_c99():
    x = Symbol('x')

    expr1 = 2**x + log(x) / log(2) + log(x + 1) + exp(x) - 1
    opt1 = optimize(expr1, optims_c99).simplify()
    assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
    assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1

    expr2 = log(x) / log(2) + log(x + 1)
    opt2 = optimize(expr2, optims_c99)
    assert opt2 == log2(x) + log1p(x)
    assert opt2.rewrite(log) == expr2

    expr3 = log(x) / log(2) + log(17 * x + 17)
    opt3 = optimize(expr3, optims_c99)
    delta3 = opt3 - (log2(x) + log(17) + log1p(x))
    assert delta3 == 0
    assert (opt3.rewrite(log) - expr3).simplify() == 0

    expr4 = 2**x + 3 * log(5 * x + 7) / (13 * log(2)) + 11 * exp(x) - 11 + log(
        17 * x + 17)
    opt4 = optimize(expr4, optims_c99).simplify()
    delta4 = opt4 - (exp2(x) + 3 * log2(5 * x + 7) / 13 + 11 * expm1(x) +
                     log(17) + log1p(x))
    assert delta4 == 0
    assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) -
            expr4).simplify() == 0

    expr5 = 3 * exp(2 * x) - 3
    opt5 = optimize(expr5, optims_c99)
    delta5 = opt5 - 3 * expm1(2 * x)
    assert delta5 == 0
    assert opt5.rewrite(exp) == expr5

    expr6 = exp(2 * x) - 3
    opt6 = optimize(expr6, optims_c99)
    delta6 = opt6 - (exp(2 * x) - 3)
    assert delta6 == 0

    expr7 = log(3 * x + 3)
    opt7 = optimize(expr7, optims_c99)
    delta7 = opt7 - (log(3) + log1p(x))
    assert delta7 == 0
    assert (opt7.rewrite(log) - expr7).simplify() == 0

    expr8 = log(2 * x + 3)
    opt8 = optimize(expr8, optims_c99)
    assert opt8 == expr8
Пример #5
0
def test_subclass_print_method():
    class MyPrinter(CXX11CodePrinter):
        def _print_log1p(self, expr):
            return 'my_library::log1p(%s)' % ', '.join(
                map(self._print, expr.args))

    assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)'
Пример #6
0
def test_optims_c99():
    x = Symbol('x')

    expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
    opt1 = optimize(expr1, optims_c99).simplify()
    assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
    assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1

    expr2 = log(x)/log(2) + log(x + 1)
    opt2 = optimize(expr2, optims_c99)
    assert opt2 == log2(x) + log1p(x)
    assert opt2.rewrite(log) == expr2

    expr3 = log(x)/log(2) + log(17*x + 17)
    opt3 = optimize(expr3, optims_c99)
    delta3 = opt3 - (log2(x) + log(17) + log1p(x))
    assert delta3 == 0
    assert (opt3.rewrite(log) - expr3).simplify() == 0

    expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
    opt4 = optimize(expr4, optims_c99).simplify()
    delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
    assert delta4 == 0
    assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0

    expr5 = 3*exp(2*x) - 3
    opt5 = optimize(expr5, optims_c99)
    delta5 = opt5 - 3*expm1(2*x)
    assert delta5 == 0
    assert opt5.rewrite(exp) == expr5

    expr6 = exp(2*x) - 3
    opt6 = optimize(expr6, optims_c99)
    delta6 = opt6 - (exp(2*x) - 3)
    assert delta6 == 0

    expr7 = log(3*x + 3)
    opt7 = optimize(expr7, optims_c99)
    delta7 = opt7 - (log(3) + log1p(x))
    assert delta7 == 0
    assert (opt7.rewrite(log) - expr7).simplify() == 0

    expr8 = log(2*x + 3)
    opt8 = optimize(expr8, optims_c99)
    assert opt8 == expr8
Пример #7
0
def test_CXX11CodePrinter():
    assert CXX11CodePrinter().doprint(log1p(x)) == "std::log1p(x)"

    cxx11printer = CXX11CodePrinter()
    assert cxx11printer.language == "C++"
    assert cxx11printer.standard == "C++11"
    assert "operator" in cxx11printer.reserved_words
    assert "noexcept" in cxx11printer.reserved_words
    assert "concept" not in cxx11printer.reserved_words
Пример #8
0
def test_C99CodePrinter__precision():
    n = symbols('n', integer=True)
    f32_printer = C99CodePrinter(dict(type_aliases={real: float32}))
    f64_printer = C99CodePrinter(dict(type_aliases={real: float64}))
    f80_printer = C99CodePrinter(dict(type_aliases={real: float80}))
    assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
    assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
    assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'

    for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
        def check(expr, ref):
            assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
        check(Abs(n), 'abs(n)')
        check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
        check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
        check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
        check(exp2(x), 'exp2{s}(x)')
        check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
        check(Mod(n, 2), '((n) % (2))')
        check(Mod(2*n + 3, 3*n + 5), '((2*n + 3) % (3*n + 5))')
        check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
        check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
        check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
        check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
        check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
        check(log1p(x), 'log1p{s}(x)')
        check(2**x, 'pow{s}(2, x)')
        check(2.0**x, 'pow{s}(2.0{S}, x)')
        check(x**3, 'pow{s}(x, 3)')
        check(x**4.0, 'pow{s}(x, 4.0{S})')
        check(sqrt(3+x), 'sqrt{s}(x + 3)')
        check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
        check(hypot(x, y), 'hypot{s}(x, y)')
        check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
        check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
        check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
        check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
        check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
        check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
        check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')

        check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
        check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
        check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
        check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
        check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
        check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
        check(erf(42.*x), 'erf{s}(42.0{S}*x)')
        check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
        check(gamma(x), 'tgamma{s}(x)')
        check(loggamma(x), 'lgamma{s}(x)')

        check(ceiling(x + 2.), "ceil{s}(x + 2.0{S})")
        check(floor(x + 2.), "floor{s}(x + 2.0{S})")
        check(fma(x, y, -z), 'fma{s}(x, y, -z)')
        check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
        check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
Пример #9
0
def test_CXX11CodePrinter():
    assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)'

    cxx11printer = CXX11CodePrinter()
    assert cxx11printer.language == 'C++'
    assert cxx11printer.standard == 'C++11'
    assert 'operator' in cxx11printer.reserved_words
    assert 'noexcept' in cxx11printer.reserved_words
    assert 'concept' not in cxx11printer.reserved_words
Пример #10
0
def test_C99CodePrinter__precision():
    n = symbols('n', integer=True)
    f32_printer = C99CodePrinter(dict(type_aliases={real: float32}))
    f64_printer = C99CodePrinter(dict(type_aliases={real: float64}))
    f80_printer = C99CodePrinter(dict(type_aliases={real: float80}))
    assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
    assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
    assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'

    for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
        def check(expr, ref):
            assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
        check(Abs(n), 'abs(n)')
        check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
        check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
        check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
        check(exp2(x), 'exp2{s}(x)')
        check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
        check(Mod(n, 2), '((n) % (2))')
        check(Mod(2*n + 3, 3*n + 5), '((2*n + 3) % (3*n + 5))')
        check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
        check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
        check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
        check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
        check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
        check(log1p(x), 'log1p{s}(x)')
        check(2**x, 'pow{s}(2, x)')
        check(2.0**x, 'pow{s}(2.0{S}, x)')
        check(x**3, 'pow{s}(x, 3)')
        check(x**4.0, 'pow{s}(x, 4.0{S})')
        check(sqrt(3+x), 'sqrt{s}(x + 3)')
        check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
        check(hypot(x, y), 'hypot{s}(x, y)')
        check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
        check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
        check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
        check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
        check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
        check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
        check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')

        check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
        check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
        check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
        check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
        check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
        check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
        check(erf(42.*x), 'erf{s}(42.0{S}*x)')
        check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
        check(gamma(x), 'tgamma{s}(x)')
        check(loggamma(x), 'lgamma{s}(x)')

        check(ceiling(x + 2.), "ceil{s}(x + 2.0{S})")
        check(floor(x + 2.), "floor{s}(x + 2.0{S})")
        check(fma(x, y, -z), 'fma{s}(x, y, -z)')
        check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
        check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
Пример #11
0
def test_CXX11CodePrinter():
    assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)'

    cxx11printer = CXX11CodePrinter()
    assert cxx11printer.language == 'C++'
    assert cxx11printer.standard == 'C++11'
    assert 'operator' in cxx11printer.reserved_words
    assert 'noexcept' in cxx11printer.reserved_words
    assert 'concept' not in cxx11printer.reserved_words
Пример #12
0
def test_optims_numpy_TODO():
    def check(d):
        for k, v in d.items():
            assert optimize(k, optims_numpy) == v

    x, y = map(Symbol, 'x y'.split())
    check({
        log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y),
        exp(x*sin(y)/y) - 1: expm1(x*sinc(y))
    })
Пример #13
0
def test_optims_numpy():
    def check(d):
        for k, v in d.items():
            assert optimize(k, optims_numpy) == v

    x = Symbol('x')
    check({
        sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x),
        log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3)
    })
Пример #14
0
def test_log1p_opt():
    x = Symbol('x')
    expr1 = log(x + 1)
    opt1 = optimize(expr1, [log1p_opt])
    assert log1p(x) - opt1 == 0
    assert opt1.rewrite(log) == expr1

    expr2 = log(3*x + 3)
    opt2 = optimize(expr2, [log1p_opt])
    assert log1p(x) + log(3) == opt2
    assert (opt2.rewrite(log) - expr2).simplify() == 0

    expr3 = log(2*x + 1)
    opt3 = optimize(expr3, [log1p_opt])
    assert log1p(2*x) - opt3 == 0
    assert opt3.rewrite(log) == expr3

    expr4 = log(x+3)
    opt4 = optimize(expr4, [log1p_opt])
    assert str(opt4) == 'log(x + 3)'
Пример #15
0
def test_log1p_opt():
    x = Symbol('x')
    expr1 = log(x + 1)
    opt1 = optimize(expr1, [log1p_opt])
    assert log1p(x) - opt1 == 0
    assert opt1.rewrite(log) == expr1

    expr2 = log(3 * x + 3)
    opt2 = optimize(expr2, [log1p_opt])
    assert log1p(x) + log(3) == opt2
    assert (opt2.rewrite(log) - expr2).simplify() == 0

    expr3 = log(2 * x + 1)
    opt3 = optimize(expr3, [log1p_opt])
    assert log1p(2 * x) - opt3 == 0
    assert opt3.rewrite(log) == expr3

    expr4 = log(x + 3)
    opt4 = optimize(expr4, [log1p_opt])
    assert str(opt4) == 'log(x + 3)'
Пример #16
0
def test_optims_c99():
    x = Symbol('x')

    expr1 = 2**x + log(x) / log(2) + log(x + 1) + exp(x) - 1
    opt1 = optimize(expr1, optims_c99).simplify()
    assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)

    expr2 = log(x) / log(2) + log(x + 1)
    print()
    opt2 = optimize(expr2, optims_c99)
    assert opt2 == log2(x) + log1p(x)

    expr3 = log(x) / log(2) + log(17 * x + 17)
    opt3 = optimize(expr3, optims_c99)
    delta3 = opt3 - (log2(x) + log(17) + log1p(x))
    assert delta3 == 0

    expr4 = 2**x + 3 * log(5 * x + 7) / (13 * log(2)) + 11 * exp(x) - 11 + log(
        17 * x + 17)
    opt4 = optimize(expr4, optims_c99).simplify()
    delta4 = opt4 - (exp2(x) + 3 * log2(5 * x + 7) / 13 + 11 * expm1(x) +
                     log(17) + log1p(x))
    assert delta4 == 0
Пример #17
0
 def as_FunctionDefinition(variables=None, name="logsumexp"):
     variables = variables or {}
     n = variables.get('n', Symbol('n', integer=True))
     i = variables.get('i', Symbol('i', integer=True))
     s = variables.get('s', Symbol('s'))
     x = variables.get('x', IndexedBase('x', shape=(n, )))
     body = CodeBlock(
         FunctionCall('sympy_quicksort', (x, n), statement=True),
         Declaration(i, 0), Declaration(s, 0.0),
         While(
             i < n - 1,
             CodeBlock(AddAugmentedAssignment(s, exp(x[i])),
                       AddAugmentedAssignment(i, 1))),
         ReturnStatement(log1p(s) + x[n - 1]))
     return FunctionDefinition("real", name, (x, n), body)
Пример #18
0
def test_C99CodePrinter():
    assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
    assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
    assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
    assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
    assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
    assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
    assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)'  # note Cbrt due to cbrt already taken.
    assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
    assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
    assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
    assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
    c99printer = C99CodePrinter()
    assert c99printer.language == 'C'
    assert c99printer.standard == 'C99'
    assert 'restrict' in c99printer.reserved_words
    assert 'using' not in c99printer.reserved_words
Пример #19
0
def test_C99CodePrinter():
    assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
    assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
    assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
    assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
    assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
    assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
    assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)'  # note Cbrt due to cbrt already taken.
    assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
    assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
    assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
    assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
    c99printer = C99CodePrinter()
    assert c99printer.language == 'C'
    assert c99printer.standard == 'C99'
    assert 'restrict' in c99printer.reserved_words
    assert 'using' not in c99printer.reserved_words
Пример #20
0
def test_C99CodePrinter():
    assert C99CodePrinter().doprint(expm1(x)) == "expm1(x)"
    assert C99CodePrinter().doprint(log1p(x)) == "log1p(x)"
    assert C99CodePrinter().doprint(exp2(x)) == "exp2(x)"
    assert C99CodePrinter().doprint(log2(x)) == "log2(x)"
    assert C99CodePrinter().doprint(fma(x, y, -z)) == "fma(x, y, -z)"
    assert C99CodePrinter().doprint(log10(x)) == "log10(x)"
    assert (
        C99CodePrinter().doprint(Cbrt(x)) == "cbrt(x)"
    )  # note Cbrt due to cbrt already taken.
    assert C99CodePrinter().doprint(hypot(x, y)) == "hypot(x, y)"
    assert C99CodePrinter().doprint(loggamma(x)) == "lgamma(x)"
    assert C99CodePrinter().doprint(Max(x, 3, x ** 2)) == "fmax(3, fmax(x, pow(x, 2)))"
    assert C99CodePrinter().doprint(Min(x, 3)) == "fmin(3, x)"
    c99printer = C99CodePrinter()
    assert c99printer.language == "C"
    assert c99printer.standard == "C99"
    assert "restrict" in c99printer.reserved_words
    assert "using" not in c99printer.reserved_words
Пример #21
0
def test_log1p():
    if not np:
        skip("NumPy not installed")

    f = lambdify((a,), log1p(a), 'numpy')
    assert abs(f(1e-99) - 1e-99) < 1e-100
Пример #22
0
w = Wild('w')

log2_opt = ReplaceOptim(
    v * log(w) / log(2),
    v * log2(w),
    cost_function=lambda expr: expr.count(lambda e: (
        (isinstance(e, Pow) and e.exp.is_negative) or
        (isinstance(e, (log, log2)) and not e.args[0].is_number))))

log2const_opt = ReplaceOptim(log(2) * log2(w), log(w))

logsumexp_2terms_opt = ReplaceOptim(
    lambda l:
    (isinstance(l, log) and isinstance(l.args[0], Add) and len(l.args[0].args)
     == 2 and all(isinstance(t, exp) for t in l.args[0].args)), lambda l:
    (Max(*[e.args[0] for e in l.args[0].args]) + log1p(
        exp(Min(*[e.args[0] for e in l.args[0].args])))))


def _partition(predicate, iterable):
    iter_a, iter_b = tee(iterable)
    return tuple(filter(predicate,
                        iter_a)), tuple(filterfalse(predicate, iter_b))


def _try_expm1(e):
    return e.factor().replace(exp(w) - 1, expm1(w))


def _expm1_value(e):
    def isnum(arg):
        return arg.is_number
Пример #23
0
    _v * log(_w) / log(2),
    _v * log2(_w),
    cost_function=lambda expr: expr.count(
        lambda e:
        (  # division & eval of transcendentals are expensive floating point operations...
            e.is_Pow and e.exp.is_negative  # division
            or (isinstance(e, (log, log2)) and not e.args[0].is_number)
        )  # transcendental
    ))

log2const_opt = ReplaceOptim(log(2) * log2(_w), log(_w))

logsumexp_2terms_opt = ReplaceOptim(
    lambda l: (isinstance(l, log) and l.args[0].is_Add and len(l.args[0].args)
               == 2 and all(isinstance(t, exp) for t in l.args[0].args)),
    lambda l: (Max(*[e.args[0] for e in l.args[0].args]) + log1p(
        exp(Min(*[e.args[0] for e in l.args[0].args])))))


def _try_expm1(expr):
    protected, old_new = expr.replace(exp, lambda arg: Dummy(), map=True)
    factored = protected.factor()
    new_old = {v: k for k, v in old_new.items()}
    return factored.replace(
        _d - 1, lambda d: expm1(new_old[d].args[0])).xreplace(new_old)


def _expm1_value(e):
    numbers, non_num = sift(e.args, lambda arg: arg.is_number, binary=True)
    non_num_exp, non_num_other = sift(non_num,
                                      lambda arg: arg.has(exp),
                                      binary=True)
Пример #24
0
def test_subclass_print_method():
    class MyPrinter(CXX11CodePrinter):
        def _print_log1p(self, expr):
            return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args))

    assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)'
Пример #25
0
def test_log1p():
    # Eval
    assert log1p(0) == 0
    d = S(10)
    assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0

    x = Symbol('x', real=True)

    # Expand and rewrite
    assert log1p(x).expand(func=True) - log(x + 1) == 0
    assert log1p(x).rewrite('tractable') - log(x + 1) == 0
    assert log1p(x).rewrite('log') - log(x + 1) == 0

    # Precision
    assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100  # for comparison
    assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100

    # Properties
    assert log1p(-2**Rational(-1, 2)).is_real

    assert not log1p(-1).is_finite
    assert log1p(pi).is_finite

    assert not log1p(x).is_positive
    assert log1p(Symbol('y', positive=True)).is_positive

    assert not log1p(x).is_zero
    assert log1p(Symbol('z', zero=True)).is_zero

    assert not log1p(x).is_nonnegative
    assert log1p(Symbol('o', nonnegative=True)).is_nonnegative

    # Diff
    assert log1p(42 * x).diff(x) - 42 / (42 * x + 1) == 0
    assert log1p(42 * x).diff(x) - log1p(42 * x).expand(func=True).diff(x) == 0
Пример #26
0
    lambda e: (  # division & eval of transcendentals are expensive floating point operations...
        e.is_Pow and e.exp.is_negative  # division
        or (isinstance(e, (log, log2)) and not e.args[0].is_number))  # transcendental
    )
)

log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))

logsumexp_2terms_opt = ReplaceOptim(
    lambda l: (isinstance(l, log)
               and l.args[0].is_Add
               and len(l.args[0].args) == 2
               and all(isinstance(t, exp) for t in l.args[0].args)),
    lambda l: (
        Max(*[e.args[0] for e in l.args[0].args]) +
        log1p(exp(Min(*[e.args[0] for e in l.args[0].args])))
    )
)


def _try_expm1(expr):
    protected, old_new =  expr.replace(exp, lambda arg: Dummy(), map=True)
    factored = protected.factor()
    new_old = {v: k for k, v in old_new.items()}
    return factored.replace(_d - 1, lambda d: expm1(new_old[d].args[0])).xreplace(new_old)


def _expm1_value(e):
    numbers, non_num = sift(e.args, lambda arg: arg.is_number, binary=True)
    non_num_exp, non_num_other = sift(non_num, lambda arg: arg.has(exp),
        binary=True)
Пример #27
0
def test_C99CodePrinter__precision():
    n = symbols("n", integer=True)
    f32_printer = C99CodePrinter(dict(type_aliases={real: float32}))
    f64_printer = C99CodePrinter(dict(type_aliases={real: float64}))
    f80_printer = C99CodePrinter(dict(type_aliases={real: float80}))
    assert f32_printer.doprint(sin(x + 2.1)) == "sinf(x + 2.1F)"
    assert f64_printer.doprint(sin(x + 2.1)) == "sin(x + 2.1000000000000001)"
    assert f80_printer.doprint(sin(x + Float("2.0"))) == "sinl(x + 2.0L)"

    for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ["f", "", "l"]):

        def check(expr, ref):
            assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())

        check(Abs(n), "abs(n)")
        check(Abs(x + 2.0), "fabs{s}(x + 2.0{S})")
        check(
            sin(x + 4.0) ** cos(x - 2.0),
            "pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))",
        )
        check(exp(x * 8.0), "exp{s}(8.0{S}*x)")
        check(exp2(x), "exp2{s}(x)")
        check(expm1(x * 4.0), "expm1{s}(4.0{S}*x)")
        check(Mod(n, 2), "((n) % (2))")
        check(Mod(2 * n + 3, 3 * n + 5), "((2*n + 3) % (3*n + 5))")
        check(Mod(x + 2.0, 3.0), "fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})")
        check(Mod(x, 2.0 * x + 3.0), "fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})")
        check(log(x / 2), "log{s}((1.0{S}/2.0{S})*x)")
        check(log10(3 * x / 2), "log10{s}((3.0{S}/2.0{S})*x)")
        check(log2(x * 8.0), "log2{s}(8.0{S}*x)")
        check(log1p(x), "log1p{s}(x)")
        check(2 ** x, "pow{s}(2, x)")
        check(2.0 ** x, "pow{s}(2.0{S}, x)")
        check(x ** 3, "pow{s}(x, 3)")
        check(x ** 4.0, "pow{s}(x, 4.0{S})")
        check(sqrt(3 + x), "sqrt{s}(x + 3)")
        check(Cbrt(x - 2.0), "cbrt{s}(x - 2.0{S})")
        check(hypot(x, y), "hypot{s}(x, y)")
        check(sin(3.0 * x + 2.0), "sin{s}(3.0{S}*x + 2.0{S})")
        check(cos(3.0 * x - 1.0), "cos{s}(3.0{S}*x - 1.0{S})")
        check(tan(4.0 * y + 2.0), "tan{s}(4.0{S}*y + 2.0{S})")
        check(asin(3.0 * x + 2.0), "asin{s}(3.0{S}*x + 2.0{S})")
        check(acos(3.0 * x + 2.0), "acos{s}(3.0{S}*x + 2.0{S})")
        check(atan(3.0 * x + 2.0), "atan{s}(3.0{S}*x + 2.0{S})")
        check(atan2(3.0 * x, 2.0 * y), "atan2{s}(3.0{S}*x, 2.0{S}*y)")

        check(sinh(3.0 * x + 2.0), "sinh{s}(3.0{S}*x + 2.0{S})")
        check(cosh(3.0 * x - 1.0), "cosh{s}(3.0{S}*x - 1.0{S})")
        check(tanh(4.0 * y + 2.0), "tanh{s}(4.0{S}*y + 2.0{S})")
        check(asinh(3.0 * x + 2.0), "asinh{s}(3.0{S}*x + 2.0{S})")
        check(acosh(3.0 * x + 2.0), "acosh{s}(3.0{S}*x + 2.0{S})")
        check(atanh(3.0 * x + 2.0), "atanh{s}(3.0{S}*x + 2.0{S})")
        check(erf(42.0 * x), "erf{s}(42.0{S}*x)")
        check(erfc(42.0 * x), "erfc{s}(42.0{S}*x)")
        check(gamma(x), "tgamma{s}(x)")
        check(loggamma(x), "lgamma{s}(x)")

        check(ceiling(x + 2.0), "ceil{s}(x + 2.0{S})")
        check(floor(x + 2.0), "floor{s}(x + 2.0{S})")
        check(fma(x, y, -z), "fma{s}(x, y, -z)")
        check(Max(x, 8.0, x ** 4.0), "fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))")
        check(Min(x, 2.0), "fmin{s}(2.0{S}, x)")
Пример #28
0
def test_numerical_accuracy_functions():
    prntr = SciPyPrinter()
    assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)'
    assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)'
    assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
Пример #29
0
def test_log1p():
    if not np:
        skip("NumPy not installed")

    f = lambdify((a,), log1p(a), 'numpy')
    assert abs(f(1e-99) - 1e-99) < 1e-100
Пример #30
0
def test_log1p():
    # Eval
    assert log1p(0) == 0
    d = S(10)
    assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0

    x = Symbol('x', real=True, finite=True)

    # Expand and rewrite
    assert log1p(x).expand(func=True) - log(x + 1) == 0
    assert log1p(x).rewrite('tractable') - log(x + 1) == 0
    assert log1p(x).rewrite('log') - log(x + 1) == 0

    # Precision
    assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100  # for comparison
    assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100

    # Properties
    assert log1p(-2**(-S(1)/2)).is_real

    assert not log1p(-1).is_finite
    assert log1p(pi).is_finite

    assert not log1p(x).is_positive
    assert log1p(Symbol('y', positive=True)).is_positive

    assert not log1p(x).is_zero
    assert log1p(Symbol('z', zero=True)).is_zero

    assert not log1p(x).is_nonnegative
    assert log1p(Symbol('o', nonnegative=True)).is_nonnegative

    # Diff
    assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0
    assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0