Esempio n. 1
0
def test_two_args():
    """ test simple expressions """
    e = ScalarExpression("2 * a ** b")
    assert e.depends_on("b")
    assert not e.constant
    assert e(4, 2) == 32
    assert e.get_compiled()(4, 2) == 32
    assert e.differentiate("a")(4, 2) == 16
    assert e.differentiate("b")(4, 2) == pytest.approx(32 * np.log(4))
    assert e.differentiate("c").value == 0
    assert e.shape == tuple()
    assert e.rank == 0
    assert e == ScalarExpression(e.expression)

    for x in [np.random.random(2), np.random.random((2, 5))]:
        res = 2 * x[0]**x[1]
        np.testing.assert_allclose(e(*x), res)
        np.testing.assert_allclose(e.get_compiled()(*x), res)
        if x.ndim == 1:
            func = e._get_function(single_arg=True)
            np.testing.assert_allclose(func(x), res)
            func = e.get_compiled(single_arg=True)
            np.testing.assert_allclose(func(x), res)

    g = e.derivatives
    assert g.shape == (2, )
    assert g.rank == 1
    assert not g.constant
    np.testing.assert_allclose(g(2, 3), [24, 16 * np.log(2)])
    np.testing.assert_allclose(g.get_compiled()(2, 3), [24, 16 * np.log(2)])
Esempio n. 2
0
def test_single_arg():
    """ test simple expressions """
    e = ScalarExpression("2 * a")
    assert not e.constant
    assert e.depends_on("a")
    assert e(4) == 8
    assert e.get_compiled()(4) == 8
    assert e.differentiate("a").value == 2
    assert e.differentiate("b").value == 0
    assert e.shape == tuple()
    assert e.rank == 0
    assert bool(e)
    assert not e.is_zero

    assert e == ScalarExpression(e.expression)
    with pytest.raises(TypeError):
        e.value

    arr = np.random.random(5)
    np.testing.assert_allclose(e(arr), 2 * arr)
    np.testing.assert_allclose(e.get_compiled()(arr), 2 * arr)

    g = e.derivatives
    assert g.shape == (1, )
    assert g.constant
    assert g(3) == [2]
    assert g.get_compiled()(3) == [2]

    with pytest.raises(TypeError):
        ScalarExpression(np.exp)
Esempio n. 3
0
def test_const():
    """ test simple expressions """
    for expr in [None, 1, "1", "a - a"]:
        e = ScalarExpression() if expr is None else ScalarExpression(expr)
        val = 0 if expr is None or expr == "a - a" else float(expr)
        assert e.constant
        assert e.value == val
        assert e() == val
        assert e.get_compiled()() == val
        assert not e.depends_on("a")
        assert e.differentiate("a").value == 0
        assert e.shape == tuple()
        assert e.rank == 0
        assert bool(e) == (val != 0)
        assert e.is_zero == (val == 0)
        assert not e.complex

        g = e.derivatives
        assert g.constant
        assert isinstance(str(g), str)
        np.testing.assert_equal(g.value, [])

        for f in [
                ScalarExpression(e),
                ScalarExpression(e.expression),
                ScalarExpression(e.value),
        ]:
            assert e is not f
            assert e._sympy_expr == f._sympy_expr
Esempio n. 4
0
def test_expression_user_funcs():
    """test the usage of user_funcs"""
    expr = ScalarExpression("func()", user_funcs={"func": lambda: 1})
    assert expr() == 1
    assert expr.get_compiled()() == 1
    assert expr.value == 1

    expr = ScalarExpression("f(pi)", user_funcs={"f": np.sin})
    assert expr.constant
    assert expr() == pytest.approx(0)
    assert expr.get_compiled()() == pytest.approx(0)
    assert expr.value == pytest.approx(0)

    expr = TensorExpression("[0, f(pi)]", user_funcs={"f": np.sin})
    assert expr.constant
    np.testing.assert_allclose(expr(), np.array([0, 0]), atol=1e-14)
    np.testing.assert_allclose(expr.get_compiled()(), np.array([0, 0]), atol=1e-14)
    np.testing.assert_allclose(expr.value, np.array([0, 0]), atol=1e-14)
Esempio n. 5
0
def test_expression_consts():
    """test the usage of consts"""
    expr = ScalarExpression("a", consts={"a": 1})
    assert expr.constant
    assert not expr.depends_on("a")
    assert expr() == 1
    assert expr.get_compiled()() == 1
    assert expr.value == 1

    expr = ScalarExpression("a + b", consts={"a": 1})
    assert not expr.constant
    assert not expr.depends_on("a") and expr.depends_on("b")
    assert expr(2) == 3
    assert expr.get_compiled()(2) == 3

    expr = ScalarExpression("a + b", consts={"a": np.array([1, 2])})
    assert not expr.constant
    np.testing.assert_allclose(expr(np.array([2, 3])), np.array([3, 5]))
    np.testing.assert_allclose(expr.get_compiled()(np.array([2, 3])), np.array([3, 5]))
Esempio n. 6
0
def test_expression_special():
    """ test special cases of expressions """
    expr = ScalarExpression("Heaviside(x)")
    assert not expr.constant
    assert expr(-1) == 0
    assert expr(0) == 0.5
    assert expr(1) == 1

    f = expr.get_compiled()
    assert f(-1) == 0
    assert f(0) == 0.5
    assert f(1) == 1
Esempio n. 7
0
def test_const(caplog):
    """test simple expressions"""
    # test scalar expressions with constants
    for expr in [None, 1, "1", "a - a"]:
        e = ScalarExpression() if expr is None else ScalarExpression(expr)
        val = 0 if expr is None or expr == "a - a" else float(expr)
        assert e.constant
        assert e.value == val
        assert e() == val
        assert e.get_compiled()() == val
        assert not e.depends_on("a")
        assert e.differentiate("a").value == 0
        assert e.shape == tuple()
        assert e.rank == 0
        assert bool(e) == (val != 0)
        assert e.is_zero == (val == 0)
        assert not e.complex

        g = e.derivatives
        assert g.constant
        assert isinstance(str(g), str)
        np.testing.assert_equal(g.value, [])

        for f in [
            ScalarExpression(e),
            ScalarExpression(e.expression),
            ScalarExpression(e.value),
        ]:
            assert e is not f
            assert e._sympy_expr == f._sympy_expr

    # test whether wrong constants are check for
    field = ScalarField(UnitGrid([3]))
    e = ScalarExpression("scalar_field", consts={"scalar_field": field})
    with caplog.at_level(logging.WARNING):
        assert e() == field
    assert "field" in caplog.text
    if not nb.config.DISABLE_JIT:
        with pytest.raises(Exception):
            e.get_compiled()()
Esempio n. 8
0
def test_indexed():
    """ test simple expressions """
    e = ScalarExpression("2 * a[0] ** a[1]", allow_indexed=True)
    assert not e.constant
    assert e.depends_on("a")

    a = np.array([4, 2])
    assert e(a) == 32
    assert e.get_compiled()(a) == 32

    assert e.differentiate("a[0]")(a) == 16
    assert e.differentiate("a[1]")(a) == pytest.approx(32 * np.log(4))

    with pytest.raises(RuntimeError):
        e.differentiate("a")
    with pytest.raises(RuntimeError):
        e.derivatives