示例#1
0
def test_loops():
    m, n = symbols('m n', integer=True)
    A = IndexedBase('A')
    x = IndexedBase('x')
    y = IndexedBase('y')
    z = IndexedBase('z')
    i = Idx('i', m)
    j = Idx('j', n)

    assert rust_code(A[i, j]*x[j], assign_to=y[i]) == (
        "for i in 0..m {\n"
        "    y[i] = 0;\n"
        "}\n"
        "for i in 0..m {\n"
        "    for j in 0..n {\n"
        "        y[i] = A[n*i + j]*x[j] + y[i];\n"
        "    }\n"
        "}")

    assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == (
        "for i in 0..m {\n"
        "    y[i] = x[i] + z[i];\n"
        "}\n"
        "for i in 0..m {\n"
        "    for j in 0..n {\n"
        "        y[i] = A[n*i + j]*x[j] + y[i];\n"
        "    }\n"
        "}")
示例#2
0
def test_printmethod():
    class fabs(Abs):
        def _rust_code(self, printer):
            return "%s.fabs()" % printer._print(self.args[0])
    assert rust_code(fabs(x)) == "x.fabs()"
    a = MatrixSymbol("a", 1, 3)
    assert rust_code(a[0,0]) == 'a[0]'
示例#3
0
def test_user_functions():
    x = symbols('x', integer=False)
    n = symbols('n', integer=True)
    custom_functions = {
        "ceiling": "ceil",
        "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)],
    }
    assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()"
    assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)"
    assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)"
示例#4
0
def test_reserved_words():

    x, y = symbols("x if")

    expr = sin(y)
    assert rust_code(expr) == "if_.sin()"
    assert rust_code(expr, dereference=[y]) == "(*if_).sin()"
    assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()"

    with raises(ValueError):
        rust_code(expr, error_on_reserved=True)
示例#5
0
def test_sign():
    expr = sign(x) * y
    assert rust_code(expr) == "y*x.signum()"
    assert rust_code(expr, assign_to='r') == "r = y*x.signum();"

    expr = sign(x + y) + 42
    assert rust_code(expr) == "(x + y).signum() + 42"
    assert rust_code(expr, assign_to='r') == "r = (x + y).signum() + 42;"

    expr = sign(cos(x))
    assert rust_code(expr) == "x.cos().signum()"
示例#6
0
def test_Indexed():
    n, m, o = symbols('n m o', integer=True)
    i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)

    x = IndexedBase('x')[j]
    assert rust_code(x) == "x[j]"

    A = IndexedBase('A')[i, j]
    assert rust_code(A) == "A[m*i + j]"

    B = IndexedBase('B')[i, j, k]
    assert rust_code(B) == "B[m*o*i + o*j + k]"
示例#7
0
def test_inline_function():
    x = symbols('x')
    g = implemented_function('g', Lambda(x, 2*x))
    assert rust_code(g(x)) == "2*x"

    g = implemented_function('g', Lambda(x, 2*x/Catalan))
    assert rust_code(g(x)) == (
        "const Catalan: f64 = %s;\n2*x/Catalan" % Catalan.evalf(17))

    A = IndexedBase('A')
    i = Idx('i', symbols('n', integer=True))
    g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
    assert rust_code(g(A[i]), assign_to=A[i]) == (
        "for i in 0..n {\n"
        "    A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
        "}")
示例#8
0
def test_loops_addfactor():
    m, n, o, p = symbols('m n o p', integer=True)
    a = IndexedBase('a')
    b = IndexedBase('b')
    c = IndexedBase('c')
    y = IndexedBase('y')
    i = Idx('i', m)
    j = Idx('j', n)
    k = Idx('k', o)
    l = Idx('l', p)

    code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
    assert code == (
        "for i in 0..m {\n"
        "    y[i] = 0;\n"
        "}\n"
        "for i in 0..m {\n"
        "    for j in 0..n {\n"
        "        for k in 0..o {\n"
        "            for l in 0..p {\n"
        "                y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
        "            }\n"
        "        }\n"
        "    }\n"
        "}")
示例#9
0
def test_ITE():
    expr = ITE(x < 1, y, z)
    assert rust_code(expr) == (
            "if (x < 1) {\n"
            "    y\n"
            "} else {\n"
            "    z\n"
            "}")
示例#10
0
def test_dummy_loops():
    i, m = symbols('i m', integer=True, cls=Dummy)
    x = IndexedBase('x')
    y = IndexedBase('y')
    i = Idx(i, m)

    assert rust_code(x[i], assign_to=y[i]) == (
        "for i in 0..m {\n"
        "    y[i] = x[i];\n"
        "}")
示例#11
0
def test_Rational():
    assert rust_code(Rational(3, 7)) == "3_f64/7.0"
    assert rust_code(Rational(18, 9)) == "2"
    assert rust_code(Rational(3, -7)) == "-3_f64/7.0"
    assert rust_code(Rational(-3, -7)) == "3_f64/7.0"
    assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0"
    assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x"
示例#12
0
def test_Relational():
    assert rust_code(Eq(x, y)) == "x == y"
    assert rust_code(Ne(x, y)) == "x != y"
    assert rust_code(Le(x, y)) == "x <= y"
    assert rust_code(Lt(x, y)) == "x < y"
    assert rust_code(Gt(x, y)) == "x > y"
    assert rust_code(Ge(x, y)) == "x >= y"
示例#13
0
def test_Piecewise():
    expr = Piecewise((x, x < 1), (x + 2, True))
    assert rust_code(expr) == (
            "if (x < 1) {\n"
            "    x\n"
            "} else {\n"
            "    x + 2\n"
            "}")
    assert rust_code(expr, assign_to="r") == (
        "r = if (x < 1) {\n"
        "    x\n"
        "} else {\n"
        "    x + 2\n"
        "};")
    assert rust_code(expr, assign_to="r", inline=True) == (
        "r = if (x < 1) { x } else { x + 2 };")
    expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
    assert rust_code(expr, inline=True) == (
        "if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
    assert rust_code(expr, assign_to="r", inline=True) == (
        "r = if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 };")
    assert rust_code(expr, assign_to="r") == (
        "r = if (x < 1) {\n"
        "    x\n"
        "} else if (x < 5) {\n"
        "    x + 1\n"
        "} else {\n"
        "    x + 2\n"
        "};")
    expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
    assert rust_code(expr, inline=True) == (
        "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
    expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42
    assert rust_code(expr, inline=True) == (
        "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 } - 42")
    # Check that Piecewise without a True (default) condition error
    expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
    raises(ValueError, lambda: rust_code(expr))
示例#14
0
def test_loops_multiple_contractions():
    n, m, o, p = symbols('n m o p', integer=True)
    a = IndexedBase('a')
    b = IndexedBase('b')
    y = IndexedBase('y')
    i = Idx('i', m)
    j = Idx('j', n)
    k = Idx('k', o)
    l = Idx('l', p)

    assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == (
        "for i in 0..m {\n"
        "    y[i] = 0;\n"
        "}\n"
        "for i in 0..m {\n"
        "    for j in 0..n {\n"
        "        for k in 0..o {\n"
        "            for l in 0..p {\n"
        "                y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
        "            }\n"
        "        }\n"
        "    }\n"
        "}")
示例#15
0
def test_constants_other():
    assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
    assert rust_code(
            2*Catalan) == "const Catalan: f64 = %s;\n2*Catalan" % Catalan.evalf(17)
    assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
示例#16
0
def test_sparse_matrix():
    # gh-15791
    assert 'Not supported in Rust' in rust_code(SparseMatrix([[1, 2, 3]]))
示例#17
0
def test_basic_ops():
    assert rust_code(x + y) == "x + y"
    assert rust_code(x - y) == "x - y"
    assert rust_code(x * y) == "x*y"
    assert rust_code(x / y) == "x/y"
    assert rust_code(-x) == "-x"
示例#18
0
def test_matrix():
    assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]'
    with raises(ValueError):
        rust_code(Matrix([[1, 2, 3]]))
示例#19
0
def test_dereference_printing():
    expr = x + y + sin(z) + z
    assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()"
示例#20
0
def test_boolean():
    assert rust_code(True) == "true"
    assert rust_code(S.true) == "true"
    assert rust_code(False) == "false"
    assert rust_code(S.false) == "false"
    assert rust_code(x & y) == "x && y"
    assert rust_code(x | y) == "x || y"
    assert rust_code(~x) == "!x"
    assert rust_code(x & y & z) == "x && y && z"
    assert rust_code(x | y | z) == "x || y || z"
    assert rust_code((x & y) | z) == "z || x && y"
    assert rust_code((x | y) & z) == "z && (x || y)"
示例#21
0
def test_settings():
    raises(TypeError, lambda: rust_code(sin(x), method="garbage"))
示例#22
0
def test_Integer():
    assert rust_code(Integer(42)) == "42"
    assert rust_code(Integer(-56)) == "-56"