def test_vanderpol():
    with pytest.raises(ImportError):
        diffeq_zoo.vanderpol_jax()
import probnum.diffeq as pnd
import probnum.problems as pnprob
import probnum.problems.zoo.diffeq as diffeq_zoo

# Jax dependency handling
# pylint: disable=unused-import
try:
    import jax
    import jax.numpy as jnp
    from jax.config import config

    config.update("jax_enable_x64", True)

    JAX_AVAILABLE = True

    IVPs = [diffeq_zoo.threebody_jax(), diffeq_zoo.vanderpol_jax()]


except ImportError:
    JAX_AVAILABLE = False
    IVPs = []


# Pytest decorators to select tests for each case
only_if_jax_available = pytest.mark.skipif(not JAX_AVAILABLE, reason="requires jax")
only_if_jax_is_not_available = pytest.mark.skipif(
    JAX_AVAILABLE,
    reason="Imports will be successful, thus catching the ImportError will fail",
)