コード例 #1
0
ファイル: rk_solver_test.py プロジェクト: siavashadpey/EpiMod
    def test_interface(self):
        rk = RKSolver(0, 1)
        rk.initial_time = 0.5
        rk.final_time = 5
        rk.n_steps = 10
        self.assertEqual(rk.initial_time,0.5)
        self.assertEqual(rk.final_time,5)
        self.assertEqual(rk.n_steps,10)

        eqn = Seir()
        rk.equation = eqn
        self.assertEqual(rk.equation, eqn)

        rk.output_frequency = 20
        self.assertEqual(rk.output_frequency, 20)
コード例 #2
0
    def test_correct_caching(self):
        # setup equation
        eqn = Seir(population=1)

        # setup ode solver
        ti = 0.
        tf = 2.
        n_steps = 3
        rk = RKSolver(ti, tf, n_steps)
        rk.output_frequency = 1
        rk.set_output_storing_flag(True)
        rk.equation = eqn
        u0 = np.array([100., 0., 10., 0.])
        du0_dp = np.zeros((eqn.n_components(), eqn.n_parameters()))
        rk.set_initial_condition(u0, du0_dp)
        rk.set_output_gradient_flag(True)

        # setup cached simulation object
        cached_sim = CachedSEIRSimulation(rk)

        params = np.array([2.3, 0.2, 1. / 3., 1. / 4.])

        (f, df) = cached_sim(params)

        f1 = np.copy(f)
        df1 = np.copy(df)

        params2 = np.array([2.32, 0.2, 1. / 3., 1. / 4.])
        (f2, df2) = cached_sim(params2)

        assert not np.allclose(f1, f2)
        assert not np.allclose(df1, df2)

        rk.final_time = 3
        (f3, df3) = cached_sim(params)

        assert not np.allclose(f1, f3)
        assert not np.allclose(df1, df3)