Esempio n. 1
0
    def test_index(self):
        vec = pybamm.Vector(np.array([1, 2, 3, 4, 5]))
        # with integer
        ind = vec[3]
        self.assertIsInstance(ind, pybamm.Index)
        self.assertEqual(ind.slice, slice(3, 4))
        self.assertEqual(ind.evaluate(), 4)
        # with slice
        ind = vec[1:3]
        self.assertIsInstance(ind, pybamm.Index)
        self.assertEqual(ind.slice, slice(1, 3))
        np.testing.assert_array_equal(ind.evaluate(), np.array([[2], [3]]))
        # with only stop slice
        ind = vec[:3]
        self.assertIsInstance(ind, pybamm.Index)
        self.assertEqual(ind.slice, slice(3))
        np.testing.assert_array_equal(ind.evaluate(), np.array([[1], [2],
                                                                [3]]))

        # errors
        with self.assertRaisesRegex(TypeError,
                                    "index must be integer or slice"):
            pybamm.Index(vec, 0.0)
        with self.assertRaisesRegex(ValueError,
                                    "slice size exceeds child size"):
            pybamm.Index(vec, 5)
Esempio n. 2
0
    def test_index(self):
        vec = pybamm.StateVector(slice(0, 5))
        ind = pybamm.Index(vec, 3)
        jac = ind.jac(vec).evaluate(y=np.linspace(0, 2, 5)).toarray()
        np.testing.assert_array_equal(jac, np.array([[0, 0, 0, 1, 0]]))

        # jac of ind of something that isn't a StateVector should return zeros
        const_vec = pybamm.Vector(np.ones(3))
        ind = pybamm.Index(const_vec, 2)
        jac = ind.jac(vec).evaluate(y=np.linspace(0, 2, 5)).toarray()
        np.testing.assert_array_equal(jac, np.array([[0, 0, 0, 0, 0]]))
Esempio n. 3
0
 def test_evaluates_on_edges(self):
     a = pybamm.StateVector(slice(0, 10), domain="test")
     self.assertFalse(pybamm.Index(a, slice(1)).evaluates_on_edges("primary"))
     self.assertFalse(pybamm.Laplacian(a).evaluates_on_edges("primary"))
     self.assertFalse(pybamm.GradientSquared(a).evaluates_on_edges("primary"))
     self.assertFalse(pybamm.BoundaryIntegral(a).evaluates_on_edges("primary"))
     self.assertTrue(pybamm.Upwind(a).evaluates_on_edges("primary"))
     self.assertTrue(pybamm.Downwind(a).evaluates_on_edges("primary"))
Esempio n. 4
0
    def test_symbol_new_copy(self):
        a = pybamm.Parameter("a")
        b = pybamm.Parameter("b")
        v_n = pybamm.Variable("v", "negative electrode")
        x_n = pybamm.standard_spatial_vars.x_n
        v_s = pybamm.Variable("v", "separator")
        vec = pybamm.Vector([1, 2, 3, 4, 5])
        mat = pybamm.Matrix([[1, 2], [3, 4]])
        mesh = get_mesh_for_testing()

        for symbol in [
                a + b,
                a - b,
                a * b,
                a / b,
                a**b,
                -a,
                abs(a),
                pybamm.Function(np.sin, a),
                pybamm.FunctionParameter("function", {"a": a}),
                pybamm.grad(v_n),
                pybamm.div(pybamm.grad(v_n)),
                pybamm.upwind(v_n),
                pybamm.IndefiniteIntegral(v_n, x_n),
                pybamm.BackwardIndefiniteIntegral(v_n, x_n),
                pybamm.BoundaryValue(v_n, "right"),
                pybamm.BoundaryGradient(v_n, "right"),
                pybamm.PrimaryBroadcast(a, "domain"),
                pybamm.SecondaryBroadcast(v_n, "current collector"),
                pybamm.FullBroadcast(a, "domain",
                                     {"secondary": "other domain"}),
                pybamm.concatenation(v_n, v_s),
                pybamm.NumpyConcatenation(a, b, v_s),
                pybamm.DomainConcatenation([v_n, v_s], mesh),
                pybamm.Parameter("param"),
                pybamm.InputParameter("param"),
                pybamm.StateVector(slice(0, 56)),
                pybamm.Matrix(np.ones((50, 40))),
                pybamm.SpatialVariable("x", ["negative electrode"]),
                pybamm.t,
                pybamm.Index(vec, 1),
                pybamm.NotConstant(a),
                pybamm.ExternalVariable(
                    "external variable",
                    20,
                    domain="test",
                    auxiliary_domains={"secondary": "test2"},
                ),
                pybamm.minimum(a, b),
                pybamm.maximum(a, b),
                pybamm.SparseStack(mat, mat),
        ]:
            self.assertEqual(symbol.id, symbol.new_copy().id)
Esempio n. 5
0
    def test_index(self):
        vec = pybamm.StateVector(slice(0, 5))
        y_test = np.array([1, 2, 3, 4, 5])
        # with integer
        ind = pybamm.Index(vec, 3)
        self.assertIsInstance(ind, pybamm.Index)
        self.assertEqual(ind.slice, slice(3, 4))
        self.assertEqual(ind.evaluate(y=y_test), 4)
        # with -1
        ind = pybamm.Index(vec, -1)
        self.assertIsInstance(ind, pybamm.Index)
        self.assertEqual(ind.slice, slice(-1, None))
        self.assertEqual(ind.evaluate(y=y_test), 5)
        self.assertEqual(ind.name, "Index[-1]")
        # with slice
        ind = pybamm.Index(vec, slice(1, 3))
        self.assertIsInstance(ind, pybamm.Index)
        self.assertEqual(ind.slice, slice(1, 3))
        np.testing.assert_array_equal(ind.evaluate(y=y_test), np.array([[2], [3]]))
        # with only stop slice
        ind = pybamm.Index(vec, slice(3))
        self.assertIsInstance(ind, pybamm.Index)
        self.assertEqual(ind.slice, slice(3))
        np.testing.assert_array_equal(ind.evaluate(y=y_test), np.array([[1], [2], [3]]))

        # errors
        with self.assertRaisesRegex(TypeError, "index must be integer or slice"):
            pybamm.Index(vec, 0.0)
        debug_mode = pybamm.settings.debug_mode
        pybamm.settings.debug_mode = True
        with self.assertRaisesRegex(ValueError, "slice size exceeds child size"):
            pybamm.Index(vec, 5)
        pybamm.settings.debug_mode = debug_mode
Esempio n. 6
0
 def _concatenation_jac(self, children_jacs):
     """ See :meth:`pybamm.Concatenation.concatenation_jac()`. """
     # note that this assumes that the children are in the right order and only have
     # one domain each
     jacs = []
     for i in range(self.secondary_dimensions_npts):
         for child_jac, slices in zip(children_jacs, self._children_slices):
             if len(slices) > 1:
                 raise NotImplementedError(
                     """jacobian only implemented for when each child has
                     a single domain""")
             child_slice = next(iter(slices.values()))
             jacs.append(pybamm.Index(child_jac, child_slice[i]))
     return SparseStack(*jacs)
Esempio n. 7
0
    def test_symbol_new_copy(self):
        a = pybamm.Scalar(0)
        b = pybamm.Scalar(1)
        v_n = pybamm.Variable("v", "negative electrode")
        x_n = pybamm.standard_spatial_vars.x_n
        v_s = pybamm.Variable("v", "separator")
        vec = pybamm.Vector(np.array([1, 2, 3, 4, 5]))
        mesh = get_mesh_for_testing()

        for symbol in [
                a + b,
                a - b,
                a * b,
                a / b,
                a**b,
                -a,
                abs(a),
                pybamm.Function(np.sin, a),
                pybamm.FunctionParameter("function", {"a": a}),
                pybamm.grad(v_n),
                pybamm.div(pybamm.grad(v_n)),
                pybamm.Integral(a, pybamm.t),
                pybamm.IndefiniteIntegral(v_n, x_n),
                pybamm.BackwardIndefiniteIntegral(v_n, x_n),
                pybamm.BoundaryValue(v_n, "right"),
                pybamm.BoundaryGradient(v_n, "right"),
                pybamm.PrimaryBroadcast(a, "domain"),
                pybamm.SecondaryBroadcast(v_n, "current collector"),
                pybamm.FullBroadcast(a, "domain",
                                     {"secondary": "other domain"}),
                pybamm.Concatenation(v_n, v_s),
                pybamm.NumpyConcatenation(a, b, v_s),
                pybamm.DomainConcatenation([v_n, v_s], mesh),
                pybamm.Parameter("param"),
                pybamm.InputParameter("param"),
                pybamm.StateVector(slice(0, 56)),
                pybamm.Matrix(np.ones((50, 40))),
                pybamm.SpatialVariable("x", ["negative electrode"]),
                pybamm.t,
                pybamm.Index(vec, 1),
        ]:
            self.assertEqual(symbol.id, symbol.new_copy().id)
Esempio n. 8
0
    def _get_j_diffusion_limited_first_order(self, variables):
        """
        First-order correction to the interfacial current density due to
        diffusion-limited effects. For a general model the correction term is zero,
        since the reaction is not diffusion-limited
        """
        if self.order == "leading":
            j_leading_order = variables["Leading-order x-averaged " +
                                        self.domain.lower() + " electrode" +
                                        self.reaction_name +
                                        " interfacial current density"]
            param = self.param
            if self.domain == "Negative":
                N_ox_s_p = variables["Oxygen flux"].orphans[1]
                N_ox_neg_sep_interface = pybamm.Index(N_ox_s_p, slice(0, 1))

                j = -N_ox_neg_sep_interface / param.C_e / -param.s_ox_Ox / param.l_n

            return (j - j_leading_order) / param.C_e
        else:
            return pybamm.Scalar(0)
Esempio n. 9
0
    def test_evaluator_python(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        y_tests = [np.array([[2], [3]]), np.array([[1], [3]])]
        t_tests = [1, 2]

        # test a * b
        expr = a * b
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 6)
        result = evaluator.evaluate(t=None, y=np.array([[1], [3]]))
        self.assertEqual(result, 3)

        # test function(a*b)
        expr = pybamm.Function(test_function, a * b)
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 12)

        # test a constant expression
        expr = pybamm.Scalar(2) * pybamm.Scalar(3)
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate()
        self.assertEqual(result, 6)

        # test a larger expression
        expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4
        evaluator = pybamm.EvaluatorPython(expr)
        for y in y_tests:
            result = evaluator.evaluate(t=None, y=y)
            self.assertEqual(result, expr.evaluate(t=None, y=y))

        # test something with time
        expr = a * pybamm.t
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        expr = A @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a heaviside
        a = pybamm.Vector(np.array([1, 2]))
        expr = a <= pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = a > pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a minimum or maximum
        a = pybamm.Vector(np.array([1, 2]))
        expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with an index
        expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a sparse matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]])))
        expr = A @ B @ C @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test numpy concatenation
        a = pybamm.Vector(np.array([[1], [2]]))
        b = pybamm.Vector(np.array([[3]]))
        expr = pybamm.NumpyConcatenation(a, b)
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test sparse stack
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        expr = pybamm.SparseStack(A, B)
        evaluator = pybamm.EvaluatorPython(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y).toarray()
            np.testing.assert_allclose(result,
                                       expr.evaluate(t=t, y=y).toarray())

        # test Inner
        v = pybamm.Vector(np.ones(5), domain="test")
        w = pybamm.Vector(2 * np.ones(5), domain="test")
        expr = pybamm.Inner(v, w)
        evaluator = pybamm.EvaluatorPython(expr)
        result = evaluator.evaluate()
        np.testing.assert_allclose(result, expr.evaluate())
Esempio n. 10
0
 def __getitem__(self, key):
     """return a :class:`Index` object"""
     return pybamm.simplify_if_constant(pybamm.Index(self, key),
                                        keep_domains=True)
Esempio n. 11
0
 def __getitem__(self, key):
     """return a :class:`Index` object"""
     return pybamm.Index(self, key)
Esempio n. 12
0
    def _process_symbol(self, symbol):
        """ See :meth:`Discretisation.process_symbol()`. """

        if symbol.domain != []:
            spatial_method = self.spatial_methods[symbol.domain[0]]
            # If boundary conditions are provided, need to check for BCs on tabs
            if self.bcs:
                key_id = list(self.bcs.keys())[0]
                if any("tab" in side
                       for side in list(self.bcs[key_id].keys())):
                    self.bcs[key_id] = self.check_tab_conditions(
                        symbol, self.bcs[key_id])

        if isinstance(symbol, pybamm.BinaryOperator):
            # Pre-process children
            left, right = symbol.children
            disc_left = self.process_symbol(left)
            disc_right = self.process_symbol(right)
            if symbol.domain == []:
                return symbol._binary_new_copy(disc_left, disc_right)
            else:
                return spatial_method.process_binary_operators(
                    symbol, left, right, disc_left, disc_right)
        elif isinstance(symbol, pybamm.UnaryOperator):
            child = symbol.child
            disc_child = self.process_symbol(child)
            if child.domain != []:
                child_spatial_method = self.spatial_methods[child.domain[0]]

            if isinstance(symbol, pybamm.Gradient):
                return child_spatial_method.gradient(child, disc_child,
                                                     self.bcs)

            elif isinstance(symbol, pybamm.Divergence):
                return child_spatial_method.divergence(child, disc_child,
                                                       self.bcs)

            elif isinstance(symbol, pybamm.Laplacian):
                return child_spatial_method.laplacian(child, disc_child,
                                                      self.bcs)

            elif isinstance(symbol, pybamm.Gradient_Squared):
                return child_spatial_method.gradient_squared(
                    child, disc_child, self.bcs)

            elif isinstance(symbol, pybamm.Mass):
                return child_spatial_method.mass_matrix(child, self.bcs)

            elif isinstance(symbol, pybamm.BoundaryMass):
                return child_spatial_method.boundary_mass_matrix(
                    child, self.bcs)

            elif isinstance(symbol, pybamm.IndefiniteIntegral):
                return child_spatial_method.indefinite_integral(
                    child, disc_child, "forward")
            elif isinstance(symbol, pybamm.BackwardIndefiniteIntegral):
                return child_spatial_method.indefinite_integral(
                    child, disc_child, "backward")

            elif isinstance(symbol, pybamm.Integral):
                integral_spatial_method = self.spatial_methods[
                    symbol.integration_variable[0].domain[0]]
                out = integral_spatial_method.integral(
                    child, disc_child, symbol._integration_dimension)
                out.copy_domains(symbol)
                return out

            elif isinstance(symbol, pybamm.DefiniteIntegralVector):
                return child_spatial_method.definite_integral_matrix(
                    child, vector_type=symbol.vector_type)

            elif isinstance(symbol, pybamm.BoundaryIntegral):
                return child_spatial_method.boundary_integral(
                    child, disc_child, symbol.region)

            elif isinstance(symbol, pybamm.Broadcast):
                # Broadcast new_child to the domain specified by symbol.domain
                # Different discretisations may broadcast differently
                if symbol.domain == []:
                    symbol = disc_child * pybamm.Vector([1])
                else:
                    symbol = spatial_method.broadcast(
                        disc_child,
                        symbol.domain,
                        symbol.auxiliary_domains,
                        symbol.broadcast_type,
                    )
                return symbol

            elif isinstance(symbol, pybamm.DeltaFunction):
                return spatial_method.delta_function(symbol, disc_child)

            elif isinstance(symbol, pybamm.BoundaryOperator):
                # if boundary operator applied on "negative tab" or
                # "positive tab" *and* the mesh is 1D then change side to
                # "left" or "right" as appropriate
                if symbol.side in ["negative tab", "positive tab"]:
                    mesh = self.mesh[symbol.children[0].domain[0]]
                    if isinstance(mesh, pybamm.SubMesh1D):
                        symbol.side = mesh.tabs[symbol.side]
                return child_spatial_method.boundary_value_or_flux(
                    symbol, disc_child, self.bcs)
            elif isinstance(symbol, pybamm.UpwindDownwind):
                direction = symbol.name  # upwind or downwind
                return spatial_method.upwind_or_downwind(
                    child, disc_child, self.bcs, direction)
            else:
                return symbol._unary_new_copy(disc_child)

        elif isinstance(symbol, pybamm.Function):
            disc_children = [
                self.process_symbol(child) for child in symbol.children
            ]
            return symbol._function_new_copy(disc_children)

        elif isinstance(symbol, pybamm.VariableDot):
            return pybamm.StateVectorDot(
                *self.y_slices[symbol.get_variable().id],
                domain=symbol.domain,
                auxiliary_domains=symbol.auxiliary_domains,
            )

        elif isinstance(symbol, pybamm.Variable):
            # Check if variable is a standard variable or an external variable
            if any(symbol.id == var.id
                   for var in self.external_variables.values()):
                # Look up dictionary key based on value
                idx = [x.id for x in self.external_variables.values()
                       ].index(symbol.id)
                name, parent_and_slice = list(
                    self.external_variables.keys())[idx]
                if parent_and_slice is None:
                    # Variable didn't come from a concatenation so we can just create a
                    # normal external variable using the symbol's name
                    return pybamm.ExternalVariable(
                        symbol.name,
                        size=self._get_variable_size(symbol),
                        domain=symbol.domain,
                        auxiliary_domains=symbol.auxiliary_domains,
                    )
                else:
                    # We have to use a special name since the concatenation doesn't have
                    # a very informative name. Needs improving
                    parent, start, end = parent_and_slice
                    ext = pybamm.ExternalVariable(
                        name,
                        size=self._get_variable_size(parent),
                        domain=parent.domain,
                        auxiliary_domains=parent.auxiliary_domains,
                    )
                    out = pybamm.Index(ext, slice(start, end))
                    out.domain = symbol.domain
                    return out

            else:
                # add a try except block for a more informative error if a variable
                # can't be found. This should usually be caught earlier by
                # model.check_well_posedness, but won't be if debug_mode is False
                try:
                    y_slices = self.y_slices[symbol.id]
                except KeyError:
                    raise pybamm.ModelError("""
                        No key set for variable '{}'. Make sure it is included in either
                        model.rhs, model.algebraic, or model.external_variables in an
                        unmodified form (e.g. not Broadcasted)
                        """.format(symbol.name))
                return pybamm.StateVector(
                    *y_slices,
                    domain=symbol.domain,
                    auxiliary_domains=symbol.auxiliary_domains,
                )

        elif isinstance(symbol, pybamm.SpatialVariable):
            return spatial_method.spatial_variable(symbol)

        elif isinstance(symbol, pybamm.Concatenation):
            new_children = [
                self.process_symbol(child) for child in symbol.children
            ]
            new_symbol = spatial_method.concatenation(new_children)

            return new_symbol

        elif isinstance(symbol, pybamm.InputParameter):
            # Return a new copy of the input parameter, but set the expected size
            # according to the domain of the input parameter
            expected_size = self._get_variable_size(symbol)
            new_input_parameter = symbol.new_copy()
            new_input_parameter.set_expected_size(expected_size)
            return new_input_parameter

        else:
            # Backup option: return new copy of the object
            try:
                return symbol.new_copy()
            except NotImplementedError:
                raise NotImplementedError(
                    "Cannot discretise symbol of type '{}'".format(
                        type(symbol)))
Esempio n. 13
0
    def test_evaluator_jax(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        y_tests = [
            np.array([[2.0], [3.0]]),
            np.array([[1.0], [3.0]]),
            np.array([1.0, 3.0]),
        ]
        t_tests = [1.0, 2.0]

        # test a * b
        expr = a * b
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 6)
        result = evaluator.evaluate(t=None, y=np.array([[1], [3]]))
        self.assertEqual(result, 3)

        # test function(a*b)
        expr = pybamm.Function(test_function, a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, 12)

        # test exp
        expr = pybamm.exp(a * b)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate(t=None, y=np.array([[2], [3]]))
        self.assertEqual(result, np.exp(6))

        # test a constant expression
        expr = pybamm.Scalar(2) * pybamm.Scalar(3)
        evaluator = pybamm.EvaluatorJax(expr)
        result = evaluator.evaluate()
        self.assertEqual(result, 6)

        # test a larger expression
        expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4
        evaluator = pybamm.EvaluatorJax(expr)
        for y in y_tests:
            result = evaluator.evaluate(t=None, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=None, y=y))

        # test something with time
        expr = a * pybamm.t
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a matrix multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        expr = A @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a heaviside
        a = pybamm.Vector(np.array([1, 2]))
        expr = a <= pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = a > pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with a minimum or maximum
        a = pybamm.Vector(np.array([1, 2]))
        expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2)))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test something with an index
        expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            self.assertEqual(result, expr.evaluate(t=t, y=y))

        # test something with a sparse matrix-vector multiplication
        A = pybamm.Matrix(np.array([[1, 2], [3, 4]]))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]])))
        expr = A @ B @ C @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test the sparse-scalar multiplication
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        for expr in [
                A * pybamm.t @ pybamm.StateVector(slice(0, 2)),
                pybamm.t * A @ pybamm.StateVector(slice(0, 2)),
        ]:
            evaluator = pybamm.EvaluatorJax(expr)
            for t, y in zip(t_tests, y_tests):
                result = evaluator.evaluate(t=t, y=y)
                np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test the sparse-scalar division
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        expr = A / (1.0 + pybamm.t) @ pybamm.StateVector(slice(0, 2))
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test sparse stack
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        a = pybamm.StateVector(slice(0, 1))
        expr = pybamm.SparseStack(A, a * B)
        with self.assertRaises(NotImplementedError):
            evaluator = pybamm.EvaluatorJax(expr)

        # test sparse mat-mat mult
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]])))
        B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]])))
        a = pybamm.StateVector(slice(0, 1))
        expr = A @ (a * B)
        with self.assertRaises(NotImplementedError):
            evaluator = pybamm.EvaluatorJax(expr)

        # test numpy concatenation
        a = pybamm.Vector(np.array([[1], [2]]))
        b = pybamm.Vector(np.array([[3]]))
        expr = pybamm.NumpyConcatenation(a, b)
        evaluator = pybamm.EvaluatorJax(expr)
        for t, y in zip(t_tests, y_tests):
            result = evaluator.evaluate(t=t, y=y)
            np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

        # test Inner
        A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1]])))
        v = pybamm.StateVector(slice(0, 1))
        for expr in [
                pybamm.Inner(A, v) @ v,
                pybamm.Inner(v, A) @ v,
                pybamm.Inner(v, v) @ v
        ]:
            evaluator = pybamm.EvaluatorJax(expr)
            for t, y in zip(t_tests, y_tests):
                result = evaluator.evaluate(t=t, y=y)
                np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))