示例#1
0
def test_log2_opt():
    x = Symbol('x')
    expr1 = 7 * log(3 * x + 5) / (log(2))
    opt1 = optimize(expr1, [log2_opt])
    assert opt1 == 7 * log2(3 * x + 5)
    assert opt1.rewrite(log) == expr1

    expr2 = 3 * log(5 * x + 7) / (13 * log(2))
    opt2 = optimize(expr2, [log2_opt])
    assert opt2 == 3 * log2(5 * x + 7) / 13
    assert opt2.rewrite(log) == expr2

    expr3 = log(x) / log(2)
    opt3 = optimize(expr3, [log2_opt])
    assert opt3 == log2(x)
    assert opt3.rewrite(log) == expr3

    expr4 = log(x) / log(2) + log(x + 1)
    opt4 = optimize(expr4, [log2_opt])
    assert opt4 == log2(x) + log(2) * log2(x + 1)
    assert opt4.rewrite(log) == expr4

    expr5 = log(17)
    opt5 = optimize(expr5, [log2_opt])
    assert opt5 == expr5

    expr6 = log(x + 3) / log(2)
    opt6 = optimize(expr6, [log2_opt])
    assert str(opt6) == 'log2(x + 3)'
    assert opt6.rewrite(log) == expr6
示例#2
0
def test_numpy_special_math():
    if not numpy:
        skip("numpy not installed")

    funcs = [expm1, log1p, exp2, log2, log10, hypot, logaddexp, logaddexp2]
    for func in funcs:
        if 2 in func.nargs:
            expr = func(x, y)
            args = (x, y)
            num_args = (0.3, 0.4)
        elif 1 in func.nargs:
            expr = func(x)
            args = (x, )
            num_args = (0.3, )
        else:
            raise NotImplementedError(
                "Need to handle other than unary & binary functions in test")
        f = lambdify(args, expr)
        result = f(*num_args)
        reference = expr.subs(dict(zip(args, num_args))).evalf()
        assert numpy.allclose(result, float(reference))

    lae2 = lambdify((x, y), logaddexp2(log2(x), log2(y)))
    assert abs(2.0**lae2(1e-50, 2.5e-50) -
               3.5e-50) < 1e-62  # from NumPy's docstring
示例#3
0
def test_log2_opt():
    x = Symbol('x')
    expr1 = 7*log(3*x + 5)/(log(2))
    opt1 = optimize(expr1, [log2_opt])
    assert opt1 == 7*log2(3*x + 5)
    assert opt1.rewrite(log) == expr1

    expr2 = 3*log(5*x + 7)/(13*log(2))
    opt2 = optimize(expr2, [log2_opt])
    assert opt2 == 3*log2(5*x + 7)/13
    assert opt2.rewrite(log) == expr2

    expr3 = log(x)/log(2)
    opt3 = optimize(expr3, [log2_opt])
    assert opt3 == log2(x)
    assert opt3.rewrite(log) == expr3

    expr4 = log(x)/log(2) + log(x+1)
    opt4 = optimize(expr4, [log2_opt])
    assert opt4 == log2(x) + log(2)*log2(x+1)
    assert opt4.rewrite(log) == expr4

    expr5 = log(17)
    opt5 = optimize(expr5, [log2_opt])
    assert opt5 == expr5

    expr6 = log(x + 3)/log(2)
    opt6 = optimize(expr6, [log2_opt])
    assert str(opt6) == 'log2(x + 3)'
    assert opt6.rewrite(log) == expr6
示例#4
0
def test_optims_c99():
    x = Symbol('x')

    expr1 = 2**x + log(x) / log(2) + log(x + 1) + exp(x) - 1
    opt1 = optimize(expr1, optims_c99).simplify()
    assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
    assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1

    expr2 = log(x) / log(2) + log(x + 1)
    opt2 = optimize(expr2, optims_c99)
    assert opt2 == log2(x) + log1p(x)
    assert opt2.rewrite(log) == expr2

    expr3 = log(x) / log(2) + log(17 * x + 17)
    opt3 = optimize(expr3, optims_c99)
    delta3 = opt3 - (log2(x) + log(17) + log1p(x))
    assert delta3 == 0
    assert (opt3.rewrite(log) - expr3).simplify() == 0

    expr4 = 2**x + 3 * log(5 * x + 7) / (13 * log(2)) + 11 * exp(x) - 11 + log(
        17 * x + 17)
    opt4 = optimize(expr4, optims_c99).simplify()
    delta4 = opt4 - (exp2(x) + 3 * log2(5 * x + 7) / 13 + 11 * expm1(x) +
                     log(17) + log1p(x))
    assert delta4 == 0
    assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) -
            expr4).simplify() == 0

    expr5 = 3 * exp(2 * x) - 3
    opt5 = optimize(expr5, optims_c99)
    delta5 = opt5 - 3 * expm1(2 * x)
    assert delta5 == 0
    assert opt5.rewrite(exp) == expr5

    expr6 = exp(2 * x) - 3
    opt6 = optimize(expr6, optims_c99)
    delta6 = opt6 - (exp(2 * x) - 3)
    assert delta6 == 0

    expr7 = log(3 * x + 3)
    opt7 = optimize(expr7, optims_c99)
    delta7 = opt7 - (log(3) + log1p(x))
    assert delta7 == 0
    assert (opt7.rewrite(log) - expr7).simplify() == 0

    expr8 = log(2 * x + 3)
    opt8 = optimize(expr8, optims_c99)
    assert opt8 == expr8
示例#5
0
def get_math_macros():
    """ Returns a dictionary with math-related macros from math.h/cmath

    Note that these macros are not strictly required by the C/C++-standard.
    For MSVC they are enabled by defining "_USE_MATH_DEFINES" (preferably
    via a compilation flag).

    Returns
    =======

    Dictionary mapping sympy expressions to strings (macro names)

    """
    from sympy.codegen.cfunctions import log2, Sqrt
    from sympy.functions.elementary.exponential import log
    from sympy.functions.elementary.miscellaneous import sqrt

    return {
        S.Exp1: 'M_E',
        log2(S.Exp1): 'M_LOG2E',
        1/log(2): 'M_LOG2E',
        log(2): 'M_LN2',
        log(10): 'M_LN10',
        S.Pi: 'M_PI',
        S.Pi/2: 'M_PI_2',
        S.Pi/4: 'M_PI_4',
        1/S.Pi: 'M_1_PI',
        2/S.Pi: 'M_2_PI',
        2/sqrt(S.Pi): 'M_2_SQRTPI',
        2/Sqrt(S.Pi): 'M_2_SQRTPI',
        sqrt(2): 'M_SQRT2',
        Sqrt(2): 'M_SQRT2',
        1/sqrt(2): 'M_SQRT1_2',
        1/Sqrt(2): 'M_SQRT1_2'
    }
示例#6
0
def test_optims_c99():
    x = Symbol('x')

    expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
    opt1 = optimize(expr1, optims_c99).simplify()
    assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
    assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1

    expr2 = log(x)/log(2) + log(x + 1)
    opt2 = optimize(expr2, optims_c99)
    assert opt2 == log2(x) + log1p(x)
    assert opt2.rewrite(log) == expr2

    expr3 = log(x)/log(2) + log(17*x + 17)
    opt3 = optimize(expr3, optims_c99)
    delta3 = opt3 - (log2(x) + log(17) + log1p(x))
    assert delta3 == 0
    assert (opt3.rewrite(log) - expr3).simplify() == 0

    expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
    opt4 = optimize(expr4, optims_c99).simplify()
    delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
    assert delta4 == 0
    assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0

    expr5 = 3*exp(2*x) - 3
    opt5 = optimize(expr5, optims_c99)
    delta5 = opt5 - 3*expm1(2*x)
    assert delta5 == 0
    assert opt5.rewrite(exp) == expr5

    expr6 = exp(2*x) - 3
    opt6 = optimize(expr6, optims_c99)
    delta6 = opt6 - (exp(2*x) - 3)
    assert delta6 == 0

    expr7 = log(3*x + 3)
    opt7 = optimize(expr7, optims_c99)
    delta7 = opt7 - (log(3) + log1p(x))
    assert delta7 == 0
    assert (opt7.rewrite(log) - expr7).simplify() == 0

    expr8 = log(2*x + 3)
    opt8 = optimize(expr8, optims_c99)
    assert opt8 == expr8
示例#7
0
def test_C99CodePrinter__precision():
    n = symbols('n', integer=True)
    f32_printer = C99CodePrinter(dict(type_aliases={real: float32}))
    f64_printer = C99CodePrinter(dict(type_aliases={real: float64}))
    f80_printer = C99CodePrinter(dict(type_aliases={real: float80}))
    assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
    assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
    assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'

    for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
        def check(expr, ref):
            assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
        check(Abs(n), 'abs(n)')
        check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
        check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
        check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
        check(exp2(x), 'exp2{s}(x)')
        check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
        check(Mod(n, 2), '((n) % (2))')
        check(Mod(2*n + 3, 3*n + 5), '((2*n + 3) % (3*n + 5))')
        check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
        check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
        check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
        check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
        check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
        check(log1p(x), 'log1p{s}(x)')
        check(2**x, 'pow{s}(2, x)')
        check(2.0**x, 'pow{s}(2.0{S}, x)')
        check(x**3, 'pow{s}(x, 3)')
        check(x**4.0, 'pow{s}(x, 4.0{S})')
        check(sqrt(3+x), 'sqrt{s}(x + 3)')
        check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
        check(hypot(x, y), 'hypot{s}(x, y)')
        check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
        check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
        check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
        check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
        check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
        check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
        check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')

        check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
        check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
        check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
        check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
        check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
        check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
        check(erf(42.*x), 'erf{s}(42.0{S}*x)')
        check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
        check(gamma(x), 'tgamma{s}(x)')
        check(loggamma(x), 'lgamma{s}(x)')

        check(ceiling(x + 2.), "ceil{s}(x + 2.0{S})")
        check(floor(x + 2.), "floor{s}(x + 2.0{S})")
        check(fma(x, y, -z), 'fma{s}(x, y, -z)')
        check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
        check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
示例#8
0
def test_C99CodePrinter__precision():
    n = symbols('n', integer=True)
    f32_printer = C99CodePrinter(dict(type_aliases={real: float32}))
    f64_printer = C99CodePrinter(dict(type_aliases={real: float64}))
    f80_printer = C99CodePrinter(dict(type_aliases={real: float80}))
    assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
    assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
    assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'

    for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
        def check(expr, ref):
            assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
        check(Abs(n), 'abs(n)')
        check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
        check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
        check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
        check(exp2(x), 'exp2{s}(x)')
        check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
        check(Mod(n, 2), '((n) % (2))')
        check(Mod(2*n + 3, 3*n + 5), '((2*n + 3) % (3*n + 5))')
        check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
        check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
        check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
        check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
        check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
        check(log1p(x), 'log1p{s}(x)')
        check(2**x, 'pow{s}(2, x)')
        check(2.0**x, 'pow{s}(2.0{S}, x)')
        check(x**3, 'pow{s}(x, 3)')
        check(x**4.0, 'pow{s}(x, 4.0{S})')
        check(sqrt(3+x), 'sqrt{s}(x + 3)')
        check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
        check(hypot(x, y), 'hypot{s}(x, y)')
        check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
        check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
        check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
        check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
        check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
        check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
        check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')

        check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
        check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
        check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
        check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
        check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
        check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
        check(erf(42.*x), 'erf{s}(42.0{S}*x)')
        check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
        check(gamma(x), 'tgamma{s}(x)')
        check(loggamma(x), 'lgamma{s}(x)')

        check(ceiling(x + 2.), "ceil{s}(x + 2.0{S})")
        check(floor(x + 2.), "floor{s}(x + 2.0{S})")
        check(fma(x, y, -z), 'fma{s}(x, y, -z)')
        check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
        check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
示例#9
0
def test_optims_numpy_TODO():
    def check(d):
        for k, v in d.items():
            assert optimize(k, optims_numpy) == v

    x, y = map(Symbol, 'x y'.split())
    check({
        log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y),
        exp(x*sin(y)/y) - 1: expm1(x*sinc(y))
    })
示例#10
0
def test_optims_numpy():
    def check(d):
        for k, v in d.items():
            assert optimize(k, optims_numpy) == v

    x = Symbol('x')
    check({
        sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x),
        log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3)
    })
示例#11
0
def test_log2_opt():
    x = Symbol('x')
    expr1 = 7 * log(3 * x + 5) / (log(2))
    opt1 = optimize(expr1, [log2_opt])
    assert opt1 == 7 * log2(3 * x + 5)

    expr2 = 3 * log(5 * x + 7) / (13 * log(2))
    opt2 = optimize(expr2, [log2_opt])
    assert opt2 == 3 * log2(5 * x + 7) / 13

    expr3 = log(x) / log(2)
    opt3 = optimize(expr3, [log2_opt])
    assert opt3 == log2(x)

    expr4 = log(x) / log(2) + log(x + 1)
    opt4 = optimize(expr4, [log2_opt])
    assert opt4 == log2(x) + log(2) * log2(x + 1)

    expr5 = log(17)
    opt5 = optimize(expr5, [log2_opt])
    assert opt5 == expr5
示例#12
0
def test_log2():
    # Eval
    assert log2(8) == 3
    assert log2(pi) != log(pi)/log(2)  # log2 should *save* (CPU) instructions

    x = Symbol('x', real=True, finite=True)
    assert log2(x) != log(x)/log(2)
    assert log2(2**x) == x

    # Expand
    assert log2(x).expand(func=True) - log(x)/log(2) == 0

    # Diff
    assert log2(42*x).diff() - 1/(log(2)*x) == 0
    assert log2(42*x).diff() - log2(42*x).expand(func=True).diff(x) == 0
示例#13
0
def test_log2():
    # Eval
    assert log2(8) == 3
    assert log2(pi) != log(pi) / log(
        2)  # log2 should *save* (CPU) instructions

    x = Symbol('x', real=True)
    assert log2(x) != log(x) / log(2)
    assert log2(2**x) == x

    # Expand
    assert log2(x).expand(func=True) - log(x) / log(2) == 0

    # Diff
    assert log2(42 * x).diff() - 1 / (log(2) * x) == 0
    assert log2(42 * x).diff() - log2(42 * x).expand(func=True).diff(x) == 0
示例#14
0
def test_optims_c99():
    x = Symbol('x')

    expr1 = 2**x + log(x) / log(2) + log(x + 1) + exp(x) - 1
    opt1 = optimize(expr1, optims_c99).simplify()
    assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)

    expr2 = log(x) / log(2) + log(x + 1)
    print()
    opt2 = optimize(expr2, optims_c99)
    assert opt2 == log2(x) + log1p(x)

    expr3 = log(x) / log(2) + log(17 * x + 17)
    opt3 = optimize(expr3, optims_c99)
    delta3 = opt3 - (log2(x) + log(17) + log1p(x))
    assert delta3 == 0

    expr4 = 2**x + 3 * log(5 * x + 7) / (13 * log(2)) + 11 * exp(x) - 11 + log(
        17 * x + 17)
    opt4 = optimize(expr4, optims_c99).simplify()
    delta4 = opt4 - (exp2(x) + 3 * log2(5 * x + 7) / 13 + 11 * expm1(x) +
                     log(17) + log1p(x))
    assert delta4 == 0
示例#15
0
def test_ccode_math_macros():
    assert ccode(z + exp(1)) == "z + M_E"
    assert ccode(z + log2(exp(1))) == "z + M_LOG2E"
    assert ccode(z + 1 / log(2)) == "z + M_LOG2E"
    assert ccode(z + log(2)) == "z + M_LN2"
    assert ccode(z + log(10)) == "z + M_LN10"
    assert ccode(z + pi) == "z + M_PI"
    assert ccode(z + pi / 2) == "z + M_PI_2"
    assert ccode(z + pi / 4) == "z + M_PI_4"
    assert ccode(z + 1 / pi) == "z + M_1_PI"
    assert ccode(z + 2 / pi) == "z + M_2_PI"
    assert ccode(z + 2 / sqrt(pi)) == "z + M_2_SQRTPI"
    assert ccode(z + 2 / Sqrt(pi)) == "z + M_2_SQRTPI"
    assert ccode(z + sqrt(2)) == "z + M_SQRT2"
    assert ccode(z + Sqrt(2)) == "z + M_SQRT2"
    assert ccode(z + 1 / sqrt(2)) == "z + M_SQRT1_2"
    assert ccode(z + 1 / Sqrt(2)) == "z + M_SQRT1_2"
示例#16
0
def test_ccode_math_macros():
    assert ccode(z + exp(1)) == 'z + M_E'
    assert ccode(z + log2(exp(1))) == 'z + M_LOG2E'
    assert ccode(z + 1 / log(2)) == 'z + M_LOG2E'
    assert ccode(z + log(2)) == 'z + M_LN2'
    assert ccode(z + log(10)) == 'z + M_LN10'
    assert ccode(z + pi) == 'z + M_PI'
    assert ccode(z + pi / 2) == 'z + M_PI_2'
    assert ccode(z + pi / 4) == 'z + M_PI_4'
    assert ccode(z + 1 / pi) == 'z + M_1_PI'
    assert ccode(z + 2 / pi) == 'z + M_2_PI'
    assert ccode(z + 2 / sqrt(pi)) == 'z + M_2_SQRTPI'
    assert ccode(z + 2 / Sqrt(pi)) == 'z + M_2_SQRTPI'
    assert ccode(z + sqrt(2)) == 'z + M_SQRT2'
    assert ccode(z + Sqrt(2)) == 'z + M_SQRT2'
    assert ccode(z + 1 / sqrt(2)) == 'z + M_SQRT1_2'
    assert ccode(z + 1 / Sqrt(2)) == 'z + M_SQRT1_2'
示例#17
0
def test_C99CodePrinter():
    assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
    assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
    assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
    assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
    assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
    assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
    assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)'  # note Cbrt due to cbrt already taken.
    assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
    assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
    assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
    assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
    c99printer = C99CodePrinter()
    assert c99printer.language == 'C'
    assert c99printer.standard == 'C99'
    assert 'restrict' in c99printer.reserved_words
    assert 'using' not in c99printer.reserved_words
示例#18
0
def test_ccode_math_macros():
    assert ccode(z + exp(1)) == 'z + M_E'
    assert ccode(z + log2(exp(1))) == 'z + M_LOG2E'
    assert ccode(z + 1/log(2)) == 'z + M_LOG2E'
    assert ccode(z + log(2)) == 'z + M_LN2'
    assert ccode(z + log(10)) == 'z + M_LN10'
    assert ccode(z + pi) == 'z + M_PI'
    assert ccode(z + pi/2) == 'z + M_PI_2'
    assert ccode(z + pi/4) == 'z + M_PI_4'
    assert ccode(z + 1/pi) == 'z + M_1_PI'
    assert ccode(z + 2/pi) == 'z + M_2_PI'
    assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI'
    assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI'
    assert ccode(z + sqrt(2)) == 'z + M_SQRT2'
    assert ccode(z + Sqrt(2)) == 'z + M_SQRT2'
    assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2'
    assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2'
示例#19
0
def test_C99CodePrinter():
    assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
    assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
    assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
    assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
    assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
    assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
    assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)'  # note Cbrt due to cbrt already taken.
    assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
    assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
    assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
    assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
    c99printer = C99CodePrinter()
    assert c99printer.language == 'C'
    assert c99printer.standard == 'C99'
    assert 'restrict' in c99printer.reserved_words
    assert 'using' not in c99printer.reserved_words
示例#20
0
def test_C99CodePrinter():
    assert C99CodePrinter().doprint(expm1(x)) == "expm1(x)"
    assert C99CodePrinter().doprint(log1p(x)) == "log1p(x)"
    assert C99CodePrinter().doprint(exp2(x)) == "exp2(x)"
    assert C99CodePrinter().doprint(log2(x)) == "log2(x)"
    assert C99CodePrinter().doprint(fma(x, y, -z)) == "fma(x, y, -z)"
    assert C99CodePrinter().doprint(log10(x)) == "log10(x)"
    assert (
        C99CodePrinter().doprint(Cbrt(x)) == "cbrt(x)"
    )  # note Cbrt due to cbrt already taken.
    assert C99CodePrinter().doprint(hypot(x, y)) == "hypot(x, y)"
    assert C99CodePrinter().doprint(loggamma(x)) == "lgamma(x)"
    assert C99CodePrinter().doprint(Max(x, 3, x ** 2)) == "fmax(3, fmax(x, pow(x, 2)))"
    assert C99CodePrinter().doprint(Min(x, 3)) == "fmin(3, x)"
    c99printer = C99CodePrinter()
    assert c99printer.language == "C"
    assert c99printer.standard == "C99"
    assert "restrict" in c99printer.reserved_words
    assert "using" not in c99printer.reserved_words
示例#21
0
def test_log2():
    if not np:
        skip("NumPy not installed")
    assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) < 1e-16
示例#22
0
            if before > after:
                expr = new_expr
    return expr


exp2_opt = ReplaceOptim(lambda p: p.is_Pow and p.base == 2,
                        lambda p: exp2(p.exp))

_d = Wild('d', properties=[lambda x: x.is_Dummy])
_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
_v = Wild('v')
_w = Wild('w')

log2_opt = ReplaceOptim(
    _v * log(_w) / log(2),
    _v * log2(_w),
    cost_function=lambda expr: expr.count(
        lambda e:
        (  # division & eval of transcendentals are expensive floating point operations...
            e.is_Pow and e.exp.is_negative  # division
            or (isinstance(e, (log, log2)) and not e.args[0].is_number)
        )  # transcendental
    ))

log2const_opt = ReplaceOptim(log(2) * log2(_w), log(_w))

logsumexp_2terms_opt = ReplaceOptim(
    lambda l: (isinstance(l, log) and l.args[0].is_Add and len(l.args[0].args)
               == 2 and all(isinstance(t, exp) for t in l.args[0].args)),
    lambda l: (Max(*[e.args[0] for e in l.args[0].args]) + log1p(
        exp(Min(*[e.args[0] for e in l.args[0].args])))))
示例#23
0
def test_log2():
    if not np:
        skip("NumPy not installed")
    assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) < 1e-16
示例#24
0
_d = Wild('d', properties=[lambda x: x.is_Dummy])
_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
_v = Wild('v')
_w = Wild('w')
_n = Wild('n', properties=[lambda x: x.is_number])

sinc_opt1 = ReplaceOptim(
    sin(_w)/_w, sinc(_w)
)
sinc_opt2 = ReplaceOptim(
    sin(_n*_w)/_w, _n*sinc(_n*_w)
)
sinc_opts = (sinc_opt1, sinc_opt2)

log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count(
    lambda e: (  # division & eval of transcendentals are expensive floating point operations...
        e.is_Pow and e.exp.is_negative  # division
        or (isinstance(e, (log, log2)) and not e.args[0].is_number))  # transcendental
    )
)

log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))

logsumexp_2terms_opt = ReplaceOptim(
    lambda l: (isinstance(l, log)
               and l.args[0].is_Add
               and len(l.args[0].args) == 2
               and all(isinstance(t, exp) for t in l.args[0].args)),
    lambda l: (
        Max(*[e.args[0] for e in l.args[0].args]) +
示例#25
0
def test_C99CodePrinter__precision():
    n = symbols("n", integer=True)
    f32_printer = C99CodePrinter(dict(type_aliases={real: float32}))
    f64_printer = C99CodePrinter(dict(type_aliases={real: float64}))
    f80_printer = C99CodePrinter(dict(type_aliases={real: float80}))
    assert f32_printer.doprint(sin(x + 2.1)) == "sinf(x + 2.1F)"
    assert f64_printer.doprint(sin(x + 2.1)) == "sin(x + 2.1000000000000001)"
    assert f80_printer.doprint(sin(x + Float("2.0"))) == "sinl(x + 2.0L)"

    for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ["f", "", "l"]):

        def check(expr, ref):
            assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())

        check(Abs(n), "abs(n)")
        check(Abs(x + 2.0), "fabs{s}(x + 2.0{S})")
        check(
            sin(x + 4.0) ** cos(x - 2.0),
            "pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))",
        )
        check(exp(x * 8.0), "exp{s}(8.0{S}*x)")
        check(exp2(x), "exp2{s}(x)")
        check(expm1(x * 4.0), "expm1{s}(4.0{S}*x)")
        check(Mod(n, 2), "((n) % (2))")
        check(Mod(2 * n + 3, 3 * n + 5), "((2*n + 3) % (3*n + 5))")
        check(Mod(x + 2.0, 3.0), "fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})")
        check(Mod(x, 2.0 * x + 3.0), "fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})")
        check(log(x / 2), "log{s}((1.0{S}/2.0{S})*x)")
        check(log10(3 * x / 2), "log10{s}((3.0{S}/2.0{S})*x)")
        check(log2(x * 8.0), "log2{s}(8.0{S}*x)")
        check(log1p(x), "log1p{s}(x)")
        check(2 ** x, "pow{s}(2, x)")
        check(2.0 ** x, "pow{s}(2.0{S}, x)")
        check(x ** 3, "pow{s}(x, 3)")
        check(x ** 4.0, "pow{s}(x, 4.0{S})")
        check(sqrt(3 + x), "sqrt{s}(x + 3)")
        check(Cbrt(x - 2.0), "cbrt{s}(x - 2.0{S})")
        check(hypot(x, y), "hypot{s}(x, y)")
        check(sin(3.0 * x + 2.0), "sin{s}(3.0{S}*x + 2.0{S})")
        check(cos(3.0 * x - 1.0), "cos{s}(3.0{S}*x - 1.0{S})")
        check(tan(4.0 * y + 2.0), "tan{s}(4.0{S}*y + 2.0{S})")
        check(asin(3.0 * x + 2.0), "asin{s}(3.0{S}*x + 2.0{S})")
        check(acos(3.0 * x + 2.0), "acos{s}(3.0{S}*x + 2.0{S})")
        check(atan(3.0 * x + 2.0), "atan{s}(3.0{S}*x + 2.0{S})")
        check(atan2(3.0 * x, 2.0 * y), "atan2{s}(3.0{S}*x, 2.0{S}*y)")

        check(sinh(3.0 * x + 2.0), "sinh{s}(3.0{S}*x + 2.0{S})")
        check(cosh(3.0 * x - 1.0), "cosh{s}(3.0{S}*x - 1.0{S})")
        check(tanh(4.0 * y + 2.0), "tanh{s}(4.0{S}*y + 2.0{S})")
        check(asinh(3.0 * x + 2.0), "asinh{s}(3.0{S}*x + 2.0{S})")
        check(acosh(3.0 * x + 2.0), "acosh{s}(3.0{S}*x + 2.0{S})")
        check(atanh(3.0 * x + 2.0), "atanh{s}(3.0{S}*x + 2.0{S})")
        check(erf(42.0 * x), "erf{s}(42.0{S}*x)")
        check(erfc(42.0 * x), "erfc{s}(42.0{S}*x)")
        check(gamma(x), "tgamma{s}(x)")
        check(loggamma(x), "lgamma{s}(x)")

        check(ceiling(x + 2.0), "ceil{s}(x + 2.0{S})")
        check(floor(x + 2.0), "floor{s}(x + 2.0{S})")
        check(fma(x, y, -z), "fma{s}(x, y, -z)")
        check(Max(x, 8.0, x ** 4.0), "fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))")
        check(Min(x, 2.0), "fmin{s}(2.0{S}, x)")
示例#26
0
                expr = new_expr
    return expr


exp2_opt = ReplaceOptim(
    lambda p: p.is_Pow and p.base == 2,
    lambda p: exp2(p.exp)
)

_d = Wild('d', properties=[lambda x: x.is_Dummy])
_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
_v = Wild('v')
_w = Wild('w')


log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count(
    lambda e: (  # division & eval of transcendentals are expensive floating point operations...
        e.is_Pow and e.exp.is_negative  # division
        or (isinstance(e, (log, log2)) and not e.args[0].is_number))  # transcendental
    )
)

log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))

logsumexp_2terms_opt = ReplaceOptim(
    lambda l: (isinstance(l, log)
               and l.args[0].is_Add
               and len(l.args[0].args) == 2
               and all(isinstance(t, exp) for t in l.args[0].args)),
    lambda l: (
        Max(*[e.args[0] for e in l.args[0].args]) +