def test_log1p_opt(): x = Symbol('x') expr1 = log(x + 1) opt1 = optimize(expr1, [log1p_opt]) assert log1p(x) - opt1 == 0 expr2 = log(3 * x + 3) opt2 = optimize(expr2, [log1p_opt]) assert log1p(x) + log(3) == opt2
def test_subclass_print_method__ns(): class MyPrinter(CXX11CodePrinter): _ns = "my_library::" p = CXX11CodePrinter() myp = MyPrinter() assert p.doprint(log1p(x)) == "std::log1p(x)" assert myp.doprint(log1p(x)) == "my_library::log1p(x)"
def test_subclass_print_method__ns(): class MyPrinter(CXX11CodePrinter): _ns = 'my_library::' p = CXX11CodePrinter() myp = MyPrinter() assert p.doprint(log1p(x)) == 'std::log1p(x)' assert myp.doprint(log1p(x)) == 'my_library::log1p(x)'
def test_optims_c99(): x = Symbol('x') expr1 = 2**x + log(x) / log(2) + log(x + 1) + exp(x) - 1 opt1 = optimize(expr1, optims_c99).simplify() assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1 expr2 = log(x) / log(2) + log(x + 1) opt2 = optimize(expr2, optims_c99) assert opt2 == log2(x) + log1p(x) assert opt2.rewrite(log) == expr2 expr3 = log(x) / log(2) + log(17 * x + 17) opt3 = optimize(expr3, optims_c99) delta3 = opt3 - (log2(x) + log(17) + log1p(x)) assert delta3 == 0 assert (opt3.rewrite(log) - expr3).simplify() == 0 expr4 = 2**x + 3 * log(5 * x + 7) / (13 * log(2)) + 11 * exp(x) - 11 + log( 17 * x + 17) opt4 = optimize(expr4, optims_c99).simplify() delta4 = opt4 - (exp2(x) + 3 * log2(5 * x + 7) / 13 + 11 * expm1(x) + log(17) + log1p(x)) assert delta4 == 0 assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0 expr5 = 3 * exp(2 * x) - 3 opt5 = optimize(expr5, optims_c99) delta5 = opt5 - 3 * expm1(2 * x) assert delta5 == 0 assert opt5.rewrite(exp) == expr5 expr6 = exp(2 * x) - 3 opt6 = optimize(expr6, optims_c99) delta6 = opt6 - (exp(2 * x) - 3) assert delta6 == 0 expr7 = log(3 * x + 3) opt7 = optimize(expr7, optims_c99) delta7 = opt7 - (log(3) + log1p(x)) assert delta7 == 0 assert (opt7.rewrite(log) - expr7).simplify() == 0 expr8 = log(2 * x + 3) opt8 = optimize(expr8, optims_c99) assert opt8 == expr8
def test_subclass_print_method(): class MyPrinter(CXX11CodePrinter): def _print_log1p(self, expr): return 'my_library::log1p(%s)' % ', '.join( map(self._print, expr.args)) assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)'
def test_optims_c99(): x = Symbol('x') expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1 opt1 = optimize(expr1, optims_c99).simplify() assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1 expr2 = log(x)/log(2) + log(x + 1) opt2 = optimize(expr2, optims_c99) assert opt2 == log2(x) + log1p(x) assert opt2.rewrite(log) == expr2 expr3 = log(x)/log(2) + log(17*x + 17) opt3 = optimize(expr3, optims_c99) delta3 = opt3 - (log2(x) + log(17) + log1p(x)) assert delta3 == 0 assert (opt3.rewrite(log) - expr3).simplify() == 0 expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17) opt4 = optimize(expr4, optims_c99).simplify() delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x)) assert delta4 == 0 assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0 expr5 = 3*exp(2*x) - 3 opt5 = optimize(expr5, optims_c99) delta5 = opt5 - 3*expm1(2*x) assert delta5 == 0 assert opt5.rewrite(exp) == expr5 expr6 = exp(2*x) - 3 opt6 = optimize(expr6, optims_c99) delta6 = opt6 - (exp(2*x) - 3) assert delta6 == 0 expr7 = log(3*x + 3) opt7 = optimize(expr7, optims_c99) delta7 = opt7 - (log(3) + log1p(x)) assert delta7 == 0 assert (opt7.rewrite(log) - expr7).simplify() == 0 expr8 = log(2*x + 3) opt8 = optimize(expr8, optims_c99) assert opt8 == expr8
def test_CXX11CodePrinter(): assert CXX11CodePrinter().doprint(log1p(x)) == "std::log1p(x)" cxx11printer = CXX11CodePrinter() assert cxx11printer.language == "C++" assert cxx11printer.standard == "C++11" assert "operator" in cxx11printer.reserved_words assert "noexcept" in cxx11printer.reserved_words assert "concept" not in cxx11printer.reserved_words
def test_C99CodePrinter__precision(): n = symbols('n', integer=True) f32_printer = C99CodePrinter(dict(type_aliases={real: float32})) f64_printer = C99CodePrinter(dict(type_aliases={real: float64})) f80_printer = C99CodePrinter(dict(type_aliases={real: float80})) assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)' assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)' assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)' for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']): def check(expr, ref): assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper()) check(Abs(n), 'abs(n)') check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})') check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))') check(exp(x*8.0), 'exp{s}(8.0{S}*x)') check(exp2(x), 'exp2{s}(x)') check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)') check(Mod(n, 2), '((n) % (2))') check(Mod(2*n + 3, 3*n + 5), '((2*n + 3) % (3*n + 5))') check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})') check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})') check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)') check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)') check(log2(x*8.0), 'log2{s}(8.0{S}*x)') check(log1p(x), 'log1p{s}(x)') check(2**x, 'pow{s}(2, x)') check(2.0**x, 'pow{s}(2.0{S}, x)') check(x**3, 'pow{s}(x, 3)') check(x**4.0, 'pow{s}(x, 4.0{S})') check(sqrt(3+x), 'sqrt{s}(x + 3)') check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})') check(hypot(x, y), 'hypot{s}(x, y)') check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})') check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})') check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})') check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})') check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})') check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})') check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)') check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})') check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})') check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})') check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})') check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})') check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})') check(erf(42.*x), 'erf{s}(42.0{S}*x)') check(erfc(42.*x), 'erfc{s}(42.0{S}*x)') check(gamma(x), 'tgamma{s}(x)') check(loggamma(x), 'lgamma{s}(x)') check(ceiling(x + 2.), "ceil{s}(x + 2.0{S})") check(floor(x + 2.), "floor{s}(x + 2.0{S})") check(fma(x, y, -z), 'fma{s}(x, y, -z)') check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))') check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
def test_CXX11CodePrinter(): assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)' cxx11printer = CXX11CodePrinter() assert cxx11printer.language == 'C++' assert cxx11printer.standard == 'C++11' assert 'operator' in cxx11printer.reserved_words assert 'noexcept' in cxx11printer.reserved_words assert 'concept' not in cxx11printer.reserved_words
def test_C99CodePrinter__precision(): n = symbols('n', integer=True) f32_printer = C99CodePrinter(dict(type_aliases={real: float32})) f64_printer = C99CodePrinter(dict(type_aliases={real: float64})) f80_printer = C99CodePrinter(dict(type_aliases={real: float80})) assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)' assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)' assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)' for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']): def check(expr, ref): assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper()) check(Abs(n), 'abs(n)') check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})') check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))') check(exp(x*8.0), 'exp{s}(8.0{S}*x)') check(exp2(x), 'exp2{s}(x)') check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)') check(Mod(n, 2), '((n) % (2))') check(Mod(2*n + 3, 3*n + 5), '((2*n + 3) % (3*n + 5))') check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})') check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})') check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)') check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)') check(log2(x*8.0), 'log2{s}(8.0{S}*x)') check(log1p(x), 'log1p{s}(x)') check(2**x, 'pow{s}(2, x)') check(2.0**x, 'pow{s}(2.0{S}, x)') check(x**3, 'pow{s}(x, 3)') check(x**4.0, 'pow{s}(x, 4.0{S})') check(sqrt(3+x), 'sqrt{s}(x + 3)') check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})') check(hypot(x, y), 'hypot{s}(x, y)') check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})') check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})') check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})') check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})') check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})') check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})') check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)') check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})') check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})') check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})') check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})') check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})') check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})') check(erf(42.*x), 'erf{s}(42.0{S}*x)') check(erfc(42.*x), 'erfc{s}(42.0{S}*x)') check(gamma(x), 'tgamma{s}(x)') check(loggamma(x), 'lgamma{s}(x)') check(ceiling(x + 2.), "ceil{s}(x + 2.0{S})") check(floor(x + 2.), "floor{s}(x + 2.0{S})") check(fma(x, y, -z), 'fma{s}(x, y, -z)') check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))') check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
def test_CXX11CodePrinter(): assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)' cxx11printer = CXX11CodePrinter() assert cxx11printer.language == 'C++' assert cxx11printer.standard == 'C++11' assert 'operator' in cxx11printer.reserved_words assert 'noexcept' in cxx11printer.reserved_words assert 'concept' not in cxx11printer.reserved_words
def test_optims_numpy_TODO(): def check(d): for k, v in d.items(): assert optimize(k, optims_numpy) == v x, y = map(Symbol, 'x y'.split()) check({ log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y), exp(x*sin(y)/y) - 1: expm1(x*sinc(y)) })
def test_optims_numpy(): def check(d): for k, v in d.items(): assert optimize(k, optims_numpy) == v x = Symbol('x') check({ sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x), log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3) })
def test_log1p_opt(): x = Symbol('x') expr1 = log(x + 1) opt1 = optimize(expr1, [log1p_opt]) assert log1p(x) - opt1 == 0 assert opt1.rewrite(log) == expr1 expr2 = log(3*x + 3) opt2 = optimize(expr2, [log1p_opt]) assert log1p(x) + log(3) == opt2 assert (opt2.rewrite(log) - expr2).simplify() == 0 expr3 = log(2*x + 1) opt3 = optimize(expr3, [log1p_opt]) assert log1p(2*x) - opt3 == 0 assert opt3.rewrite(log) == expr3 expr4 = log(x+3) opt4 = optimize(expr4, [log1p_opt]) assert str(opt4) == 'log(x + 3)'
def test_log1p_opt(): x = Symbol('x') expr1 = log(x + 1) opt1 = optimize(expr1, [log1p_opt]) assert log1p(x) - opt1 == 0 assert opt1.rewrite(log) == expr1 expr2 = log(3 * x + 3) opt2 = optimize(expr2, [log1p_opt]) assert log1p(x) + log(3) == opt2 assert (opt2.rewrite(log) - expr2).simplify() == 0 expr3 = log(2 * x + 1) opt3 = optimize(expr3, [log1p_opt]) assert log1p(2 * x) - opt3 == 0 assert opt3.rewrite(log) == expr3 expr4 = log(x + 3) opt4 = optimize(expr4, [log1p_opt]) assert str(opt4) == 'log(x + 3)'
def test_optims_c99(): x = Symbol('x') expr1 = 2**x + log(x) / log(2) + log(x + 1) + exp(x) - 1 opt1 = optimize(expr1, optims_c99).simplify() assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) expr2 = log(x) / log(2) + log(x + 1) print() opt2 = optimize(expr2, optims_c99) assert opt2 == log2(x) + log1p(x) expr3 = log(x) / log(2) + log(17 * x + 17) opt3 = optimize(expr3, optims_c99) delta3 = opt3 - (log2(x) + log(17) + log1p(x)) assert delta3 == 0 expr4 = 2**x + 3 * log(5 * x + 7) / (13 * log(2)) + 11 * exp(x) - 11 + log( 17 * x + 17) opt4 = optimize(expr4, optims_c99).simplify() delta4 = opt4 - (exp2(x) + 3 * log2(5 * x + 7) / 13 + 11 * expm1(x) + log(17) + log1p(x)) assert delta4 == 0
def as_FunctionDefinition(variables=None, name="logsumexp"): variables = variables or {} n = variables.get('n', Symbol('n', integer=True)) i = variables.get('i', Symbol('i', integer=True)) s = variables.get('s', Symbol('s')) x = variables.get('x', IndexedBase('x', shape=(n, ))) body = CodeBlock( FunctionCall('sympy_quicksort', (x, n), statement=True), Declaration(i, 0), Declaration(s, 0.0), While( i < n - 1, CodeBlock(AddAugmentedAssignment(s, exp(x[i])), AddAugmentedAssignment(i, 1))), ReturnStatement(log1p(s) + x[n - 1])) return FunctionDefinition("real", name, (x, n), body)
def test_C99CodePrinter(): assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)' assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)' assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)' assert C99CodePrinter().doprint(log2(x)) == 'log2(x)' assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)' assert C99CodePrinter().doprint(log10(x)) == 'log10(x)' assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken. assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)' assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)' assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))' assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)' c99printer = C99CodePrinter() assert c99printer.language == 'C' assert c99printer.standard == 'C99' assert 'restrict' in c99printer.reserved_words assert 'using' not in c99printer.reserved_words
def test_C99CodePrinter(): assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)' assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)' assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)' assert C99CodePrinter().doprint(log2(x)) == 'log2(x)' assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)' assert C99CodePrinter().doprint(log10(x)) == 'log10(x)' assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken. assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)' assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)' assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))' assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)' c99printer = C99CodePrinter() assert c99printer.language == 'C' assert c99printer.standard == 'C99' assert 'restrict' in c99printer.reserved_words assert 'using' not in c99printer.reserved_words
def test_C99CodePrinter(): assert C99CodePrinter().doprint(expm1(x)) == "expm1(x)" assert C99CodePrinter().doprint(log1p(x)) == "log1p(x)" assert C99CodePrinter().doprint(exp2(x)) == "exp2(x)" assert C99CodePrinter().doprint(log2(x)) == "log2(x)" assert C99CodePrinter().doprint(fma(x, y, -z)) == "fma(x, y, -z)" assert C99CodePrinter().doprint(log10(x)) == "log10(x)" assert ( C99CodePrinter().doprint(Cbrt(x)) == "cbrt(x)" ) # note Cbrt due to cbrt already taken. assert C99CodePrinter().doprint(hypot(x, y)) == "hypot(x, y)" assert C99CodePrinter().doprint(loggamma(x)) == "lgamma(x)" assert C99CodePrinter().doprint(Max(x, 3, x ** 2)) == "fmax(3, fmax(x, pow(x, 2)))" assert C99CodePrinter().doprint(Min(x, 3)) == "fmin(3, x)" c99printer = C99CodePrinter() assert c99printer.language == "C" assert c99printer.standard == "C99" assert "restrict" in c99printer.reserved_words assert "using" not in c99printer.reserved_words
def test_log1p(): if not np: skip("NumPy not installed") f = lambdify((a,), log1p(a), 'numpy') assert abs(f(1e-99) - 1e-99) < 1e-100
w = Wild('w') log2_opt = ReplaceOptim( v * log(w) / log(2), v * log2(w), cost_function=lambda expr: expr.count(lambda e: ( (isinstance(e, Pow) and e.exp.is_negative) or (isinstance(e, (log, log2)) and not e.args[0].is_number)))) log2const_opt = ReplaceOptim(log(2) * log2(w), log(w)) logsumexp_2terms_opt = ReplaceOptim( lambda l: (isinstance(l, log) and isinstance(l.args[0], Add) and len(l.args[0].args) == 2 and all(isinstance(t, exp) for t in l.args[0].args)), lambda l: (Max(*[e.args[0] for e in l.args[0].args]) + log1p( exp(Min(*[e.args[0] for e in l.args[0].args]))))) def _partition(predicate, iterable): iter_a, iter_b = tee(iterable) return tuple(filter(predicate, iter_a)), tuple(filterfalse(predicate, iter_b)) def _try_expm1(e): return e.factor().replace(exp(w) - 1, expm1(w)) def _expm1_value(e): def isnum(arg): return arg.is_number
_v * log(_w) / log(2), _v * log2(_w), cost_function=lambda expr: expr.count( lambda e: ( # division & eval of transcendentals are expensive floating point operations... e.is_Pow and e.exp.is_negative # division or (isinstance(e, (log, log2)) and not e.args[0].is_number) ) # transcendental )) log2const_opt = ReplaceOptim(log(2) * log2(_w), log(_w)) logsumexp_2terms_opt = ReplaceOptim( lambda l: (isinstance(l, log) and l.args[0].is_Add and len(l.args[0].args) == 2 and all(isinstance(t, exp) for t in l.args[0].args)), lambda l: (Max(*[e.args[0] for e in l.args[0].args]) + log1p( exp(Min(*[e.args[0] for e in l.args[0].args]))))) def _try_expm1(expr): protected, old_new = expr.replace(exp, lambda arg: Dummy(), map=True) factored = protected.factor() new_old = {v: k for k, v in old_new.items()} return factored.replace( _d - 1, lambda d: expm1(new_old[d].args[0])).xreplace(new_old) def _expm1_value(e): numbers, non_num = sift(e.args, lambda arg: arg.is_number, binary=True) non_num_exp, non_num_other = sift(non_num, lambda arg: arg.has(exp), binary=True)
def test_subclass_print_method(): class MyPrinter(CXX11CodePrinter): def _print_log1p(self, expr): return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args)) assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)'
def test_log1p(): # Eval assert log1p(0) == 0 d = S(10) assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0 x = Symbol('x', real=True) # Expand and rewrite assert log1p(x).expand(func=True) - log(x + 1) == 0 assert log1p(x).rewrite('tractable') - log(x + 1) == 0 assert log1p(x).rewrite('log') - log(x + 1) == 0 # Precision assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100 # Properties assert log1p(-2**Rational(-1, 2)).is_real assert not log1p(-1).is_finite assert log1p(pi).is_finite assert not log1p(x).is_positive assert log1p(Symbol('y', positive=True)).is_positive assert not log1p(x).is_zero assert log1p(Symbol('z', zero=True)).is_zero assert not log1p(x).is_nonnegative assert log1p(Symbol('o', nonnegative=True)).is_nonnegative # Diff assert log1p(42 * x).diff(x) - 42 / (42 * x + 1) == 0 assert log1p(42 * x).diff(x) - log1p(42 * x).expand(func=True).diff(x) == 0
lambda e: ( # division & eval of transcendentals are expensive floating point operations... e.is_Pow and e.exp.is_negative # division or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental ) ) log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w)) logsumexp_2terms_opt = ReplaceOptim( lambda l: (isinstance(l, log) and l.args[0].is_Add and len(l.args[0].args) == 2 and all(isinstance(t, exp) for t in l.args[0].args)), lambda l: ( Max(*[e.args[0] for e in l.args[0].args]) + log1p(exp(Min(*[e.args[0] for e in l.args[0].args]))) ) ) def _try_expm1(expr): protected, old_new = expr.replace(exp, lambda arg: Dummy(), map=True) factored = protected.factor() new_old = {v: k for k, v in old_new.items()} return factored.replace(_d - 1, lambda d: expm1(new_old[d].args[0])).xreplace(new_old) def _expm1_value(e): numbers, non_num = sift(e.args, lambda arg: arg.is_number, binary=True) non_num_exp, non_num_other = sift(non_num, lambda arg: arg.has(exp), binary=True)
def test_C99CodePrinter__precision(): n = symbols("n", integer=True) f32_printer = C99CodePrinter(dict(type_aliases={real: float32})) f64_printer = C99CodePrinter(dict(type_aliases={real: float64})) f80_printer = C99CodePrinter(dict(type_aliases={real: float80})) assert f32_printer.doprint(sin(x + 2.1)) == "sinf(x + 2.1F)" assert f64_printer.doprint(sin(x + 2.1)) == "sin(x + 2.1000000000000001)" assert f80_printer.doprint(sin(x + Float("2.0"))) == "sinl(x + 2.0L)" for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ["f", "", "l"]): def check(expr, ref): assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper()) check(Abs(n), "abs(n)") check(Abs(x + 2.0), "fabs{s}(x + 2.0{S})") check( sin(x + 4.0) ** cos(x - 2.0), "pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))", ) check(exp(x * 8.0), "exp{s}(8.0{S}*x)") check(exp2(x), "exp2{s}(x)") check(expm1(x * 4.0), "expm1{s}(4.0{S}*x)") check(Mod(n, 2), "((n) % (2))") check(Mod(2 * n + 3, 3 * n + 5), "((2*n + 3) % (3*n + 5))") check(Mod(x + 2.0, 3.0), "fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})") check(Mod(x, 2.0 * x + 3.0), "fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})") check(log(x / 2), "log{s}((1.0{S}/2.0{S})*x)") check(log10(3 * x / 2), "log10{s}((3.0{S}/2.0{S})*x)") check(log2(x * 8.0), "log2{s}(8.0{S}*x)") check(log1p(x), "log1p{s}(x)") check(2 ** x, "pow{s}(2, x)") check(2.0 ** x, "pow{s}(2.0{S}, x)") check(x ** 3, "pow{s}(x, 3)") check(x ** 4.0, "pow{s}(x, 4.0{S})") check(sqrt(3 + x), "sqrt{s}(x + 3)") check(Cbrt(x - 2.0), "cbrt{s}(x - 2.0{S})") check(hypot(x, y), "hypot{s}(x, y)") check(sin(3.0 * x + 2.0), "sin{s}(3.0{S}*x + 2.0{S})") check(cos(3.0 * x - 1.0), "cos{s}(3.0{S}*x - 1.0{S})") check(tan(4.0 * y + 2.0), "tan{s}(4.0{S}*y + 2.0{S})") check(asin(3.0 * x + 2.0), "asin{s}(3.0{S}*x + 2.0{S})") check(acos(3.0 * x + 2.0), "acos{s}(3.0{S}*x + 2.0{S})") check(atan(3.0 * x + 2.0), "atan{s}(3.0{S}*x + 2.0{S})") check(atan2(3.0 * x, 2.0 * y), "atan2{s}(3.0{S}*x, 2.0{S}*y)") check(sinh(3.0 * x + 2.0), "sinh{s}(3.0{S}*x + 2.0{S})") check(cosh(3.0 * x - 1.0), "cosh{s}(3.0{S}*x - 1.0{S})") check(tanh(4.0 * y + 2.0), "tanh{s}(4.0{S}*y + 2.0{S})") check(asinh(3.0 * x + 2.0), "asinh{s}(3.0{S}*x + 2.0{S})") check(acosh(3.0 * x + 2.0), "acosh{s}(3.0{S}*x + 2.0{S})") check(atanh(3.0 * x + 2.0), "atanh{s}(3.0{S}*x + 2.0{S})") check(erf(42.0 * x), "erf{s}(42.0{S}*x)") check(erfc(42.0 * x), "erfc{s}(42.0{S}*x)") check(gamma(x), "tgamma{s}(x)") check(loggamma(x), "lgamma{s}(x)") check(ceiling(x + 2.0), "ceil{s}(x + 2.0{S})") check(floor(x + 2.0), "floor{s}(x + 2.0{S})") check(fma(x, y, -z), "fma{s}(x, y, -z)") check(Max(x, 8.0, x ** 4.0), "fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))") check(Min(x, 2.0), "fmin{s}(2.0{S}, x)")
def test_numerical_accuracy_functions(): prntr = SciPyPrinter() assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)' assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)' assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
def test_log1p(): if not np: skip("NumPy not installed") f = lambdify((a,), log1p(a), 'numpy') assert abs(f(1e-99) - 1e-99) < 1e-100
def test_log1p(): # Eval assert log1p(0) == 0 d = S(10) assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0 x = Symbol('x', real=True, finite=True) # Expand and rewrite assert log1p(x).expand(func=True) - log(x + 1) == 0 assert log1p(x).rewrite('tractable') - log(x + 1) == 0 assert log1p(x).rewrite('log') - log(x + 1) == 0 # Precision assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100 # Properties assert log1p(-2**(-S(1)/2)).is_real assert not log1p(-1).is_finite assert log1p(pi).is_finite assert not log1p(x).is_positive assert log1p(Symbol('y', positive=True)).is_positive assert not log1p(x).is_zero assert log1p(Symbol('z', zero=True)).is_zero assert not log1p(x).is_nonnegative assert log1p(Symbol('o', nonnegative=True)).is_nonnegative # Diff assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0 assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0