def test_find_symbols(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) # test a + b constant_symbols = OrderedDict() variable_symbols = OrderedDict() expr = a + b pybamm.find_symbols(expr, constant_symbols, variable_symbols) self.assertEqual(len(constant_symbols), 0) # test keys of known_symbols self.assertEqual(list(variable_symbols.keys())[0], a.id) self.assertEqual(list(variable_symbols.keys())[1], b.id) self.assertEqual(list(variable_symbols.keys())[2], expr.id) # test values of variable_symbols self.assertEqual(list(variable_symbols.values())[0], "y[:1][[True]]") self.assertEqual( list(variable_symbols.values())[1], "y[:2][[False, True]]") var_a = pybamm.id_to_python_variable(a.id) var_b = pybamm.id_to_python_variable(b.id) self.assertEqual( list(variable_symbols.values())[2], "{} + {}".format(var_a, var_b)) # test identical subtree constant_symbols = OrderedDict() variable_symbols = OrderedDict() expr = a + b + b pybamm.find_symbols(expr, constant_symbols, variable_symbols) self.assertEqual(len(constant_symbols), 0) # test keys of variable_symbols self.assertEqual(list(variable_symbols.keys())[0], a.id) self.assertEqual(list(variable_symbols.keys())[1], b.id) self.assertEqual(list(variable_symbols.keys())[2], expr.children[0].id) self.assertEqual(list(variable_symbols.keys())[3], expr.id) # test values of variable_symbols self.assertEqual(list(variable_symbols.values())[0], "y[:1][[True]]") self.assertEqual( list(variable_symbols.values())[1], "y[:2][[False, True]]") self.assertEqual( list(variable_symbols.values())[2], "{} + {}".format(var_a, var_b)) var_child = pybamm.id_to_python_variable(expr.children[0].id) self.assertEqual( list(variable_symbols.values())[3], "{} + {}".format(var_child, var_b)) # test unary op constant_symbols = OrderedDict() variable_symbols = OrderedDict() expr = a + (-b) pybamm.find_symbols(expr, constant_symbols, variable_symbols) self.assertEqual(len(constant_symbols), 0) # test keys of variable_symbols self.assertEqual(list(variable_symbols.keys())[0], a.id) self.assertEqual(list(variable_symbols.keys())[1], b.id) self.assertEqual(list(variable_symbols.keys())[2], expr.children[1].id) self.assertEqual(list(variable_symbols.keys())[3], expr.id) # test values of variable_symbols self.assertEqual(list(variable_symbols.values())[0], "y[:1][[True]]") self.assertEqual( list(variable_symbols.values())[1], "y[:2][[False, True]]") self.assertEqual( list(variable_symbols.values())[2], "-{}".format(var_b)) var_child = pybamm.id_to_python_variable(expr.children[1].id) self.assertEqual( list(variable_symbols.values())[3], "{} + {}".format(var_a, var_child)) # test function constant_symbols = OrderedDict() variable_symbols = OrderedDict() expr = pybamm.Function(test_function, a) pybamm.find_symbols(expr, constant_symbols, variable_symbols) self.assertEqual(list(constant_symbols.keys())[0], expr.id) self.assertEqual(list(constant_symbols.values())[0], test_function) self.assertEqual(list(variable_symbols.keys())[0], a.id) self.assertEqual(list(variable_symbols.keys())[1], expr.id) self.assertEqual(list(variable_symbols.values())[0], "y[:1][[True]]") var_funct = pybamm.id_to_python_variable(expr.id, True) self.assertEqual( list(variable_symbols.values())[1], "{}({})".format(var_funct, var_a)) # test matrix constant_symbols = OrderedDict() variable_symbols = OrderedDict() A = pybamm.Matrix(np.array([[1, 2], [3, 4]])) pybamm.find_symbols(A, constant_symbols, variable_symbols) self.assertEqual(len(variable_symbols), 0) self.assertEqual(list(constant_symbols.keys())[0], A.id) np.testing.assert_allclose( list(constant_symbols.values())[0], np.array([[1, 2], [3, 4]])) # test sparse matrix constant_symbols = OrderedDict() variable_symbols = OrderedDict() A = pybamm.Matrix(scipy.sparse.csr_matrix(np.array([[0, 2], [0, 4]]))) pybamm.find_symbols(A, constant_symbols, variable_symbols) self.assertEqual(len(variable_symbols), 0) self.assertEqual(list(constant_symbols.keys())[0], A.id) np.testing.assert_allclose( list(constant_symbols.values())[0].toarray(), A.entries.toarray()) # test numpy concatentate constant_symbols = OrderedDict() variable_symbols = OrderedDict() expr = pybamm.NumpyConcatenation(a, b) pybamm.find_symbols(expr, constant_symbols, variable_symbols) self.assertEqual(len(constant_symbols), 0) self.assertEqual(list(variable_symbols.keys())[0], a.id) self.assertEqual(list(variable_symbols.keys())[1], b.id) self.assertEqual(list(variable_symbols.keys())[2], expr.id) self.assertEqual( list(variable_symbols.values())[2], "np.concatenate(({},{}))".format(var_a, var_b), ) # test domain concatentate constant_symbols = OrderedDict() variable_symbols = OrderedDict() expr = pybamm.NumpyConcatenation(a, b) pybamm.find_symbols(expr, constant_symbols, variable_symbols) self.assertEqual(len(constant_symbols), 0) self.assertEqual(list(variable_symbols.keys())[0], a.id) self.assertEqual(list(variable_symbols.keys())[1], b.id) self.assertEqual(list(variable_symbols.keys())[2], expr.id) self.assertEqual( list(variable_symbols.values())[2], "np.concatenate(({},{}))".format(var_a, var_b), ) # test that Concatentation throws expr = pybamm.Concatenation(a, b) with self.assertRaises(NotImplementedError): pybamm.find_symbols(expr, constant_symbols, variable_symbols) # test that these nodes throw for expr in (pybamm.Variable("a"), pybamm.Parameter("a")): with self.assertRaises(NotImplementedError): pybamm.find_symbols(expr, constant_symbols, variable_symbols)
def test_domain_concatenation(self): disc = get_discretisation_for_testing() mesh = disc.mesh a_dom = ["negative electrode"] b_dom = ["positive electrode"] a_pts = mesh[a_dom[0]].npts b_pts = mesh[b_dom[0]].npts a = pybamm.StateVector(slice(0, a_pts), domain=a_dom) b = pybamm.StateVector(slice(a_pts, a_pts + b_pts), domain=b_dom) y = np.empty((a_pts + b_pts, 1)) for i in range(len(y)): y[i] = i # concatenate them the "wrong" way round to check they get reordered correctly expr = pybamm.DomainConcatenation([b, a], mesh) constant_symbols = OrderedDict() variable_symbols = OrderedDict() pybamm.find_symbols(expr, constant_symbols, variable_symbols) self.assertEqual(list(variable_symbols.keys())[0], b.id) self.assertEqual(list(variable_symbols.keys())[1], a.id) self.assertEqual(list(variable_symbols.keys())[2], expr.id) var_a = pybamm.id_to_python_variable(a.id) var_b = pybamm.id_to_python_variable(b.id) self.assertEqual(len(constant_symbols), 0) self.assertEqual( list(variable_symbols.values())[2], "np.concatenate(({}[0:{}],{}[0:{}]))".format( var_a, a_pts, var_b, b_pts), ) evaluator = pybamm.EvaluatorPython(expr) result = evaluator.evaluate(y=y) np.testing.assert_allclose(result, expr.evaluate(y=y)) # check that concatenating a single domain is consistent expr = pybamm.DomainConcatenation([a], mesh) evaluator = pybamm.EvaluatorPython(expr) result = evaluator.evaluate(y=y) np.testing.assert_allclose(result, expr.evaluate(y=y)) # check the reordering in case a child vector has to be split up a_dom = ["separator"] b_dom = ["negative electrode", "positive electrode"] b0_pts = mesh[b_dom[0]].npts a0_pts = mesh[a_dom[0]].npts b1_pts = mesh[b_dom[1]].npts a = pybamm.StateVector(slice(0, a0_pts), domain=a_dom) b = pybamm.StateVector(slice(a0_pts, a0_pts + b0_pts + b1_pts), domain=b_dom) y = np.empty((a0_pts + b0_pts + b1_pts, 1)) for i in range(len(y)): y[i] = i var_a = pybamm.id_to_python_variable(a.id) var_b = pybamm.id_to_python_variable(b.id) expr = pybamm.DomainConcatenation([a, b], mesh) constant_symbols = OrderedDict() variable_symbols = OrderedDict() pybamm.find_symbols(expr, constant_symbols, variable_symbols) b0_str = "{}[0:{}]".format(var_b, b0_pts) a0_str = "{}[0:{}]".format(var_a, a0_pts) b1_str = "{}[{}:{}]".format(var_b, b0_pts, b0_pts + b1_pts) self.assertEqual(len(constant_symbols), 0) self.assertEqual( list(variable_symbols.values())[2], "np.concatenate(({},{},{}))".format(b0_str, a0_str, b1_str), ) evaluator = pybamm.EvaluatorPython(expr) result = evaluator.evaluate(y=y) np.testing.assert_allclose(result, expr.evaluate(y=y))