def test_evaluator_jax_jacobian(self): a = pybamm.StateVector(slice(0, 1)) y_tests = [np.array([[2.0]]), np.array([[1.0]])] expr = a**2 expr_jac = 2 * a evaluator = pybamm.EvaluatorJax(expr) evaluator_jac_test = evaluator.get_jacobian() evaluator_jac = pybamm.EvaluatorJax(expr_jac) for y in y_tests: result_test = evaluator_jac_test.evaluate(t=None, y=y) result_true = evaluator_jac.evaluate(t=None, y=y) np.testing.assert_allclose(result_test, result_true)
def evaluate_model(self, use_known_evals=False, to_python=False, to_jax=False): result = np.empty((0, 1)) for eqn in [ self.model.concatenated_rhs, self.model.concatenated_algebraic ]: y = self.model.concatenated_initial_conditions.evaluate(t=0) if use_known_evals: eqn_eval, known_evals = eqn.evaluate(0, y, known_evals={}) elif to_python: evaluator = pybamm.EvaluatorPython(eqn) eqn_eval = evaluator.evaluate(0, y) elif to_jax: evaluator = pybamm.EvaluatorJax(eqn) eqn_eval = evaluator.evaluate(0, y) else: eqn_eval = eqn.evaluate(0, y) if eqn_eval.shape == (0, ): eqn_eval = eqn_eval[:, np.newaxis] result = np.concatenate([result, eqn_eval]) return result
def test_solver_with_inputs(self): # Create model model = pybamm.BaseModel() model.convert_to_format = "jax" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) # Solve t_eval = np.linspace(0, 10, 80) y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) rhs = pybamm.EvaluatorJax(model.concatenated_rhs) def fun(y, t, inputs): return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1) y = pybamm.jax_bdf_integrate( fun, y0, t_eval, {"rate": 0.1}, rtol=1e-9, atol=1e-9 ) np.testing.assert_allclose(y[:, 0].reshape(-1), np.exp(-0.1 * t_eval))
def test_solver(self): # Create model model = pybamm.BaseModel() model.convert_to_format = "jax" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: 0.1 * var} model.initial_conditions = {var: 1} # No need to set parameters; can use base discretisation (no spatial operators) # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) # Solve t_eval = np.linspace(0.0, 1.0, 80) y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) rhs = pybamm.EvaluatorJax(model.concatenated_rhs) def fun(y, t): return rhs.evaluate(t=t, y=y).reshape(-1) t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) t1 = time.perf_counter() - t0 # test accuracy np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6) t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) t2 = time.perf_counter() - t0 # second run should be much quicker self.assertLess(t2, t1) # test second run is accurate np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6)
def test_solver_sensitivities(self): # Create model model = pybamm.BaseModel() model.convert_to_format = "jax" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} # create discretisation mesh = get_mesh_for_testing(xpts=10) spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) # Solve t_eval = np.linspace(0, 10, 4) y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) rhs = pybamm.EvaluatorJax(model.concatenated_rhs) def fun(y, t, inputs): return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1) h = 0.0001 rate = 0.1 # create a dummy "model" where we calculate the sum of the time series @jax.jit def solve_bdf(rate): return jax.numpy.sum( pybamm.jax_bdf_integrate(fun, y0, t_eval, {'rate': rate}, rtol=1e-9, atol=1e-9)) # check answers with finite difference eval_plus = solve_bdf(rate + h) eval_neg = solve_bdf(rate - h) grad_num = (eval_plus - eval_neg) / (2 * h) grad_solve_bdf = jax.jit(jax.grad(solve_bdf)) grad_bdf = grad_solve_bdf(rate) self.assertAlmostEqual(grad_bdf, grad_num, places=3)
def test_evaluator_jax_debug(self): a = pybamm.StateVector(slice(0, 1)) expr = a**2 y_test = np.array([[2.0], [3.0]]) evaluator = pybamm.EvaluatorJax(expr) evaluator.debug(y=y_test)
def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) y_tests = [np.array([[2], [3]]), np.array([[1], [3]])] t_tests = [1, 2] # test a * b expr = a * b evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 6) result = evaluator.evaluate(t=None, y=np.array([[1], [3]])) self.assertEqual(result, 3) # test function(a*b) expr = pybamm.Function(test_function, a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 12) # test exp expr = pybamm.exp(a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, np.exp(6)) # test a constant expression expr = pybamm.Scalar(2) * pybamm.Scalar(3) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate() self.assertEqual(result, 6) # test a larger expression expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4 evaluator = pybamm.EvaluatorJax(expr) for y in y_tests: result = evaluator.evaluate(t=None, y=y) np.testing.assert_allclose(result, expr.evaluate(t=None, y=y)) # test something with time expr = a * pybamm.t evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a matrix multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) expr = A @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a heaviside a = pybamm.Vector(np.array([1, 2])) expr = a <= pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = a > pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a minimum or maximum a = pybamm.Vector(np.array([1, 2])) expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with an index expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a sparse matrix multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]]))) expr = A @ B @ C @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test sparse stack A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]]))) a = pybamm.StateVector(slice(0, 1)) expr = pybamm.SparseStack(A, a * B) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y).toarray()) # test numpy concatenation a = pybamm.Vector(np.array([[1], [2]])) b = pybamm.Vector(np.array([[3]])) expr = pybamm.NumpyConcatenation(a, b) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test Inner v = pybamm.Vector(np.ones(5), domain="test") w = pybamm.Vector(2 * np.ones(5), domain="test") expr = pybamm.Inner(v, w) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate() np.testing.assert_allclose(result, expr.evaluate())
def process(func, name, use_jacobian=None): def report(string): # don't log event conversion if "event" not in string: pybamm.logger.info(string) if use_jacobian is None: use_jacobian = model.use_jacobian if model.convert_to_format != "casadi": # Process with pybamm functions if model.use_simplify: report(f"Simplifying {name}") func = simp.simplify(func) if model.convert_to_format == "jax": report(f"Converting {name} to jax") jax_func = pybamm.EvaluatorJax(func) if use_jacobian: report(f"Calculating jacobian for {name}") jac = jacobian.jac(func, y) if model.use_simplify: report(f"Simplifying jacobian for {name}") jac = simp.simplify(jac) if model.convert_to_format == "python": report(f"Converting jacobian for {name} to python") jac = pybamm.EvaluatorPython(jac) elif model.convert_to_format == "jax": report(f"Converting jacobian for {name} to jax") jac = jax_func.get_jacobian() jac = jac.evaluate else: jac = None if model.convert_to_format == "python": report(f"Converting {name} to python") func = pybamm.EvaluatorPython(func) if model.convert_to_format == "jax": report(f"Converting {name} to jax") func = jax_func func = func.evaluate else: # Process with CasADi report(f"Converting {name} to CasADi") func = func.to_casadi(t_casadi, y_casadi, inputs=p_casadi) if use_jacobian: report(f"Calculating jacobian for {name} using CasADi") jac_casadi = casadi.jacobian(func, y_casadi) jac = casadi.Function( name, [t_casadi, y_casadi, p_casadi_stacked], [jac_casadi]) else: jac = None func = casadi.Function(name, [t_casadi, y_casadi, p_casadi_stacked], [func]) if name == "residuals": func_call = Residuals(func, name, model) else: func_call = SolverCallable(func, name, model) if jac is not None: jac_call = SolverCallable(jac, name + "_jac", model) else: jac_call = None return func, func_call, jac_call
def test_evaluator_jax_inputs(self): a = pybamm.InputParameter('a') expr = a**2 evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(inputs={'a': 2}) self.assertEqual(result, 4)
def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) y_tests = [ np.array([[2.0], [3.0]]), np.array([[1.0], [3.0]]), np.array([1.0, 3.0]), ] t_tests = [1.0, 2.0] # test a * b expr = a * b evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 6) result = evaluator.evaluate(t=None, y=np.array([[1], [3]])) self.assertEqual(result, 3) # test function(a*b) expr = pybamm.Function(test_function, a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 12) # test exp expr = pybamm.exp(a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, np.exp(6)) # test a constant expression expr = pybamm.Scalar(2) * pybamm.Scalar(3) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate() self.assertEqual(result, 6) # test a larger expression expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4 evaluator = pybamm.EvaluatorJax(expr) for y in y_tests: result = evaluator.evaluate(t=None, y=y) np.testing.assert_allclose(result, expr.evaluate(t=None, y=y)) # test something with time expr = a * pybamm.t evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a matrix multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) expr = A @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a heaviside a = pybamm.Vector(np.array([1, 2])) expr = a <= pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = a > pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a minimum or maximum a = pybamm.Vector(np.array([1, 2])) expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with an index expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a sparse matrix-vector multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]]))) expr = A @ B @ C @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test the sparse-scalar multiplication A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) for expr in [ A * pybamm.t @ pybamm.StateVector(slice(0, 2)), pybamm.t * A @ pybamm.StateVector(slice(0, 2)), ]: evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test the sparse-scalar division A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) expr = A / (1.0 + pybamm.t) @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test sparse stack A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]]))) a = pybamm.StateVector(slice(0, 1)) expr = pybamm.SparseStack(A, a * B) with self.assertRaises(NotImplementedError): evaluator = pybamm.EvaluatorJax(expr) # test sparse mat-mat mult A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]]))) a = pybamm.StateVector(slice(0, 1)) expr = A @ (a * B) with self.assertRaises(NotImplementedError): evaluator = pybamm.EvaluatorJax(expr) # test numpy concatenation a = pybamm.Vector(np.array([[1], [2]])) b = pybamm.Vector(np.array([[3]])) expr = pybamm.NumpyConcatenation(a, b) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test Inner A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1]]))) v = pybamm.StateVector(slice(0, 1)) for expr in [ pybamm.Inner(A, v) @ v, pybamm.Inner(v, A) @ v, pybamm.Inner(v, v) @ v ]: evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))