コード例 #1
0
    def test_call(self):
        """Test whether __call__ wraps the function correctly."""

        # t > 0, hence the state is not affected
        dummy_state = diffeq.ODESolverState(ivp=None, rv=3.0, t=1.0)
        updated = self.discrete_callbacks(state=dummy_state)
        assert updated.rv == 3.0

        # t < 0, hence the state is multiplied by two
        dummy_state = diffeq.ODESolverState(ivp=None, rv=3.0, t=-1.0)
        updated = self.discrete_callbacks(state=dummy_state)
        assert updated.rv == 6.0
コード例 #2
0
ファイル: test_odesolver.py プロジェクト: tskarvone/probnum
 def initialize(self, ivp):
     return diffeq.ODESolverState(
         ivp=ivp,
         rv=Constant(ivp.y0),
         t=ivp.t0,
         error_estimate=np.nan,
         reference_state=None,
     )
コード例 #3
0
def test_step(solvers, start_point, stop_point, y):
    """When performing two small similar steps, their output should be similar.

    For the first step no error estimation is available, the first step is therefore
    deterministic and to check for non-determinism, two steps have to be performed.
    """

    _, perturbedsolver, ode = solvers
    perturbedsolver.initialize(ode)

    test_state = diffeq.ODESolverState(
        ivp=ode, rv=y, t=start_point, error_estimate=None, reference_state=None
    )
    step_after_first_step = perturbedsolver.attempt_step(
        test_state, dt=stop_point - start_point
    )
    perturbed_y_1 = perturbedsolver.attempt_step(
        step_after_first_step, dt=stop_point - start_point
    )

    perturbedsolver.initialize(ode)
    test_state = diffeq.ODESolverState(
        ivp=ode, rv=y, t=start_point, error_estimate=None, reference_state=None
    )

    step_after_first_step = perturbedsolver.attempt_step(
        test_state, dt=stop_point - start_point
    )
    perturbed_y_2 = perturbedsolver.attempt_step(
        step_after_first_step, dt=stop_point - start_point
    )

    np.testing.assert_allclose(
        perturbed_y_1.rv.mean, perturbed_y_2.rv.mean, atol=1e-4, rtol=1e-4
    )

    np.testing.assert_allclose(
        perturbed_y_1.error_estimate,
        perturbed_y_2.error_estimate,
        atol=1e-4,
        rtol=1e-4,
    )
    assert np.all(np.not_equal(perturbed_y_1.rv.mean, perturbed_y_2.rv.mean))
コード例 #4
0
ファイル: test_odesolver.py プロジェクト: tskarvone/probnum
    def attempt_step(self, state, dt):
        t, x = state.t, state.rv.mean
        xnew = x + dt * state.ivp.f(t, x)

        # return nan as error estimate to ensure that it is not used
        new_state = diffeq.ODESolverState(
            ivp=state.ivp,
            rv=Constant(xnew),
            t=t + dt,
            error_estimate=np.nan,
            reference_state=xnew,
        )
        return new_state
コード例 #5
0
def test_step_variables(solvers, y, start_point, stop_point):
    testsolver, scipysolver, ode = solvers

    teststate = diffeq.ODESolverState(
        ivp=ode,
        rv=randvars.Constant(y),
        t=start_point,
        error_estimate=None,
        reference_state=None,
    )
    testsolver.initialize(ode)
    solver_y_new = testsolver.attempt_step(teststate, dt=stop_point - start_point)
    y_new, f_new = rk.rk_step(
        scipysolver.fun,
        start_point,
        y,
        scipysolver.f,
        stop_point - start_point,
        scipysolver.A,
        scipysolver.B,
        scipysolver.C,
        scipysolver.K,
    )

    # error estimation is correct
    scipy_error_estimation = scipysolver._estimate_error(
        scipysolver.K, stop_point - start_point
    )
    np.testing.assert_allclose(
        solver_y_new.error_estimate, scipy_error_estimation, atol=1e-13, rtol=1e-13
    )

    # locations are correct
    np.testing.assert_allclose(
        testsolver.solver.t_old, start_point, atol=1e-13, rtol=1e-13
    )
    np.testing.assert_allclose(testsolver.solver.t, stop_point, atol=1e-13, rtol=1e-13)
    np.testing.assert_allclose(
        testsolver.solver.h_previous,
        stop_point - start_point,
        atol=1e-13,
        rtol=1e-13,
    )

    # evaluations are correct
    np.testing.assert_allclose(testsolver.solver.y_old, y, atol=1e-13, rtol=1e-13)
    np.testing.assert_allclose(testsolver.solver.y, y_new, atol=1e-13, rtol=1e-13)
    np.testing.assert_allclose(
        testsolver.solver.h_abs, stop_point - start_point, atol=1e-13, rtol=1e-13
    )
    np.testing.assert_allclose(testsolver.solver.f, f_new, atol=1e-13, rtol=1e-13)
コード例 #6
0
def test_step_execution(solvers):
    testsolver, scipysolver, ode = solvers
    scipysolver.step()

    # perform step of the same size
    teststate = diffeq.ODESolverState(
        ivp=ode,
        rv=randvars.Constant(scipysolver.y_old),
        t=scipysolver.t_old,
        error_estimate=None,
        reference_state=None,
    )
    testsolver.initialize(ode)
    dt = scipysolver.t - scipysolver.t_old
    new_state = testsolver.attempt_step(teststate, dt)
    np.testing.assert_allclose(scipysolver.y, new_state.rv.mean)
コード例 #7
0
def test_dense_output(solvers):
    testsolver, scipysolver, ode = solvers

    # perform steps of the same size
    testsolver.initialize(ode)
    scipysolver.step()
    teststate = diffeq.ODESolverState(
        ivp=ode,
        rv=randvars.Constant(scipysolver.y_old),
        t=scipysolver.t_old,
        error_estimate=None,
        reference_state=None,
    )
    state = testsolver.attempt_step(
        state=teststate, dt=scipysolver.t - scipysolver.t_old
    )

    # sanity check: the steps are the same
    # (this is contained in a different test already, but if this one
    # does not work, the dense output test below is meaningless)
    np.testing.assert_allclose(scipysolver.y, state.rv.mean)

    testsolver_dense = testsolver.dense_output()
    scipy_dense = scipysolver._dense_output_impl()

    t_old = scipysolver.t_old
    t = scipysolver.t
    t_mid = (t_old + t) / 2.0

    for time in [t_old, t, t_mid]:
        test_dense = testsolver_dense(time)
        ref_dense = scipy_dense(time)
        np.testing.assert_allclose(
            test_dense,
            ref_dense,
            atol=1e-13,
            rtol=1e-13,
        )