def test_mass_matrix(self): # Solve t_eval = np.linspace(0.0, 1.0, 80) def fun(y, t): return jax.numpy.stack([0.1 * y[0], y[1] - 2.0 * y[0]]) mass = jax.numpy.array([[2.0, 0.0], [0.0, 0.0]]) # give some bad initial conditions, solver should calculate correct ones using # this as a guess y0 = jax.numpy.array([1.0, 1.5]) t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) t1 = time.perf_counter() - t0 # test accuracy soln = np.exp(0.05 * t_eval) np.testing.assert_allclose(y[:, 0], soln, rtol=1e-7, atol=1e-7) np.testing.assert_allclose(y[:, 1], 2.0 * soln, rtol=1e-7, atol=1e-7) t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) t2 = time.perf_counter() - t0 # second run should be much quicker self.assertLess(t2, t1) # test second run is accurate np.testing.assert_allclose(y[:, 0], np.exp(0.05 * t_eval), rtol=1e-7, atol=1e-7)
def test_solver_with_inputs(self): # Create model model = pybamm.BaseModel() model.convert_to_format = "jax" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) # Solve t_eval = np.linspace(0, 10, 80) y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) rhs = pybamm.EvaluatorJax(model.concatenated_rhs) def fun(y, t, inputs): return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1) y = pybamm.jax_bdf_integrate( fun, y0, t_eval, {"rate": 0.1}, rtol=1e-9, atol=1e-9 ) np.testing.assert_allclose(y[:, 0].reshape(-1), np.exp(-0.1 * t_eval))
def solve_bdf(rate): return jax.numpy.sum( pybamm.jax_bdf_integrate(fun, y0, t_eval, {'rate': rate}, rtol=1e-9, atol=1e-9))
def test_solver(self): # Create model model = pybamm.BaseModel() model.convert_to_format = "jax" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: 0.1 * var} model.initial_conditions = {var: 1} # No need to set parameters; can use base discretisation (no spatial operators) # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) # Solve t_eval = np.linspace(0.0, 1.0, 80) y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) rhs = pybamm.EvaluatorJax(model.concatenated_rhs) def fun(y, t): return rhs.evaluate(t=t, y=y).reshape(-1) t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) t1 = time.perf_counter() - t0 # test accuracy np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6) t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) t2 = time.perf_counter() - t0 # second run should be much quicker self.assertLess(t2, t1) # test second run is accurate np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6)
def solve_model_bdf(inputs): y = pybamm.jax_bdf_integrate(rhs_dae, y0, t_eval, inputs, rtol=self.rtol, atol=self.atol, mass=mass, **self.extra_options) return jnp.transpose(y)