コード例 #1
0
def test_tensor_expression():
    """ test TensorExpression """
    e = TensorExpression("[[0, 1], [2, 3]]")
    assert isinstance(str(e), str)
    assert e.shape == (2, 2)
    assert e.rank == 2
    assert e.constant
    np.testing.assert_allclose(e.get_compiled_array()(), [[0, 1], [2, 3]])
    np.testing.assert_allclose(e.get_compiled_array()(tuple()),
                               [[0, 1], [2, 3]])
    assert e.differentiate("a") == TensorExpression("[[0, 0], [0, 0]]")
    np.testing.assert_allclose(e.value, np.arange(4).reshape(2, 2))

    e = TensorExpression("[a, 2*a]")
    assert isinstance(str(e), str)
    assert e.shape == (2, )
    assert e.rank == 1
    assert e.depends_on("a")
    assert not e.constant
    np.testing.assert_allclose(e.differentiate("a").value, np.array([1, 2]))
    with pytest.raises(TypeError):
        e.value
    assert e[0] == ScalarExpression("a")
    assert e[1] == ScalarExpression("2*a")
    assert e[0:1] == TensorExpression("[a]")
    np.testing.assert_allclose(e.get_compiled_array()(1.0), [1.0, 2.0])
    np.testing.assert_allclose(e.get_compiled_array()(2.0), [2.0, 4.0])

    e2 = TensorExpression(e)
    assert isinstance(str(e2), str)
    assert e == e2
    assert e is not e2
コード例 #2
0
def test_two_args():
    """ test simple expressions """
    e = ScalarExpression("2 * a ** b")
    assert e.depends_on("b")
    assert not e.constant
    assert e(4, 2) == 32
    assert e.get_compiled()(4, 2) == 32
    assert e.differentiate("a")(4, 2) == 16
    assert e.differentiate("b")(4, 2) == pytest.approx(32 * np.log(4))
    assert e.differentiate("c").value == 0
    assert e.shape == tuple()
    assert e.rank == 0
    assert e == ScalarExpression(e.expression)

    for x in [np.random.random(2), np.random.random((2, 5))]:
        res = 2 * x[0]**x[1]
        np.testing.assert_allclose(e(*x), res)
        np.testing.assert_allclose(e.get_compiled()(*x), res)
        if x.ndim == 1:
            func = e._get_function(single_arg=True)
            np.testing.assert_allclose(func(x), res)
            func = e.get_compiled(single_arg=True)
            np.testing.assert_allclose(func(x), res)

    g = e.derivatives
    assert g.shape == (2, )
    assert g.rank == 1
    assert not g.constant
    np.testing.assert_allclose(g(2, 3), [24, 16 * np.log(2)])
    np.testing.assert_allclose(g.get_compiled()(2, 3), [24, 16 * np.log(2)])
コード例 #3
0
def test_const():
    """ test simple expressions """
    for expr in [None, 1, "1", "a - a"]:
        e = ScalarExpression() if expr is None else ScalarExpression(expr)
        val = 0 if expr is None or expr == "a - a" else float(expr)
        assert e.constant
        assert e.value == val
        assert e() == val
        assert e.get_compiled()() == val
        assert not e.depends_on("a")
        assert e.differentiate("a").value == 0
        assert e.shape == tuple()
        assert e.rank == 0
        assert bool(e) == (val != 0)
        assert e.is_zero == (val == 0)
        assert not e.complex

        g = e.derivatives
        assert g.constant
        assert isinstance(str(g), str)
        np.testing.assert_equal(g.value, [])

        for f in [
                ScalarExpression(e),
                ScalarExpression(e.expression),
                ScalarExpression(e.value),
        ]:
            assert e is not f
            assert e._sympy_expr == f._sympy_expr
コード例 #4
0
def test_expression_special():
    """ test special cases of expressions """
    expr = ScalarExpression("Heaviside(x)")
    assert not expr.constant
    assert expr(-1) == 0
    assert expr(0) == 0.5
    assert expr(1) == 1

    f = expr.get_compiled()
    assert f(-1) == 0
    assert f(0) == 0.5
    assert f(1) == 1
コード例 #5
0
def test_single_arg():
    """ test simple expressions """
    e = ScalarExpression("2 * a")
    assert not e.constant
    assert e.depends_on("a")
    assert e(4) == 8
    assert e.get_compiled()(4) == 8
    assert e.differentiate("a").value == 2
    assert e.differentiate("b").value == 0
    assert e.shape == tuple()
    assert e.rank == 0
    assert bool(e)
    assert not e.is_zero

    assert e == ScalarExpression(e.expression)
    with pytest.raises(TypeError):
        e.value

    arr = np.random.random(5)
    np.testing.assert_allclose(e(arr), 2 * arr)
    np.testing.assert_allclose(e.get_compiled()(arr), 2 * arr)

    g = e.derivatives
    assert g.shape == (1, )
    assert g.constant
    assert g(3) == [2]
    assert g.get_compiled()(3) == [2]

    with pytest.raises(TypeError):
        ScalarExpression(np.exp)
コード例 #6
0
def test_expression_from_expression():
    """ test creating expressions from expressions """
    expr = ScalarExpression("sin(a)")
    assert expr == ScalarExpression(expr)
    assert expr != ScalarExpression(expr, ["a", "b"])
    with pytest.raises(RuntimeError):
        ScalarExpression(expr, "b")

    expr = ScalarExpression("sin(a)", ["a", "b"])
    assert expr == ScalarExpression(expr)
    assert expr != ScalarExpression(expr, "a")
    with pytest.raises(RuntimeError):
        ScalarExpression(expr, "b")
コード例 #7
0
def test_complex_expression():
    """ test expressions with complex numbers """
    for s in ["sqrt(-1)", "I"]:
        expr = ScalarExpression(s)
        assert expr.complex
        assert expr.constant
        assert expr.value == pytest.approx(1j)

    expr = TensorExpression("[1, I]")
    assert expr.complex
    assert expr.constant
    np.testing.assert_allclose(expr.value, np.array([1, 1j]))
コード例 #8
0
def test_derivatives():
    """ test vector expressions """
    e = ScalarExpression("a * b**2")
    assert e.depends_on("a") and e.depends_on("b")
    assert not e.constant
    assert e.rank == 0

    d1 = e.derivatives
    assert d1.shape == (2, )
    np.testing.assert_allclose(d1(2, 3), [9, 12])
    np.testing.assert_allclose(d1.get_compiled()(2, 3), [9, 12])

    d2 = d1.derivatives
    assert d2.shape == (2, 2)
    np.testing.assert_allclose(d2(2, 3), [[0, 6], [6, 4]])
    np.testing.assert_allclose(d2.get_compiled()(2, 3), [[0, 6], [6, 4]])

    d3 = d2.derivatives
    assert d3.shape == (2, 2, 2)

    d4 = d3.derivatives
    assert d4.shape == (2, 2, 2, 2)
    np.testing.assert_allclose(d4(2, 3), np.zeros((2, 2, 2, 2)))
    np.testing.assert_allclose(d4.get_compiled()(2, 3), np.zeros((2, 2, 2, 2)))
コード例 #9
0
ファイル: test_expressions.py プロジェクト: tefavidal/py-pde
def test_expression_user_funcs():
    """test the usage of user_funcs"""
    expr = ScalarExpression("func()", user_funcs={"func": lambda: 1})
    assert expr() == 1
    assert expr.get_compiled()() == 1
    assert expr.value == 1

    expr = ScalarExpression("f(pi)", user_funcs={"f": np.sin})
    assert expr.constant
    assert expr() == pytest.approx(0)
    assert expr.get_compiled()() == pytest.approx(0)
    assert expr.value == pytest.approx(0)

    expr = TensorExpression("[0, f(pi)]", user_funcs={"f": np.sin})
    assert expr.constant
    np.testing.assert_allclose(expr(), np.array([0, 0]), atol=1e-14)
    np.testing.assert_allclose(expr.get_compiled()(), np.array([0, 0]), atol=1e-14)
    np.testing.assert_allclose(expr.value, np.array([0, 0]), atol=1e-14)
コード例 #10
0
def test_indexed():
    """ test simple expressions """
    e = ScalarExpression("2 * a[0] ** a[1]", allow_indexed=True)
    assert not e.constant
    assert e.depends_on("a")

    a = np.array([4, 2])
    assert e(a) == 32
    assert e.get_compiled()(a) == 32

    assert e.differentiate("a[0]")(a) == 16
    assert e.differentiate("a[1]")(a) == pytest.approx(32 * np.log(4))

    with pytest.raises(RuntimeError):
        e.differentiate("a")
    with pytest.raises(RuntimeError):
        e.derivatives
コード例 #11
0
def test_expression_consts():
    """ test the usage of consts """
    expr = ScalarExpression("a", consts={"a": 1})
    assert expr.constant
    assert not expr.depends_on("a")
    assert expr() == 1
    assert expr.get_compiled()() == 1
    assert expr.value == 1

    expr = ScalarExpression("a + b", consts={"a": 1})
    assert not expr.constant
    assert not expr.depends_on("a") and expr.depends_on("b")
    assert expr(2) == 3
    assert expr.get_compiled()(2) == 3

    expr = ScalarExpression("a + b", consts={"a": np.array([1, 2])})
    assert not expr.constant
    np.testing.assert_allclose(expr(np.array([2, 3])), np.array([3, 5]))
    np.testing.assert_allclose(expr.get_compiled()(np.array([2, 3])),
                               np.array([3, 5]))
コード例 #12
0
def test_synonyms():
    """ test using synonyms in expression """
    e = ScalarExpression("2 * all", [["a", "all"]])
    assert e.depends_on("a")
    assert not e.depends_on("all")
コード例 #13
0
ファイル: test_expressions.py プロジェクト: tefavidal/py-pde
def test_const(caplog):
    """test simple expressions"""
    # test scalar expressions with constants
    for expr in [None, 1, "1", "a - a"]:
        e = ScalarExpression() if expr is None else ScalarExpression(expr)
        val = 0 if expr is None or expr == "a - a" else float(expr)
        assert e.constant
        assert e.value == val
        assert e() == val
        assert e.get_compiled()() == val
        assert not e.depends_on("a")
        assert e.differentiate("a").value == 0
        assert e.shape == tuple()
        assert e.rank == 0
        assert bool(e) == (val != 0)
        assert e.is_zero == (val == 0)
        assert not e.complex

        g = e.derivatives
        assert g.constant
        assert isinstance(str(g), str)
        np.testing.assert_equal(g.value, [])

        for f in [
            ScalarExpression(e),
            ScalarExpression(e.expression),
            ScalarExpression(e.value),
        ]:
            assert e is not f
            assert e._sympy_expr == f._sympy_expr

    # test whether wrong constants are check for
    field = ScalarField(UnitGrid([3]))
    e = ScalarExpression("scalar_field", consts={"scalar_field": field})
    with caplog.at_level(logging.WARNING):
        assert e() == field
    assert "field" in caplog.text
    if not nb.config.DISABLE_JIT:
        with pytest.raises(Exception):
            e.get_compiled()()