def test_two_args(): """ test simple expressions """ e = ScalarExpression("2 * a ** b") assert e.depends_on("b") assert not e.constant assert e(4, 2) == 32 assert e.get_compiled()(4, 2) == 32 assert e.differentiate("a")(4, 2) == 16 assert e.differentiate("b")(4, 2) == pytest.approx(32 * np.log(4)) assert e.differentiate("c").value == 0 assert e.shape == tuple() assert e.rank == 0 assert e == ScalarExpression(e.expression) for x in [np.random.random(2), np.random.random((2, 5))]: res = 2 * x[0]**x[1] np.testing.assert_allclose(e(*x), res) np.testing.assert_allclose(e.get_compiled()(*x), res) if x.ndim == 1: func = e._get_function(single_arg=True) np.testing.assert_allclose(func(x), res) func = e.get_compiled(single_arg=True) np.testing.assert_allclose(func(x), res) g = e.derivatives assert g.shape == (2, ) assert g.rank == 1 assert not g.constant np.testing.assert_allclose(g(2, 3), [24, 16 * np.log(2)]) np.testing.assert_allclose(g.get_compiled()(2, 3), [24, 16 * np.log(2)])
def test_single_arg(): """ test simple expressions """ e = ScalarExpression("2 * a") assert not e.constant assert e.depends_on("a") assert e(4) == 8 assert e.get_compiled()(4) == 8 assert e.differentiate("a").value == 2 assert e.differentiate("b").value == 0 assert e.shape == tuple() assert e.rank == 0 assert bool(e) assert not e.is_zero assert e == ScalarExpression(e.expression) with pytest.raises(TypeError): e.value arr = np.random.random(5) np.testing.assert_allclose(e(arr), 2 * arr) np.testing.assert_allclose(e.get_compiled()(arr), 2 * arr) g = e.derivatives assert g.shape == (1, ) assert g.constant assert g(3) == [2] assert g.get_compiled()(3) == [2] with pytest.raises(TypeError): ScalarExpression(np.exp)
def test_const(): """ test simple expressions """ for expr in [None, 1, "1", "a - a"]: e = ScalarExpression() if expr is None else ScalarExpression(expr) val = 0 if expr is None or expr == "a - a" else float(expr) assert e.constant assert e.value == val assert e() == val assert e.get_compiled()() == val assert not e.depends_on("a") assert e.differentiate("a").value == 0 assert e.shape == tuple() assert e.rank == 0 assert bool(e) == (val != 0) assert e.is_zero == (val == 0) assert not e.complex g = e.derivatives assert g.constant assert isinstance(str(g), str) np.testing.assert_equal(g.value, []) for f in [ ScalarExpression(e), ScalarExpression(e.expression), ScalarExpression(e.value), ]: assert e is not f assert e._sympy_expr == f._sympy_expr
def test_indexed(): """ test simple expressions """ e = ScalarExpression("2 * a[0] ** a[1]", allow_indexed=True) assert not e.constant assert e.depends_on("a") a = np.array([4, 2]) assert e(a) == 32 assert e.get_compiled()(a) == 32 assert e.differentiate("a[0]")(a) == 16 assert e.differentiate("a[1]")(a) == pytest.approx(32 * np.log(4)) with pytest.raises(RuntimeError): e.differentiate("a") with pytest.raises(RuntimeError): e.derivatives
def test_const(caplog): """test simple expressions""" # test scalar expressions with constants for expr in [None, 1, "1", "a - a"]: e = ScalarExpression() if expr is None else ScalarExpression(expr) val = 0 if expr is None or expr == "a - a" else float(expr) assert e.constant assert e.value == val assert e() == val assert e.get_compiled()() == val assert not e.depends_on("a") assert e.differentiate("a").value == 0 assert e.shape == tuple() assert e.rank == 0 assert bool(e) == (val != 0) assert e.is_zero == (val == 0) assert not e.complex g = e.derivatives assert g.constant assert isinstance(str(g), str) np.testing.assert_equal(g.value, []) for f in [ ScalarExpression(e), ScalarExpression(e.expression), ScalarExpression(e.value), ]: assert e is not f assert e._sympy_expr == f._sympy_expr # test whether wrong constants are check for field = ScalarField(UnitGrid([3])) e = ScalarExpression("scalar_field", consts={"scalar_field": field}) with caplog.at_level(logging.WARNING): assert e() == field assert "field" in caplog.text if not nb.config.DISABLE_JIT: with pytest.raises(Exception): e.get_compiled()()