def test_index(self): vec = pybamm.Vector(np.array([1, 2, 3, 4, 5])) # with integer ind = vec[3] self.assertIsInstance(ind, pybamm.Index) self.assertEqual(ind.slice, slice(3, 4)) self.assertEqual(ind.evaluate(), 4) # with slice ind = vec[1:3] self.assertIsInstance(ind, pybamm.Index) self.assertEqual(ind.slice, slice(1, 3)) np.testing.assert_array_equal(ind.evaluate(), np.array([[2], [3]])) # with only stop slice ind = vec[:3] self.assertIsInstance(ind, pybamm.Index) self.assertEqual(ind.slice, slice(3)) np.testing.assert_array_equal(ind.evaluate(), np.array([[1], [2], [3]])) # errors with self.assertRaisesRegex(TypeError, "index must be integer or slice"): pybamm.Index(vec, 0.0) with self.assertRaisesRegex(ValueError, "slice size exceeds child size"): pybamm.Index(vec, 5)
def test_index(self): vec = pybamm.StateVector(slice(0, 5)) ind = pybamm.Index(vec, 3) jac = ind.jac(vec).evaluate(y=np.linspace(0, 2, 5)).toarray() np.testing.assert_array_equal(jac, np.array([[0, 0, 0, 1, 0]])) # jac of ind of something that isn't a StateVector should return zeros const_vec = pybamm.Vector(np.ones(3)) ind = pybamm.Index(const_vec, 2) jac = ind.jac(vec).evaluate(y=np.linspace(0, 2, 5)).toarray() np.testing.assert_array_equal(jac, np.array([[0, 0, 0, 0, 0]]))
def test_evaluates_on_edges(self): a = pybamm.StateVector(slice(0, 10), domain="test") self.assertFalse(pybamm.Index(a, slice(1)).evaluates_on_edges("primary")) self.assertFalse(pybamm.Laplacian(a).evaluates_on_edges("primary")) self.assertFalse(pybamm.GradientSquared(a).evaluates_on_edges("primary")) self.assertFalse(pybamm.BoundaryIntegral(a).evaluates_on_edges("primary")) self.assertTrue(pybamm.Upwind(a).evaluates_on_edges("primary")) self.assertTrue(pybamm.Downwind(a).evaluates_on_edges("primary"))
def test_symbol_new_copy(self): a = pybamm.Parameter("a") b = pybamm.Parameter("b") v_n = pybamm.Variable("v", "negative electrode") x_n = pybamm.standard_spatial_vars.x_n v_s = pybamm.Variable("v", "separator") vec = pybamm.Vector([1, 2, 3, 4, 5]) mat = pybamm.Matrix([[1, 2], [3, 4]]) mesh = get_mesh_for_testing() for symbol in [ a + b, a - b, a * b, a / b, a**b, -a, abs(a), pybamm.Function(np.sin, a), pybamm.FunctionParameter("function", {"a": a}), pybamm.grad(v_n), pybamm.div(pybamm.grad(v_n)), pybamm.upwind(v_n), pybamm.IndefiniteIntegral(v_n, x_n), pybamm.BackwardIndefiniteIntegral(v_n, x_n), pybamm.BoundaryValue(v_n, "right"), pybamm.BoundaryGradient(v_n, "right"), pybamm.PrimaryBroadcast(a, "domain"), pybamm.SecondaryBroadcast(v_n, "current collector"), pybamm.FullBroadcast(a, "domain", {"secondary": "other domain"}), pybamm.concatenation(v_n, v_s), pybamm.NumpyConcatenation(a, b, v_s), pybamm.DomainConcatenation([v_n, v_s], mesh), pybamm.Parameter("param"), pybamm.InputParameter("param"), pybamm.StateVector(slice(0, 56)), pybamm.Matrix(np.ones((50, 40))), pybamm.SpatialVariable("x", ["negative electrode"]), pybamm.t, pybamm.Index(vec, 1), pybamm.NotConstant(a), pybamm.ExternalVariable( "external variable", 20, domain="test", auxiliary_domains={"secondary": "test2"}, ), pybamm.minimum(a, b), pybamm.maximum(a, b), pybamm.SparseStack(mat, mat), ]: self.assertEqual(symbol.id, symbol.new_copy().id)
def test_index(self): vec = pybamm.StateVector(slice(0, 5)) y_test = np.array([1, 2, 3, 4, 5]) # with integer ind = pybamm.Index(vec, 3) self.assertIsInstance(ind, pybamm.Index) self.assertEqual(ind.slice, slice(3, 4)) self.assertEqual(ind.evaluate(y=y_test), 4) # with -1 ind = pybamm.Index(vec, -1) self.assertIsInstance(ind, pybamm.Index) self.assertEqual(ind.slice, slice(-1, None)) self.assertEqual(ind.evaluate(y=y_test), 5) self.assertEqual(ind.name, "Index[-1]") # with slice ind = pybamm.Index(vec, slice(1, 3)) self.assertIsInstance(ind, pybamm.Index) self.assertEqual(ind.slice, slice(1, 3)) np.testing.assert_array_equal(ind.evaluate(y=y_test), np.array([[2], [3]])) # with only stop slice ind = pybamm.Index(vec, slice(3)) self.assertIsInstance(ind, pybamm.Index) self.assertEqual(ind.slice, slice(3)) np.testing.assert_array_equal(ind.evaluate(y=y_test), np.array([[1], [2], [3]])) # errors with self.assertRaisesRegex(TypeError, "index must be integer or slice"): pybamm.Index(vec, 0.0) debug_mode = pybamm.settings.debug_mode pybamm.settings.debug_mode = True with self.assertRaisesRegex(ValueError, "slice size exceeds child size"): pybamm.Index(vec, 5) pybamm.settings.debug_mode = debug_mode
def _concatenation_jac(self, children_jacs): """ See :meth:`pybamm.Concatenation.concatenation_jac()`. """ # note that this assumes that the children are in the right order and only have # one domain each jacs = [] for i in range(self.secondary_dimensions_npts): for child_jac, slices in zip(children_jacs, self._children_slices): if len(slices) > 1: raise NotImplementedError( """jacobian only implemented for when each child has a single domain""") child_slice = next(iter(slices.values())) jacs.append(pybamm.Index(child_jac, child_slice[i])) return SparseStack(*jacs)
def test_symbol_new_copy(self): a = pybamm.Scalar(0) b = pybamm.Scalar(1) v_n = pybamm.Variable("v", "negative electrode") x_n = pybamm.standard_spatial_vars.x_n v_s = pybamm.Variable("v", "separator") vec = pybamm.Vector(np.array([1, 2, 3, 4, 5])) mesh = get_mesh_for_testing() for symbol in [ a + b, a - b, a * b, a / b, a**b, -a, abs(a), pybamm.Function(np.sin, a), pybamm.FunctionParameter("function", {"a": a}), pybamm.grad(v_n), pybamm.div(pybamm.grad(v_n)), pybamm.Integral(a, pybamm.t), pybamm.IndefiniteIntegral(v_n, x_n), pybamm.BackwardIndefiniteIntegral(v_n, x_n), pybamm.BoundaryValue(v_n, "right"), pybamm.BoundaryGradient(v_n, "right"), pybamm.PrimaryBroadcast(a, "domain"), pybamm.SecondaryBroadcast(v_n, "current collector"), pybamm.FullBroadcast(a, "domain", {"secondary": "other domain"}), pybamm.Concatenation(v_n, v_s), pybamm.NumpyConcatenation(a, b, v_s), pybamm.DomainConcatenation([v_n, v_s], mesh), pybamm.Parameter("param"), pybamm.InputParameter("param"), pybamm.StateVector(slice(0, 56)), pybamm.Matrix(np.ones((50, 40))), pybamm.SpatialVariable("x", ["negative electrode"]), pybamm.t, pybamm.Index(vec, 1), ]: self.assertEqual(symbol.id, symbol.new_copy().id)
def _get_j_diffusion_limited_first_order(self, variables): """ First-order correction to the interfacial current density due to diffusion-limited effects. For a general model the correction term is zero, since the reaction is not diffusion-limited """ if self.order == "leading": j_leading_order = variables["Leading-order x-averaged " + self.domain.lower() + " electrode" + self.reaction_name + " interfacial current density"] param = self.param if self.domain == "Negative": N_ox_s_p = variables["Oxygen flux"].orphans[1] N_ox_neg_sep_interface = pybamm.Index(N_ox_s_p, slice(0, 1)) j = -N_ox_neg_sep_interface / param.C_e / -param.s_ox_Ox / param.l_n return (j - j_leading_order) / param.C_e else: return pybamm.Scalar(0)
def test_evaluator_python(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) y_tests = [np.array([[2], [3]]), np.array([[1], [3]])] t_tests = [1, 2] # test a * b expr = a * b evaluator = pybamm.EvaluatorPython(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 6) result = evaluator.evaluate(t=None, y=np.array([[1], [3]])) self.assertEqual(result, 3) # test function(a*b) expr = pybamm.Function(test_function, a * b) evaluator = pybamm.EvaluatorPython(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 12) # test a constant expression expr = pybamm.Scalar(2) * pybamm.Scalar(3) evaluator = pybamm.EvaluatorPython(expr) result = evaluator.evaluate() self.assertEqual(result, 6) # test a larger expression expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4 evaluator = pybamm.EvaluatorPython(expr) for y in y_tests: result = evaluator.evaluate(t=None, y=y) self.assertEqual(result, expr.evaluate(t=None, y=y)) # test something with time expr = a * pybamm.t evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a matrix multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) expr = A @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a heaviside a = pybamm.Vector(np.array([1, 2])) expr = a <= pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = a > pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a minimum or maximum a = pybamm.Vector(np.array([1, 2])) expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with an index expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a sparse matrix multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]]))) expr = A @ B @ C @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test numpy concatenation a = pybamm.Vector(np.array([[1], [2]])) b = pybamm.Vector(np.array([[3]])) expr = pybamm.NumpyConcatenation(a, b) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test sparse stack A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]]))) expr = pybamm.SparseStack(A, B) evaluator = pybamm.EvaluatorPython(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y).toarray() np.testing.assert_allclose(result, expr.evaluate(t=t, y=y).toarray()) # test Inner v = pybamm.Vector(np.ones(5), domain="test") w = pybamm.Vector(2 * np.ones(5), domain="test") expr = pybamm.Inner(v, w) evaluator = pybamm.EvaluatorPython(expr) result = evaluator.evaluate() np.testing.assert_allclose(result, expr.evaluate())
def __getitem__(self, key): """return a :class:`Index` object""" return pybamm.simplify_if_constant(pybamm.Index(self, key), keep_domains=True)
def __getitem__(self, key): """return a :class:`Index` object""" return pybamm.Index(self, key)
def _process_symbol(self, symbol): """ See :meth:`Discretisation.process_symbol()`. """ if symbol.domain != []: spatial_method = self.spatial_methods[symbol.domain[0]] # If boundary conditions are provided, need to check for BCs on tabs if self.bcs: key_id = list(self.bcs.keys())[0] if any("tab" in side for side in list(self.bcs[key_id].keys())): self.bcs[key_id] = self.check_tab_conditions( symbol, self.bcs[key_id]) if isinstance(symbol, pybamm.BinaryOperator): # Pre-process children left, right = symbol.children disc_left = self.process_symbol(left) disc_right = self.process_symbol(right) if symbol.domain == []: return symbol._binary_new_copy(disc_left, disc_right) else: return spatial_method.process_binary_operators( symbol, left, right, disc_left, disc_right) elif isinstance(symbol, pybamm.UnaryOperator): child = symbol.child disc_child = self.process_symbol(child) if child.domain != []: child_spatial_method = self.spatial_methods[child.domain[0]] if isinstance(symbol, pybamm.Gradient): return child_spatial_method.gradient(child, disc_child, self.bcs) elif isinstance(symbol, pybamm.Divergence): return child_spatial_method.divergence(child, disc_child, self.bcs) elif isinstance(symbol, pybamm.Laplacian): return child_spatial_method.laplacian(child, disc_child, self.bcs) elif isinstance(symbol, pybamm.Gradient_Squared): return child_spatial_method.gradient_squared( child, disc_child, self.bcs) elif isinstance(symbol, pybamm.Mass): return child_spatial_method.mass_matrix(child, self.bcs) elif isinstance(symbol, pybamm.BoundaryMass): return child_spatial_method.boundary_mass_matrix( child, self.bcs) elif isinstance(symbol, pybamm.IndefiniteIntegral): return child_spatial_method.indefinite_integral( child, disc_child, "forward") elif isinstance(symbol, pybamm.BackwardIndefiniteIntegral): return child_spatial_method.indefinite_integral( child, disc_child, "backward") elif isinstance(symbol, pybamm.Integral): integral_spatial_method = self.spatial_methods[ symbol.integration_variable[0].domain[0]] out = integral_spatial_method.integral( child, disc_child, symbol._integration_dimension) out.copy_domains(symbol) return out elif isinstance(symbol, pybamm.DefiniteIntegralVector): return child_spatial_method.definite_integral_matrix( child, vector_type=symbol.vector_type) elif isinstance(symbol, pybamm.BoundaryIntegral): return child_spatial_method.boundary_integral( child, disc_child, symbol.region) elif isinstance(symbol, pybamm.Broadcast): # Broadcast new_child to the domain specified by symbol.domain # Different discretisations may broadcast differently if symbol.domain == []: symbol = disc_child * pybamm.Vector([1]) else: symbol = spatial_method.broadcast( disc_child, symbol.domain, symbol.auxiliary_domains, symbol.broadcast_type, ) return symbol elif isinstance(symbol, pybamm.DeltaFunction): return spatial_method.delta_function(symbol, disc_child) elif isinstance(symbol, pybamm.BoundaryOperator): # if boundary operator applied on "negative tab" or # "positive tab" *and* the mesh is 1D then change side to # "left" or "right" as appropriate if symbol.side in ["negative tab", "positive tab"]: mesh = self.mesh[symbol.children[0].domain[0]] if isinstance(mesh, pybamm.SubMesh1D): symbol.side = mesh.tabs[symbol.side] return child_spatial_method.boundary_value_or_flux( symbol, disc_child, self.bcs) elif isinstance(symbol, pybamm.UpwindDownwind): direction = symbol.name # upwind or downwind return spatial_method.upwind_or_downwind( child, disc_child, self.bcs, direction) else: return symbol._unary_new_copy(disc_child) elif isinstance(symbol, pybamm.Function): disc_children = [ self.process_symbol(child) for child in symbol.children ] return symbol._function_new_copy(disc_children) elif isinstance(symbol, pybamm.VariableDot): return pybamm.StateVectorDot( *self.y_slices[symbol.get_variable().id], domain=symbol.domain, auxiliary_domains=symbol.auxiliary_domains, ) elif isinstance(symbol, pybamm.Variable): # Check if variable is a standard variable or an external variable if any(symbol.id == var.id for var in self.external_variables.values()): # Look up dictionary key based on value idx = [x.id for x in self.external_variables.values() ].index(symbol.id) name, parent_and_slice = list( self.external_variables.keys())[idx] if parent_and_slice is None: # Variable didn't come from a concatenation so we can just create a # normal external variable using the symbol's name return pybamm.ExternalVariable( symbol.name, size=self._get_variable_size(symbol), domain=symbol.domain, auxiliary_domains=symbol.auxiliary_domains, ) else: # We have to use a special name since the concatenation doesn't have # a very informative name. Needs improving parent, start, end = parent_and_slice ext = pybamm.ExternalVariable( name, size=self._get_variable_size(parent), domain=parent.domain, auxiliary_domains=parent.auxiliary_domains, ) out = pybamm.Index(ext, slice(start, end)) out.domain = symbol.domain return out else: # add a try except block for a more informative error if a variable # can't be found. This should usually be caught earlier by # model.check_well_posedness, but won't be if debug_mode is False try: y_slices = self.y_slices[symbol.id] except KeyError: raise pybamm.ModelError(""" No key set for variable '{}'. Make sure it is included in either model.rhs, model.algebraic, or model.external_variables in an unmodified form (e.g. not Broadcasted) """.format(symbol.name)) return pybamm.StateVector( *y_slices, domain=symbol.domain, auxiliary_domains=symbol.auxiliary_domains, ) elif isinstance(symbol, pybamm.SpatialVariable): return spatial_method.spatial_variable(symbol) elif isinstance(symbol, pybamm.Concatenation): new_children = [ self.process_symbol(child) for child in symbol.children ] new_symbol = spatial_method.concatenation(new_children) return new_symbol elif isinstance(symbol, pybamm.InputParameter): # Return a new copy of the input parameter, but set the expected size # according to the domain of the input parameter expected_size = self._get_variable_size(symbol) new_input_parameter = symbol.new_copy() new_input_parameter.set_expected_size(expected_size) return new_input_parameter else: # Backup option: return new copy of the object try: return symbol.new_copy() except NotImplementedError: raise NotImplementedError( "Cannot discretise symbol of type '{}'".format( type(symbol)))
def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) y_tests = [ np.array([[2.0], [3.0]]), np.array([[1.0], [3.0]]), np.array([1.0, 3.0]), ] t_tests = [1.0, 2.0] # test a * b expr = a * b evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 6) result = evaluator.evaluate(t=None, y=np.array([[1], [3]])) self.assertEqual(result, 3) # test function(a*b) expr = pybamm.Function(test_function, a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, 12) # test exp expr = pybamm.exp(a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate(t=None, y=np.array([[2], [3]])) self.assertEqual(result, np.exp(6)) # test a constant expression expr = pybamm.Scalar(2) * pybamm.Scalar(3) evaluator = pybamm.EvaluatorJax(expr) result = evaluator.evaluate() self.assertEqual(result, 6) # test a larger expression expr = a * b + b + a**2 / b + 2 * a + b / 2 + 4 evaluator = pybamm.EvaluatorJax(expr) for y in y_tests: result = evaluator.evaluate(t=None, y=y) np.testing.assert_allclose(result, expr.evaluate(t=None, y=y)) # test something with time expr = a * pybamm.t evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a matrix multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) expr = A @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a heaviside a = pybamm.Vector(np.array([1, 2])) expr = a <= pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = a > pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with a minimum or maximum a = pybamm.Vector(np.array([1, 2])) expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2))) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test something with an index expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) self.assertEqual(result, expr.evaluate(t=t, y=y)) # test something with a sparse matrix-vector multiplication A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) C = pybamm.Matrix(scipy.sparse.coo_matrix(np.array([[1, 0], [0, 4]]))) expr = A @ B @ C @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test the sparse-scalar multiplication A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) for expr in [ A * pybamm.t @ pybamm.StateVector(slice(0, 2)), pybamm.t * A @ pybamm.StateVector(slice(0, 2)), ]: evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test the sparse-scalar division A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) expr = A / (1.0 + pybamm.t) @ pybamm.StateVector(slice(0, 2)) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test sparse stack A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]]))) a = pybamm.StateVector(slice(0, 1)) expr = pybamm.SparseStack(A, a * B) with self.assertRaises(NotImplementedError): evaluator = pybamm.EvaluatorJax(expr) # test sparse mat-mat mult A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1, 0], [0, 4]]))) B = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[2, 0], [5, 0]]))) a = pybamm.StateVector(slice(0, 1)) expr = A @ (a * B) with self.assertRaises(NotImplementedError): evaluator = pybamm.EvaluatorJax(expr) # test numpy concatenation a = pybamm.Vector(np.array([[1], [2]])) b = pybamm.Vector(np.array([[3]])) expr = pybamm.NumpyConcatenation(a, b) evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) # test Inner A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[1]]))) v = pybamm.StateVector(slice(0, 1)) for expr in [ pybamm.Inner(A, v) @ v, pybamm.Inner(v, A) @ v, pybamm.Inner(v, v) @ v ]: evaluator = pybamm.EvaluatorJax(expr) for t, y in zip(t_tests, y_tests): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))