Exemplo n.º 1
0
def test_initialize_with_taylormode(any_order):
    """Make sure that the values are close(ish) to the truth."""
    r2b_jax = diffeq_zoo.threebody_jax()
    ode_dim = 4
    expected = pnss.Integrator._convert_derivwise_to_coordwise(
        THREEBODY_INITS[: ode_dim * (any_order + 1)],
        ordint=any_order,
        spatialdim=ode_dim,
    )

    prior = pnss.IBM(
        ordint=any_order,
        spatialdim=ode_dim,
        forward_implementation="sqrt",
        backward_implementation="sqrt",
    )

    initrv = pnrv.Normal(np.zeros(prior.dimension), np.eye(prior.dimension))

    received_rv = pnde.initialize_odefilter_with_taylormode(
        r2b_jax.f, r2b_jax.y0, r2b_jax.t0, prior=prior, initrv=initrv
    )

    np.testing.assert_allclose(received_rv.mean, expected)
    np.testing.assert_allclose(received_rv.std, 0.0)
Exemplo n.º 2
0
    def test_call(self, any_order):
        r2b_jax = diffeq_zoo.threebody_jax()

        expected = randprocs.markov.integrator.convert.convert_derivwise_to_coordwise(
            _known_initial_derivatives.THREEBODY_INITS[
                : r2b_jax.dimension * (any_order + 1)
            ],
            num_derivatives=any_order,
            wiener_process_dimension=r2b_jax.dimension,
        )

        prior_process = self._construct_prior_process(
            order=any_order, spatialdim=r2b_jax.dimension, t0=r2b_jax.t0
        )

        received_rv = self.taylor_init(ivp=r2b_jax, prior_process=prior_process)

        assert isinstance(received_rv, randvars.Normal)
        np.testing.assert_allclose(received_rv.mean, expected)
        np.testing.assert_allclose(received_rv.std, 0.0)
Exemplo n.º 3
0
def problem_threebody():
    ivp = diffeq_zoo.threebody_jax()
    threebody_inits_matrix_full = known_initial_derivatives.THREEBODY_INITS
    return ivp, threebody_inits_matrix_full
def test_threebody():
    with pytest.raises(ImportError):
        diffeq_zoo.threebody_jax()
import probnum.diffeq as pnd
import probnum.problems as pnprob
import probnum.problems.zoo.diffeq as diffeq_zoo

# Jax dependency handling
# pylint: disable=unused-import
try:
    import jax
    import jax.numpy as jnp
    from jax.config import config

    config.update("jax_enable_x64", True)

    JAX_AVAILABLE = True

    IVPs = [diffeq_zoo.threebody_jax(), diffeq_zoo.vanderpol_jax()]


except ImportError:
    JAX_AVAILABLE = False
    IVPs = []


# Pytest decorators to select tests for each case
only_if_jax_available = pytest.mark.skipif(not JAX_AVAILABLE, reason="requires jax")
only_if_jax_is_not_available = pytest.mark.skipif(
    JAX_AVAILABLE,
    reason="Imports will be successful, thus catching the ImportError will fail",
)