def test_addition(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") summ = pybamm.Addition(a, b) self.assertEqual(summ.children[0].name, a.name) self.assertEqual(summ.children[1].name, b.name) # test simplifying summ2 = pybamm.Scalar(1) + pybamm.Scalar(3) self.assertEqual(summ2.id, pybamm.Scalar(4).id)
def test_convert_scalar_symbols(self): a = pybamm.Scalar(0) b = pybamm.Scalar(1) c = pybamm.Scalar(-1) d = pybamm.Scalar(2) e = pybamm.Scalar(3) g = pybamm.Scalar(3.3) self.assertEqual(a.to_casadi(), casadi.MX(0)) self.assertEqual(d.to_casadi(), casadi.MX(2)) # negate self.assertEqual((-b).to_casadi(), casadi.MX(-1)) # absolute value self.assertEqual(abs(c).to_casadi(), casadi.MX(1)) # floor self.assertEqual(pybamm.Floor(g).to_casadi(), casadi.MX(3)) # ceiling self.assertEqual(pybamm.Ceiling(g).to_casadi(), casadi.MX(4)) # function def square_plus_one(x): return x**2 + 1 f = pybamm.Function(square_plus_one, b) self.assertEqual(f.to_casadi(), 2) def myfunction(x, y): return x + y f = pybamm.Function(myfunction, b, d) self.assertEqual(f.to_casadi(), casadi.MX(3)) # use classes to avoid simplification # addition self.assertEqual((pybamm.Addition(a, b)).to_casadi(), casadi.MX(1)) # subtraction self.assertEqual(pybamm.Subtraction(c, d).to_casadi(), casadi.MX(-3)) # multiplication self.assertEqual( pybamm.Multiplication(c, d).to_casadi(), casadi.MX(-2)) # power self.assertEqual(pybamm.Power(c, d).to_casadi(), casadi.MX(1)) # division self.assertEqual(pybamm.Division(b, d).to_casadi(), casadi.MX(1 / 2)) # modulo self.assertEqual(pybamm.Modulo(e, d).to_casadi(), casadi.MX(1)) # minimum and maximum self.assertEqual(pybamm.Minimum(a, b).to_casadi(), casadi.MX(0)) self.assertEqual(pybamm.Maximum(a, b).to_casadi(), casadi.MX(1))
def test_convert_scalar_symbols(self): a = pybamm.Scalar(0) b = pybamm.Scalar(1) c = pybamm.Scalar(-1) d = pybamm.Scalar(2) self.assertEqual(a.to_casadi(), casadi.MX(0)) self.assertEqual(d.to_casadi(), casadi.MX(2)) # negate self.assertEqual((-b).to_casadi(), casadi.MX(-1)) # absolute value self.assertEqual(abs(c).to_casadi(), casadi.MX(1)) # function def sin(x): return np.sin(x) f = pybamm.Function(sin, b) self.assertEqual(f.to_casadi(), casadi.MX(np.sin(1))) def myfunction(x, y): return x + y f = pybamm.Function(myfunction, b, d) self.assertEqual(f.to_casadi(), casadi.MX(3)) # use classes to avoid simplification # addition self.assertEqual((pybamm.Addition(a, b)).to_casadi(), casadi.MX(1)) # subtraction self.assertEqual(pybamm.Subtraction(c, d).to_casadi(), casadi.MX(-3)) # multiplication self.assertEqual( pybamm.Multiplication(c, d).to_casadi(), casadi.MX(-2)) # power self.assertEqual(pybamm.Power(c, d).to_casadi(), casadi.MX(1)) # division self.assertEqual(pybamm.Division(b, d).to_casadi(), casadi.MX(1 / 2)) # minimum and maximum self.assertEqual(pybamm.Minimum(a, b).to_casadi(), casadi.MX(0)) self.assertEqual(pybamm.Maximum(a, b).to_casadi(), casadi.MX(1))
def test_to_equation(self): # Test print_name pybamm.Addition.print_name = "test" self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.symbols("test")) # Test Power self.assertEqual(pybamm.Power(7, 2).to_equation(), 49) # Test Division self.assertEqual(pybamm.Division(10, 2).to_equation(), 5) # Test Matrix Multiplication arr1 = pybamm.Array([[1, 0], [0, 1]]) arr2 = pybamm.Array([[4, 1], [2, 2]]) self.assertEqual( pybamm.MatrixMultiplication(arr1, arr2).to_equation(), sympy.Matrix([[4.0, 1.0], [2.0, 2.0]]), ) # Test EqualHeaviside self.assertEqual(pybamm.EqualHeaviside(1, 0).to_equation(), False) # Test NotEqualHeaviside self.assertEqual(pybamm.NotEqualHeaviside(2, 4).to_equation(), True)
def test_addition_printing(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") summ = pybamm.Addition(a, b) self.assertEqual(summ.name, "+") self.assertEqual(str(summ), "a + b")
def __radd__(self, other): """return an :class:`Addition` object""" return pybamm.simplify_if_constant(pybamm.Addition(other, self), keep_domains=True)
def __radd__(self, other): """return an :class:`Addition` object""" if isinstance(other, (Symbol, numbers.Number)): return pybamm.Addition(other, self) else: raise NotImplementedError
def simplified_addition(left, right): """ Note ---- We check for scalars first, then matrices. This is because (Zero Matrix) + (Zero Scalar) should return (Zero Matrix), not (Zero Scalar). """ left, right = simplify_elementwise_binary_broadcasts(left, right) # Check for Concatenations and Broadcasts out = simplified_binary_broadcast_concatenation(left, right, simplified_addition) if out is not None: return out # anything added by a scalar zero returns the other child elif pybamm.is_scalar_zero(left): return right elif pybamm.is_scalar_zero(right): return left # Check matrices after checking scalars elif pybamm.is_matrix_zero(left): if right.evaluates_to_number(): return right * pybamm.ones_like(left) # If left object is zero and has size smaller than or equal to right object in # all dimensions, we can safely return the right object. For example, adding a # zero vector a matrix, we can just return the matrix elif all(left_dim_size <= right_dim_size for left_dim_size, right_dim_size in zip( left.shape_for_testing, right.shape_for_testing)) and all( left.evaluates_on_edges(dim) == right.evaluates_on_edges(dim) for dim in ["primary", "secondary", "tertiary"]): return right elif pybamm.is_matrix_zero(right): if left.evaluates_to_number(): return left * pybamm.ones_like(right) # See comment above elif all(left_dim_size >= right_dim_size for left_dim_size, right_dim_size in zip( left.shape_for_testing, right.shape_for_testing)) and all( left.evaluates_on_edges(dim) == right.evaluates_on_edges(dim) for dim in ["primary", "secondary", "tertiary"]): return left # Return constant if both sides are constant if left.is_constant() and right.is_constant(): return pybamm.simplify_if_constant(pybamm.Addition(left, right)) # Simplify A @ c + B @ c to (A + B) @ c if (A + B) is constant # This is a common construction that appears from discretisation of spatial # operators elif (isinstance(left, MatrixMultiplication) and isinstance(right, MatrixMultiplication) and left.right.id == right.right.id): l_left, l_right = left.orphans r_left = right.orphans[0] new_left = l_left + r_left if new_left.is_constant(): new_sum = new_left @ l_right new_sum.copy_domains(pybamm.Addition(left, right)) return new_sum if isinstance(right, pybamm.Addition) and left.is_constant(): # Simplify a + (b + c) to (a + b) + c if (a + b) is constant if right.left.is_constant(): r_left, r_right = right.orphans return (left + r_left) + r_right # Simplify a + (b + c) to (a + c) + b if (a + c) is constant elif right.right.is_constant(): r_left, r_right = right.orphans return (left + r_right) + r_left if isinstance(left, pybamm.Addition) and right.is_constant(): # Simplify (a + b) + c to a + (b + c) if (b + c) is constant if left.right.is_constant(): l_left, l_right = left.orphans return l_left + (l_right + right) # Simplify (a + b) + c to (a + c) + b if (a + c) is constant elif left.left.is_constant(): l_left, l_right = left.orphans return (l_left + right) + l_right return pybamm.simplify_if_constant(pybamm.Addition(left, right))
def test_addition(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") sum = pybamm.Addition(a, b) self.assertEqual(sum.children[0].name, a.name) self.assertEqual(sum.children[1].name, b.name)