Beispiel #1
0
    def create_jacobian(self, model):
        """Creates Jacobian of the discretised model.
        Note that the model is assumed to be of the form M*y_dot = f(t,y), where
        M is the (possibly singular) mass matrix. The Jacobian is df/dy.

        Note: At present, calculation of the Jacobian is deferred until after
        simplification, since it is much faster to compute the Jacobian of the
        simplified model. However, in some use cases (e.g. running the same
        model multiple times but with different parameters) it may be more
        efficient to compute the Jacobian once, before simplification, so that
        parameters in the Jacobian can be updated (see PR #670).

        Parameters
        ----------
        model : :class:`pybamm.BaseModel`
            Discretised model. Must have attributes rhs, initial_conditions and
            boundary_conditions (all dicts of {variable: equation})

        Returns
        -------
        :class:`pybamm.Concatenation`
            The expression trees corresponding to the Jacobian of the model
        """
        # create state vector to differentiate with respect to
        y = pybamm.StateVector(
            slice(0, np.size(model.concatenated_initial_conditions)))
        # set up Jacobian object, for re-use of dict
        jacobian = pybamm.Jacobian()

        # calculate Jacobian of rhs by equation
        jac_rhs_eqn_dict = {}
        for eqn_key, eqn in model.rhs.items():
            pybamm.logger.debug(
                "Calculating block of Jacobian for {!r}".format(eqn_key.name))
            jac_rhs_eqn_dict[eqn_key] = jacobian.jac(eqn, y)
        jac_rhs = self._concatenate_in_order(jac_rhs_eqn_dict, sparse=True)

        # calculate Jacobian of algebraic by equation
        jac_algebraic_eqn_dict = {}
        for eqn_key, eqn in model.algebraic.items():
            pybamm.logger.debug(
                "Calculating block of Jacobian for {!r}".format(eqn_key.name))
            jac_algebraic_eqn_dict[eqn_key] = jacobian.jac(eqn, y)
        jac_algebraic = self._concatenate_in_order(jac_algebraic_eqn_dict,
                                                   sparse=True)

        # full Jacobian
        if model.rhs.keys() and model.algebraic.keys():
            jac = pybamm.SparseStack(jac_rhs, jac_algebraic)
        elif not model.algebraic.keys():
            jac = jac_rhs
        else:
            jac = jac_algebraic

        return jac, jac_rhs, jac_algebraic
Beispiel #2
0
    def test_symbol_new_copy(self):
        a = pybamm.Parameter("a")
        b = pybamm.Parameter("b")
        v_n = pybamm.Variable("v", "negative electrode")
        x_n = pybamm.standard_spatial_vars.x_n
        v_s = pybamm.Variable("v", "separator")
        vec = pybamm.Vector([1, 2, 3, 4, 5])
        mat = pybamm.Matrix([[1, 2], [3, 4]])
        mesh = get_mesh_for_testing()

        for symbol in [
                a + b,
                a - b,
                a * b,
                a / b,
                a**b,
                -a,
                abs(a),
                pybamm.Function(np.sin, a),
                pybamm.FunctionParameter("function", {"a": a}),
                pybamm.grad(v_n),
                pybamm.div(pybamm.grad(v_n)),
                pybamm.upwind(v_n),
                pybamm.IndefiniteIntegral(v_n, x_n),
                pybamm.BackwardIndefiniteIntegral(v_n, x_n),
                pybamm.BoundaryValue(v_n, "right"),
                pybamm.BoundaryGradient(v_n, "right"),
                pybamm.PrimaryBroadcast(a, "domain"),
                pybamm.SecondaryBroadcast(v_n, "current collector"),
                pybamm.FullBroadcast(a, "domain",
                                     {"secondary": "other domain"}),
                pybamm.concatenation(v_n, v_s),
                pybamm.NumpyConcatenation(a, b, v_s),
                pybamm.DomainConcatenation([v_n, v_s], mesh),
                pybamm.Parameter("param"),
                pybamm.InputParameter("param"),
                pybamm.StateVector(slice(0, 56)),
                pybamm.Matrix(np.ones((50, 40))),
                pybamm.SpatialVariable("x", ["negative electrode"]),
                pybamm.t,
                pybamm.Index(vec, 1),
                pybamm.NotConstant(a),
                pybamm.ExternalVariable(
                    "external variable",
                    20,
                    domain="test",
                    auxiliary_domains={"secondary": "test2"},
                ),
                pybamm.minimum(a, b),
                pybamm.maximum(a, b),
                pybamm.SparseStack(mat, mat),
        ]:
            self.assertEqual(symbol.id, symbol.new_copy().id)
Beispiel #3
0
    def set_up(self, model):
        """Unpack model, perform checks, simplify and calculate jacobian.

        Parameters
        ----------
        model : :class:`pybamm.BaseModel`
            The model whose solution to calculate. Must have attributes rhs and
            initial_conditions

        Raises
        ------
        :class:`pybamm.SolverError`
            If the model contains any algebraic equations (in which case a DAE solver
            should be used instead)
        """
        # create simplified rhs, algebraic and event expressions
        concatenated_rhs = model.concatenated_rhs
        concatenated_algebraic = model.concatenated_algebraic
        events = model.events

        if model.use_simplify:
            # set up simplification object, for re-use of dict
            simp = pybamm.Simplification()
            pybamm.logger.info("Simplifying RHS")
            concatenated_rhs = simp.simplify(concatenated_rhs)
            pybamm.logger.info("Simplifying algebraic")
            concatenated_algebraic = simp.simplify(concatenated_algebraic)
            pybamm.logger.info("Simplifying events")
            events = {
                name: simp.simplify(event)
                for name, event in events.items()
            }

        if model.use_jacobian:
            # Create Jacobian from simplified rhs
            y = pybamm.StateVector(
                slice(0, np.size(model.concatenated_initial_conditions)))
            pybamm.logger.info("Calculating jacobian")
            jac_rhs = concatenated_rhs.jac(y)
            jac_algebraic = concatenated_algebraic.jac(y)
            jac = pybamm.SparseStack(jac_rhs, jac_algebraic)
            model.jacobian = jac

            if model.use_simplify:
                pybamm.logger.info("Simplifying jacobian")
                jac_algebraic = simp.simplify(jac_algebraic)
                jac = simp.simplify(jac)

            if model.use_to_python:
                pybamm.logger.info("Converting jacobian to python")
                jac_algebraic = pybamm.EvaluatorPython(jac_algebraic)
                jac = pybamm.EvaluatorPython(jac)

            def jac_alg_fn(t, y):
                return jac_algebraic.evaluate(t, y)

        else:
            jac = None
            jac_alg_fn = None

        if model.use_to_python:
            pybamm.logger.info("Converting RHS to python")
            concatenated_rhs = pybamm.EvaluatorPython(concatenated_rhs)
            pybamm.logger.info("Converting algebraic to python")
            concatenated_algebraic = pybamm.EvaluatorPython(
                concatenated_algebraic)
            pybamm.logger.info("Converting events to python")
            events = {
                name: pybamm.EvaluatorPython(event)
                for name, event in events.items()
            }

        # Calculate consistent initial conditions for the algebraic equations
        def rhs(t, y):
            return concatenated_rhs.evaluate(t, y, known_evals={})[0][:, 0]

        def algebraic(t, y):
            return concatenated_algebraic.evaluate(t, y, known_evals={})[0][:,
                                                                            0]

        if len(model.algebraic) > 0:
            y0 = self.calculate_consistent_initial_conditions(
                rhs, algebraic, model.concatenated_initial_conditions[:, 0],
                jac_alg_fn)
        else:
            # can use DAE solver to solve ODE model
            y0 = model.concatenated_initial_conditions[:, 0]

        # Create functions to evaluate residuals
        def residuals(t, y, ydot):
            pybamm.logger.debug("Evaluating residuals for {} at t={}".format(
                model.name, t))
            y = y[:, np.newaxis]
            rhs_eval, known_evals = concatenated_rhs.evaluate(t,
                                                              y,
                                                              known_evals={})
            # reuse known_evals
            alg_eval = concatenated_algebraic.evaluate(
                t, y, known_evals=known_evals)[0]
            # turn into 1D arrays
            rhs_eval = rhs_eval[:, 0]
            alg_eval = alg_eval[:, 0]
            return (np.concatenate(
                (rhs_eval, alg_eval)) - model.mass_matrix.entries @ ydot)

        # Create event-dependent function to evaluate events
        def event_fun(event):
            def eval_event(t, y):
                return event.evaluate(t, y)

            return eval_event

        event_funs = [event_fun(event) for event in events.values()]

        # Create function to evaluate jacobian
        if jac is not None:

            def jacobian(t, y):
                return jac.evaluate(t, y, known_evals={})[0]

        else:
            jacobian = None

        # Add the solver attributes
        self.y0 = y0
        self.rhs = rhs
        self.algebraic = algebraic
        self.residuals = residuals
        self.events = events
        self.event_funs = event_funs
        self.jacobian = jacobian
Beispiel #4
0
    def test_evaluator_python(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        y_tests = [np.array([[2], [3]]), np.array([[1], [3]])]
        t_tests = [1, 2]

        # test a * b
        expr = a * b
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 6)
        result = evaluator.evaluate(t=None, y=np.array([[1], [3]]))
        self.assertEqual(result, 3)

        # test function(a*b)
        expr = pybamm.Function(test_function, a * b)
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 12)

        # test a constant expression
        expr = pybamm.Scalar(2) * pybamm.Scalar(3)
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate()
        self.assertEqual(result, 6)

        # test a larger expression
        expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4
        evaluator = pybamm.EvaluatorPython(expr)
        for y in y_tests:
            result = evaluator.evaluate(t=None, y=y)
            self.assertEqual(result, expr.evaluate(t=None, y=y))

        # test something with time
        expr = a * pybamm.t
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        expr = A @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a heaviside
        a = pybamm.Vector(np.array([1, 2]))
        expr = a <= pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = a > pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a minimum or maximum
        a = pybamm.Vector(np.array([1, 2]))
        expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with an index
        expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a sparse matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]])))
        expr = A @ B @ C @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test numpy concatenation
        a = pybamm.Vector(np.array([[1], [2]]))
        b = pybamm.Vector(np.array([[3]]))
        expr = pybamm.NumpyConcatenation(a, b)
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test sparse stack
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        expr = pybamm.SparseStack(A, B)
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y).toarray()
            np.testing.assert_allclose(result,
                                       expr.evaluate(t=t, y=y).toarray())

        # test Inner
        v = pybamm.Vector(np.ones(5), domain="test")
        w = pybamm.Vector(2 * np.ones(5), domain="test")
        expr = pybamm.Inner(v, w)
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate()
        np.testing.assert_allclose(result, expr.evaluate())
Beispiel #5
0
    def test_process_model_dae(self):
        # one rhs equation and one algebraic
        whole_cell = ["negative electrode", "separator", "positive electrode"]
        c = pybamm.Variable("c", domain=whole_cell)
        d = pybamm.Variable("d", domain=whole_cell)
        N = pybamm.grad(c)
        model = pybamm.BaseModel()
        model.rhs = {c: pybamm.div(N)}
        model.algebraic = {d: d - 2 * c}
        model.initial_conditions = {d: pybamm.Scalar(6), c: pybamm.Scalar(3)}

        model.boundary_conditions = {
            c: {"left": (0, "Neumann"), "right": (0, "Neumann")}
        }
        model.variables = {"c": c, "N": N, "d": d}

        # create discretisation
        disc = get_discretisation_for_testing()
        mesh = disc.mesh

        disc.process_model(model)
        combined_submesh = mesh.combine_submeshes(*whole_cell)

        y0 = model.concatenated_initial_conditions.evaluate()
        np.testing.assert_array_equal(
            y0,
            np.concatenate(
                [
                    3 * np.ones_like(combined_submesh.nodes),
                    6 * np.ones_like(combined_submesh.nodes),
                ]
            )[:, np.newaxis],
        )

        # grad and div are identity operators here
        np.testing.assert_array_equal(
            y0[: combined_submesh.npts], model.concatenated_rhs.evaluate(None, y0)
        )

        np.testing.assert_array_equal(
            model.concatenated_algebraic.evaluate(None, y0),
            np.zeros_like(combined_submesh.nodes[:, np.newaxis]),
        )

        # mass matrix is identity upper left, zeros elsewhere
        mass = block_diag(
            (
                np.eye(np.size(combined_submesh.nodes)),
                np.zeros(
                    (np.size(combined_submesh.nodes), np.size(combined_submesh.nodes),)
                ),
            )
        )
        np.testing.assert_array_equal(
            mass.toarray(), model.mass_matrix.entries.toarray()
        )

        # jacobian
        y = pybamm.StateVector(slice(0, np.size(y0)))
        jac_rhs = model.concatenated_rhs.jac(y)
        jac_algebraic = model.concatenated_algebraic.jac(y)
        jacobian = pybamm.SparseStack(jac_rhs, jac_algebraic).evaluate(0, y0)

        jacobian_actual = np.block(
            [
                [
                    np.eye(np.size(combined_submesh.nodes)),
                    np.zeros(
                        (
                            np.size(combined_submesh.nodes),
                            np.size(combined_submesh.nodes),
                        )
                    ),
                ],
                [
                    -2 * np.eye(np.size(combined_submesh.nodes)),
                    np.eye(np.size(combined_submesh.nodes)),
                ],
            ]
        )
        np.testing.assert_array_equal(jacobian_actual, jacobian.toarray())

        # test jacobian by eqn gives same as jacobian of concatenated rhs & algebraic
        model.jacobian, _, _ = disc.create_jacobian(model)
        model_jacobian = model.jacobian.evaluate(0, y0)
        np.testing.assert_array_equal(model_jacobian.toarray(), jacobian.toarray())

        # test known_evals
        expr = pybamm.SparseStack(jac_rhs, jac_algebraic)
        jacobian, known_evals = expr.evaluate(0, y0, known_evals={})
        np.testing.assert_array_equal(jacobian_actual, jacobian.toarray())
        jacobian = expr.evaluate(0, y0, known_evals=known_evals)[0]
        np.testing.assert_array_equal(jacobian_actual, jacobian.toarray())

        # check that any time derivatives of variables in algebraic raises an
        # error
        model = pybamm.BaseModel()
        model.rhs = {c: pybamm.div(N)}
        model.algebraic = {d: d - 2 * c.diff(pybamm.t)}
        model.initial_conditions = {d: pybamm.Scalar(6), c: pybamm.Scalar(3)}
        model.boundary_conditions = {
            c: {"left": (0, "Neumann"), "right": (0, "Neumann")}
        }
        model.variables = {"c": c, "N": N, "d": d}

        with self.assertRaises(pybamm.ModelError):
            disc.process_model(model)
Beispiel #6
0
 def concatenate(self, *symbols, sparse=False):
     if sparse:
         return pybamm.SparseStack(*symbols)
     else:
         return pybamm.NumpyConcatenation(*symbols)
Beispiel #7
0
    def test_evaluator_jax(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        y_tests = [
            np.array([[2.0], [3.0]]),
            np.array([[1.0], [3.0]]),
            np.array([1.0, 3.0]),
        ]
        t_tests = [1.0, 2.0]

        # test a * b
        expr = a * b
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 6)
        result = evaluator.evaluate(t=None, y=np.array([[1], [3]]))
        self.assertEqual(result, 3)

        # test function(a*b)
        expr = pybamm.Function(test_function, a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 12)

        # test exp
        expr = pybamm.exp(a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, np.exp(6))

        # test a constant expression
        expr = pybamm.Scalar(2) * pybamm.Scalar(3)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate()
        self.assertEqual(result, 6)

        # test a larger expression
        expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4
        evaluator = pybamm.EvaluatorJax(expr)
        for y in y_tests:
            result = evaluator.evaluate(t=None, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=None, y=y))

        # test something with time
        expr = a * pybamm.t
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        expr = A @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a heaviside
        a = pybamm.Vector(np.array([1, 2]))
        expr = a <= pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = a > pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a minimum or maximum
        a = pybamm.Vector(np.array([1, 2]))
        expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with an index
        expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a sparse matrix-vector multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]])))
        expr = A @ B @ C @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test the sparse-scalar multiplication
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        for expr in [
                A * pybamm.t @ pybamm.StateVector(slice(0, 2)),
                pybamm.t * A @ pybamm.StateVector(slice(0, 2)),
        ]:
            evaluator = pybamm.EvaluatorJax(expr)
            for t, y in zip(t_tests, y_tests):
                result = evaluator.evaluate(t=t, y=y)
                np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test the sparse-scalar division
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        expr = A / (1.0 + pybamm.t) @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test sparse stack
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        a = pybamm.StateVector(slice(0, 1))
        expr = pybamm.SparseStack(A, a * B)
        with self.assertRaises(NotImplementedError):
            evaluator = pybamm.EvaluatorJax(expr)

        # test sparse mat-mat mult
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        a = pybamm.StateVector(slice(0, 1))
        expr = A @ (a * B)
        with self.assertRaises(NotImplementedError):
            evaluator = pybamm.EvaluatorJax(expr)

        # test numpy concatenation
        a = pybamm.Vector(np.array([[1], [2]]))
        b = pybamm.Vector(np.array([[3]]))
        expr = pybamm.NumpyConcatenation(a, b)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test Inner
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1]])))
        v = pybamm.StateVector(slice(0, 1))
        for expr in [
                pybamm.Inner(A, v) @ v,
                pybamm.Inner(v, A) @ v,
                pybamm.Inner(v, v) @ v
        ]:
            evaluator = pybamm.EvaluatorJax(expr)
            for t, y in zip(t_tests, y_tests):
                result = evaluator.evaluate(t=t, y=y)
                np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))