예제 #1
0
def test_two_args():
    """ test simple expressions """
    e = ScalarExpression("2 * a ** b")
    assert e.depends_on("b")
    assert not e.constant
    assert e(4, 2) == 32
    assert e.get_compiled()(4, 2) == 32
    assert e.differentiate("a")(4, 2) == 16
    assert e.differentiate("b")(4, 2) == pytest.approx(32 * np.log(4))
    assert e.differentiate("c").value == 0
    assert e.shape == tuple()
    assert e.rank == 0
    assert e == ScalarExpression(e.expression)

    for x in [np.random.random(2), np.random.random((2, 5))]:
        res = 2 * x[0]**x[1]
        np.testing.assert_allclose(e(*x), res)
        np.testing.assert_allclose(e.get_compiled()(*x), res)
        if x.ndim == 1:
            func = e._get_function(single_arg=True)
            np.testing.assert_allclose(func(x), res)
            func = e.get_compiled(single_arg=True)
            np.testing.assert_allclose(func(x), res)

    g = e.derivatives
    assert g.shape == (2, )
    assert g.rank == 1
    assert not g.constant
    np.testing.assert_allclose(g(2, 3), [24, 16 * np.log(2)])
    np.testing.assert_allclose(g.get_compiled()(2, 3), [24, 16 * np.log(2)])
예제 #2
0
def test_single_arg():
    """ test simple expressions """
    e = ScalarExpression("2 * a")
    assert not e.constant
    assert e.depends_on("a")
    assert e(4) == 8
    assert e.get_compiled()(4) == 8
    assert e.differentiate("a").value == 2
    assert e.differentiate("b").value == 0
    assert e.shape == tuple()
    assert e.rank == 0
    assert bool(e)
    assert not e.is_zero

    assert e == ScalarExpression(e.expression)
    with pytest.raises(TypeError):
        e.value

    arr = np.random.random(5)
    np.testing.assert_allclose(e(arr), 2 * arr)
    np.testing.assert_allclose(e.get_compiled()(arr), 2 * arr)

    g = e.derivatives
    assert g.shape == (1, )
    assert g.constant
    assert g(3) == [2]
    assert g.get_compiled()(3) == [2]

    with pytest.raises(TypeError):
        ScalarExpression(np.exp)
예제 #3
0
def test_const():
    """ test simple expressions """
    for expr in [None, 1, "1", "a - a"]:
        e = ScalarExpression() if expr is None else ScalarExpression(expr)
        val = 0 if expr is None or expr == "a - a" else float(expr)
        assert e.constant
        assert e.value == val
        assert e() == val
        assert e.get_compiled()() == val
        assert not e.depends_on("a")
        assert e.differentiate("a").value == 0
        assert e.shape == tuple()
        assert e.rank == 0
        assert bool(e) == (val != 0)
        assert e.is_zero == (val == 0)
        assert not e.complex

        g = e.derivatives
        assert g.constant
        assert isinstance(str(g), str)
        np.testing.assert_equal(g.value, [])

        for f in [
                ScalarExpression(e),
                ScalarExpression(e.expression),
                ScalarExpression(e.value),
        ]:
            assert e is not f
            assert e._sympy_expr == f._sympy_expr
예제 #4
0
def test_indexed():
    """ test simple expressions """
    e = ScalarExpression("2 * a[0] ** a[1]", allow_indexed=True)
    assert not e.constant
    assert e.depends_on("a")

    a = np.array([4, 2])
    assert e(a) == 32
    assert e.get_compiled()(a) == 32

    assert e.differentiate("a[0]")(a) == 16
    assert e.differentiate("a[1]")(a) == pytest.approx(32 * np.log(4))

    with pytest.raises(RuntimeError):
        e.differentiate("a")
    with pytest.raises(RuntimeError):
        e.derivatives
예제 #5
0
def test_const(caplog):
    """test simple expressions"""
    # test scalar expressions with constants
    for expr in [None, 1, "1", "a - a"]:
        e = ScalarExpression() if expr is None else ScalarExpression(expr)
        val = 0 if expr is None or expr == "a - a" else float(expr)
        assert e.constant
        assert e.value == val
        assert e() == val
        assert e.get_compiled()() == val
        assert not e.depends_on("a")
        assert e.differentiate("a").value == 0
        assert e.shape == tuple()
        assert e.rank == 0
        assert bool(e) == (val != 0)
        assert e.is_zero == (val == 0)
        assert not e.complex

        g = e.derivatives
        assert g.constant
        assert isinstance(str(g), str)
        np.testing.assert_equal(g.value, [])

        for f in [
            ScalarExpression(e),
            ScalarExpression(e.expression),
            ScalarExpression(e.value),
        ]:
            assert e is not f
            assert e._sympy_expr == f._sympy_expr

    # test whether wrong constants are check for
    field = ScalarField(UnitGrid([3]))
    e = ScalarExpression("scalar_field", consts={"scalar_field": field})
    with caplog.at_level(logging.WARNING):
        assert e() == field
    assert "field" in caplog.text
    if not nb.config.DISABLE_JIT:
        with pytest.raises(Exception):
            e.get_compiled()()