Beispiel #1
0
    def test_evaluator_jax_jacobian(self):
        a = pybamm.StateVector(slice(0, 1))
        y_tests = [np.array([[2.0]]), np.array([[1.0]])]

        expr = a**2
        expr_jac = 2 * a
        evaluator = pybamm.EvaluatorJax(expr)
        evaluator_jac_test = evaluator.get_jacobian()
        evaluator_jac = pybamm.EvaluatorJax(expr_jac)
        for y in y_tests:
            result_test = evaluator_jac_test.evaluate(t=None, y=y)
            result_true = evaluator_jac.evaluate(t=None, y=y)
            np.testing.assert_allclose(result_test, result_true)
    def evaluate_model(self,
                       use_known_evals=False,
                       to_python=False,
                       to_jax=False):
        result = np.empty((0, 1))
        for eqn in [
                self.model.concatenated_rhs, self.model.concatenated_algebraic
        ]:

            y = self.model.concatenated_initial_conditions.evaluate(t=0)
            if use_known_evals:
                eqn_eval, known_evals = eqn.evaluate(0, y, known_evals={})
            elif to_python:
                evaluator = pybamm.EvaluatorPython(eqn)
                eqn_eval = evaluator.evaluate(0, y)
            elif to_jax:
                evaluator = pybamm.EvaluatorJax(eqn)
                eqn_eval = evaluator.evaluate(0, y)
            else:
                eqn_eval = eqn.evaluate(0, y)

            if eqn_eval.shape == (0, ):
                eqn_eval = eqn_eval[:, np.newaxis]

            result = np.concatenate([result, eqn_eval])

        return result
Beispiel #3
0
    def test_solver_with_inputs(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: -pybamm.InputParameter("rate") * var}
        model.initial_conditions = {var: 1}

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)

        # Solve
        t_eval = np.linspace(0, 10, 80)
        y0 = model.concatenated_initial_conditions.evaluate().reshape(-1)
        rhs = pybamm.EvaluatorJax(model.concatenated_rhs)

        def fun(y, t, inputs):
            return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1)

        y = pybamm.jax_bdf_integrate(
            fun, y0, t_eval, {"rate": 0.1}, rtol=1e-9, atol=1e-9
        )

        np.testing.assert_allclose(y[:, 0].reshape(-1), np.exp(-0.1 * t_eval))
    def test_solver(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: 0.1 * var}
        model.initial_conditions = {var: 1}
        # No need to set parameters; can use base discretisation (no spatial operators)

        # create discretisation
        mesh = get_mesh_for_testing()
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)

        # Solve
        t_eval = np.linspace(0.0, 1.0, 80)
        y0 = model.concatenated_initial_conditions.evaluate().reshape(-1)
        rhs = pybamm.EvaluatorJax(model.concatenated_rhs)

        def fun(y, t):
            return rhs.evaluate(t=t, y=y).reshape(-1)

        t0 = time.perf_counter()
        y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8)
        t1 = time.perf_counter() - t0

        # test accuracy
        np.testing.assert_allclose(y[:, 0],
                                   np.exp(0.1 * t_eval),
                                   rtol=1e-6,
                                   atol=1e-6)

        t0 = time.perf_counter()
        y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8)
        t2 = time.perf_counter() - t0

        # second run should be much quicker
        self.assertLess(t2, t1)

        # test second run is accurate
        np.testing.assert_allclose(y[:, 0],
                                   np.exp(0.1 * t_eval),
                                   rtol=1e-6,
                                   atol=1e-6)
    def test_solver_sensitivities(self):
        # Create model
        model = pybamm.BaseModel()
        model.convert_to_format = "jax"
        domain = ["negative electrode", "separator", "positive electrode"]
        var = pybamm.Variable("var", domain=domain)
        model.rhs = {var: -pybamm.InputParameter("rate") * var}
        model.initial_conditions = {var: 1}

        # create discretisation
        mesh = get_mesh_for_testing(xpts=10)
        spatial_methods = {"macroscale": pybamm.FiniteVolume()}
        disc = pybamm.Discretisation(mesh, spatial_methods)
        disc.process_model(model)

        # Solve
        t_eval = np.linspace(0, 10, 4)
        y0 = model.concatenated_initial_conditions.evaluate().reshape(-1)
        rhs = pybamm.EvaluatorJax(model.concatenated_rhs)

        def fun(y, t, inputs):
            return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1)

        h = 0.0001
        rate = 0.1

        # create a dummy "model" where we calculate the sum of the time series
        @jax.jit
        def solve_bdf(rate):
            return jax.numpy.sum(
                pybamm.jax_bdf_integrate(fun,
                                         y0,
                                         t_eval, {'rate': rate},
                                         rtol=1e-9,
                                         atol=1e-9))

        # check answers with finite difference
        eval_plus = solve_bdf(rate + h)
        eval_neg = solve_bdf(rate - h)
        grad_num = (eval_plus - eval_neg) / (2 * h)

        grad_solve_bdf = jax.jit(jax.grad(solve_bdf))
        grad_bdf = grad_solve_bdf(rate)

        self.assertAlmostEqual(grad_bdf, grad_num, places=3)
Beispiel #6
0
 def test_evaluator_jax_debug(self):
     a = pybamm.StateVector(slice(0, 1))
     expr = a**2
     y_test = np.array([[2.0], [3.0]])
     evaluator = pybamm.EvaluatorJax(expr)
     evaluator.debug(y=y_test)
Beispiel #7
0
    def test_evaluator_jax(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        y_tests = [np.array([[2], [3]]), np.array([[1], [3]])]
        t_tests = [1, 2]

        # test a * b
        expr = a * b
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 6)
        result = evaluator.evaluate(t=None, y=np.array([[1], [3]]))
        self.assertEqual(result, 3)

        # test function(a*b)
        expr = pybamm.Function(test_function, a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 12)

        # test exp
        expr = pybamm.exp(a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, np.exp(6))

        # test a constant expression
        expr = pybamm.Scalar(2) * pybamm.Scalar(3)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate()
        self.assertEqual(result, 6)

        # test a larger expression
        expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4
        evaluator = pybamm.EvaluatorJax(expr)
        for y in y_tests:
            result = evaluator.evaluate(t=None, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=None, y=y))

        # test something with time
        expr = a * pybamm.t
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        expr = A @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a heaviside
        a = pybamm.Vector(np.array([1, 2]))
        expr = a <= pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = a > pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a minimum or maximum
        a = pybamm.Vector(np.array([1, 2]))
        expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with an index
        expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a sparse matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]])))
        expr = A @ B @ C @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test sparse stack
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        a = pybamm.StateVector(slice(0, 1))
        expr = pybamm.SparseStack(A, a * B)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result,
                                       expr.evaluate(t=t, y=y).toarray())

        # test numpy concatenation
        a = pybamm.Vector(np.array([[1], [2]]))
        b = pybamm.Vector(np.array([[3]]))
        expr = pybamm.NumpyConcatenation(a, b)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test Inner
        v = pybamm.Vector(np.ones(5), domain="test")
        w = pybamm.Vector(2 * np.ones(5), domain="test")
        expr = pybamm.Inner(v, w)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate()
        np.testing.assert_allclose(result, expr.evaluate())
Beispiel #8
0
        def process(func, name, use_jacobian=None):
            def report(string):
                # don't log event conversion
                if "event" not in string:
                    pybamm.logger.info(string)

            if use_jacobian is None:
                use_jacobian = model.use_jacobian
            if model.convert_to_format != "casadi":
                # Process with pybamm functions
                if model.use_simplify:
                    report(f"Simplifying {name}")
                    func = simp.simplify(func)

                if model.convert_to_format == "jax":
                    report(f"Converting {name} to jax")
                    jax_func = pybamm.EvaluatorJax(func)

                if use_jacobian:
                    report(f"Calculating jacobian for {name}")
                    jac = jacobian.jac(func, y)
                    if model.use_simplify:
                        report(f"Simplifying jacobian for {name}")
                        jac = simp.simplify(jac)
                    if model.convert_to_format == "python":
                        report(f"Converting jacobian for {name} to python")
                        jac = pybamm.EvaluatorPython(jac)
                    elif model.convert_to_format == "jax":
                        report(f"Converting jacobian for {name} to jax")
                        jac = jax_func.get_jacobian()
                    jac = jac.evaluate
                else:
                    jac = None

                if model.convert_to_format == "python":
                    report(f"Converting {name} to python")
                    func = pybamm.EvaluatorPython(func)
                if model.convert_to_format == "jax":
                    report(f"Converting {name} to jax")
                    func = jax_func

                func = func.evaluate

            else:
                # Process with CasADi
                report(f"Converting {name} to CasADi")
                func = func.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
                if use_jacobian:
                    report(f"Calculating jacobian for {name} using CasADi")
                    jac_casadi = casadi.jacobian(func, y_casadi)
                    jac = casadi.Function(
                        name, [t_casadi, y_casadi, p_casadi_stacked],
                        [jac_casadi])
                else:
                    jac = None
                func = casadi.Function(name,
                                       [t_casadi, y_casadi, p_casadi_stacked],
                                       [func])
            if name == "residuals":
                func_call = Residuals(func, name, model)
            else:
                func_call = SolverCallable(func, name, model)
            if jac is not None:
                jac_call = SolverCallable(jac, name + "_jac", model)
            else:
                jac_call = None
            return func, func_call, jac_call
Beispiel #9
0
 def test_evaluator_jax_inputs(self):
     a = pybamm.InputParameter('a')
     expr = a**2
     evaluator = pybamm.EvaluatorJax(expr)
     result = evaluator.evaluate(inputs={'a': 2})
     self.assertEqual(result, 4)
Beispiel #10
0
    def test_evaluator_jax(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        y_tests = [
            np.array([[2.0], [3.0]]),
            np.array([[1.0], [3.0]]),
            np.array([1.0, 3.0]),
        ]
        t_tests = [1.0, 2.0]

        # test a * b
        expr = a * b
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 6)
        result = evaluator.evaluate(t=None, y=np.array([[1], [3]]))
        self.assertEqual(result, 3)

        # test function(a*b)
        expr = pybamm.Function(test_function, a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 12)

        # test exp
        expr = pybamm.exp(a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, np.exp(6))

        # test a constant expression
        expr = pybamm.Scalar(2) * pybamm.Scalar(3)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate()
        self.assertEqual(result, 6)

        # test a larger expression
        expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4
        evaluator = pybamm.EvaluatorJax(expr)
        for y in y_tests:
            result = evaluator.evaluate(t=None, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=None, y=y))

        # test something with time
        expr = a * pybamm.t
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        expr = A @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a heaviside
        a = pybamm.Vector(np.array([1, 2]))
        expr = a <= pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = a > pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a minimum or maximum
        a = pybamm.Vector(np.array([1, 2]))
        expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with an index
        expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a sparse matrix-vector multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]])))
        expr = A @ B @ C @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test the sparse-scalar multiplication
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        for expr in [
                A * pybamm.t @ pybamm.StateVector(slice(0, 2)),
                pybamm.t * A @ pybamm.StateVector(slice(0, 2)),
        ]:
            evaluator = pybamm.EvaluatorJax(expr)
            for t, y in zip(t_tests, y_tests):
                result = evaluator.evaluate(t=t, y=y)
                np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test the sparse-scalar division
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        expr = A / (1.0 + pybamm.t) @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test sparse stack
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        a = pybamm.StateVector(slice(0, 1))
        expr = pybamm.SparseStack(A, a * B)
        with self.assertRaises(NotImplementedError):
            evaluator = pybamm.EvaluatorJax(expr)

        # test sparse mat-mat mult
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        a = pybamm.StateVector(slice(0, 1))
        expr = A @ (a * B)
        with self.assertRaises(NotImplementedError):
            evaluator = pybamm.EvaluatorJax(expr)

        # test numpy concatenation
        a = pybamm.Vector(np.array([[1], [2]]))
        b = pybamm.Vector(np.array([[3]]))
        expr = pybamm.NumpyConcatenation(a, b)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test Inner
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1]])))
        v = pybamm.StateVector(slice(0, 1))
        for expr in [
                pybamm.Inner(A, v) @ v,
                pybamm.Inner(v, A) @ v,
                pybamm.Inner(v, v) @ v
        ]:
            evaluator = pybamm.EvaluatorJax(expr)
            for t, y in zip(t_tests, y_tests):
                result = evaluator.evaluate(t=t, y=y)
                np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))