Beispiel #1
0
    def test_solver_doesnt_support_events(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: -0.1 * var}
        model.initial_conditions = {var: 1}
        # needs to work with multiple events (to avoid bug where only last event is
        # used)
        model.events = [
            pybamm.Event("var=0.5", pybamm.min(var - 0.5)),
            pybamm.Event("var=-0.5", pybamm.min(var + 0.5)),
        ]
        # No need to set parameters; can use base discretisation (no spatial operators)

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)
        # Solve
        solver = pybamm.JaxSolver()
        t_eval = np.linspace(0, 10, 100)
        with self.assertRaisesRegex(RuntimeError,
                                    "Terminate events not supported"):
            solver.solve(model, t_eval)
Beispiel #2
0
    def test_model_solver_with_inputs(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: -pybamm.InputParameter("rate") * var}
        model.initial_conditions = {var: 1}
        # No need to set parameters; can use base discretisation (no spatial
        # operators)

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)
        # Solve
        solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8)
        t_eval = np.linspace(0, 5, 80)

        t0 = time.perf_counter()
        solution = solver.solve(model, t_eval, inputs={"rate": 0.1})
        t_first_solve = time.perf_counter() - t0

        np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t),
                                   rtol=1e-6, atol=1e-6)

        t0 = time.perf_counter()
        solution = solver.solve(model, t_eval, inputs={"rate": 0.2})
        t_second_solve = time.perf_counter() - t0

        np.testing.assert_allclose(solution.y[0], np.exp(-0.2 * solution.t),
                                   rtol=1e-6, atol=1e-6)

        self.assertLess(t_second_solve, t_first_solve)
Beispiel #3
0
    def test_get_solve(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: -pybamm.InputParameter("rate") * var}
        model.initial_conditions = {var: 1}
        # No need to set parameters; can use base discretisation (no spatial
        # operators)

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)

        # test that another method string gives error
        with self.assertRaises(ValueError):
            solver = pybamm.JaxSolver(method="not_real")

        # Solve
        solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8)
        t_eval = np.linspace(0, 5, 80)

        with self.assertRaisesRegex(RuntimeError,
                                    "Model is not set up for solving"):
            solver.get_solve(model, t_eval)

        solver.solve(model, t_eval, inputs={"rate": 0.1})
        solver = solver.get_solve(model, t_eval)
        y = solver({"rate": 0.1})

        np.testing.assert_allclose(y[0],
                                   np.exp(-0.1 * t_eval),
                                   rtol=1e-6,
                                   atol=1e-6)

        y = solver({"rate": 0.2})

        np.testing.assert_allclose(y[0],
                                   np.exp(-0.2 * t_eval),
                                   rtol=1e-6,
                                   atol=1e-6)
Beispiel #4
0
    def test_semi_explicit_model(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        var2 = pybamm.Variable("var2", domain=domain)
        model.rhs = {var: 0.1 * var}
        model.algebraic = {var2: var2 - 2.0 * var}
        # give inconsistent initial conditions, should calculate correct ones
        model.initial_conditions = {var: 1.0, var2: 1.0}
        # No need to set parameters; can use base discretisation (no spatial operators)

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)

        # Solve
        solver = pybamm.JaxSolver(
            method='BDF', rtol=1e-8, atol=1e-8
        )
        t_eval = np.linspace(0, 1, 80)
        t0 = time.perf_counter()
        solution = solver.solve(model, t_eval)
        t_first_solve = time.perf_counter() - t0
        np.testing.assert_array_equal(solution.t, t_eval)
        soln = np.exp(0.1 * solution.t)
        np.testing.assert_allclose(solution.y[0], soln,
                                   rtol=1e-7, atol=1e-7)
        np.testing.assert_allclose(solution.y[-1], 2 * soln,
                                   rtol=1e-7, atol=1e-7)

        # Test time
        self.assertEqual(
            solution.total_time, solution.solve_time + solution.set_up_time
        )
        self.assertEqual(solution.termination, "final time")

        t0 = time.perf_counter()
        second_solution = solver.solve(model, t_eval)
        t_second_solve = time.perf_counter() - t0

        self.assertLess(t_second_solve, t_first_solve)
        np.testing.assert_array_equal(second_solution.y, solution.y)
Beispiel #5
0
    def test_solver_sensitivities(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: -pybamm.InputParameter("rate") * var}
        model.initial_conditions = {var: 1.0}
        # No need to set parameters; can use base discretisation (no spatial operators)

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)

        for method in ['RK45', 'BDF']:
            # Solve
            solver = pybamm.JaxSolver(
                method=method, rtol=1e-8, atol=1e-8
            )
            t_eval = np.linspace(0, 1, 80)

            h = 0.0001
            rate = 0.1

            # need to solve the model once to get it set up by the base solver
            solver.solve(model, t_eval, inputs={'rate': rate})
            solve = solver.get_solve(model, t_eval)

            # create a dummy "model" where we calculate the sum of the time series
            def solve_model(rate):
                return jax.numpy.sum(solve({'rate': rate}))

            # check answers with finite difference
            eval_plus = solve_model(rate + h)
            eval_neg = solve_model(rate - h)
            grad_num = (eval_plus - eval_neg) / (2 * h)

            grad_solve = jax.jit(jax.grad(solve_model))
            grad = grad_solve(rate)

            self.assertAlmostEqual(grad, grad_num, places=1)
Beispiel #6
0
    def test_solver_only_works_with_jax(self):
        model = pybamm.BaseModel()
        var = pybamm.Variable("var")
        model.rhs = {var: -pybamm.sqrt(var)}
        model.initial_conditions = {var: 1}
        # No need to set parameters; can use base discretisation (no spatial operators)

        # create discretisation
        disc = pybamm.Discretisation()
        disc.process_model(model)

        t_eval = np.linspace(0, 3, 100)

        # solver needs a model converted to jax
        for convert_to_format in ["casadi", "python", "something_else"]:
            model.convert_to_format = convert_to_format

            solver = pybamm.JaxSolver()
            with self.assertRaisesRegex(RuntimeError, "must be converted to JAX"):
                solver.solve(model, t_eval)
Beispiel #7
0
    def test_model_solver(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: 0.1 * var}
        model.initial_conditions = {var: 1.0}
        # No need to set parameters; can use base discretisation (no spatial operators)

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)

        for method in ['RK45', 'BDF']:
            # Solve
            solver = pybamm.JaxSolver(
                method=method, rtol=1e-8, atol=1e-8
            )
            t_eval = np.linspace(0, 1, 80)
            t0 = time.perf_counter()
            solution = solver.solve(model, t_eval)
            t_first_solve = time.perf_counter() - t0
            np.testing.assert_array_equal(solution.t, t_eval)
            np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t),
                                       rtol=1e-6, atol=1e-6)

            # Test time
            self.assertEqual(
                solution.total_time, solution.solve_time + solution.set_up_time
            )
            self.assertEqual(solution.termination, "final time")

            t0 = time.perf_counter()
            second_solution = solver.solve(model, t_eval)
            t_second_solve = time.perf_counter() - t0

            self.assertLess(t_second_solve, t_first_solve)
            np.testing.assert_array_equal(second_solution.y, solution.y)
Beispiel #8
0
 def test_jax_citations(self):
     citations = pybamm.citations
     citations._reset()
     self.assertNotIn("jax2018", citations._papers_to_cite)
     pybamm.JaxSolver()
     self.assertIn("jax2018", citations._papers_to_cite)