def inner(left, right): """Return inner product of two symbols.""" left, right = preprocess_binary(left, right) # simplify multiply by scalar zero, being careful about shape if pybamm.is_scalar_zero(left): return pybamm.zeros_like(right) if pybamm.is_scalar_zero(right): return pybamm.zeros_like(left) # if one of the children is a zero matrix, we have to be careful about shapes if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): return pybamm.zeros_like(pybamm.Inner(left, right)) # anything multiplied by a scalar one returns itself if pybamm.is_scalar_one(left): return right if pybamm.is_scalar_one(right): return left return pybamm.simplify_if_constant(pybamm.Inner(left, right))
def simplified_matrix_multiplication(left, right): left, right = preprocess_binary(left, right) if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): return pybamm.zeros_like(pybamm.MatrixMultiplication(left, right)) if isinstance(right, Multiplication) and left.is_constant(): # Simplify A @ (b * c) to (A * b) @ c if (A * b) is constant if right.left.evaluates_to_constant_number(): r_left, r_right = right.orphans new_left = left * r_left return new_left @ r_right # Simplify A @ (b * c) to (A * c) @ b if (A * c) is constant elif right.right.evaluates_to_constant_number(): r_left, r_right = right.orphans new_left = left * r_right return new_left @ r_left elif isinstance(right, Division) and left.is_constant(): # Simplify A @ (b / c) to (A / c) @ b if (A / c) is constant if right.right.evaluates_to_constant_number(): r_left, r_right = right.orphans new_left = left / r_right new_mul = new_left @ r_left # Keep the domain of the old left new_mul.copy_domains(left) return new_mul # Simplify A @ (B @ c) to (A @ B) @ c if (A @ B) is constant # This is a common construction that appears from discretisation of spatial # operators if (isinstance(right, MatrixMultiplication) and right.left.is_constant() and left.is_constant()): r_left, r_right = right.orphans new_left = left @ r_left # be careful about domains to avoid weird errors new_left.clear_domains() new_mul = new_left @ r_right # Keep the domain of the old right new_mul.copy_domains(right) return new_mul # Simplify A @ (b + c) to (A @ b) + (A @ c) if (A @ b) or (A @ c) is constant # This is a common construction that appears from discretisation of spatial # operators # Don't do this if either b or c is a number as this will lead to matmul errors elif isinstance(right, Addition): if (right.left.is_constant() or right.right.is_constant() ) and not (right.left.size_for_testing == 1 or right.right.size_for_testing == 1): r_left, r_right = right.orphans return (left @ r_left) + (left @ r_right) return pybamm.simplify_if_constant(pybamm.MatrixMultiplication( left, right))
def simplified_subtraction(left, right): """ Note ---- We check for scalars first, then matrices. This is because (Zero Matrix) - (Zero Scalar) should return (Zero Matrix), not -(Zero Scalar). """ left, right = simplify_elementwise_binary_broadcasts(left, right) # Check for Concatenations and Broadcasts out = simplified_binary_broadcast_concatenation(left, right, simplified_subtraction) if out is not None: return out # anything added by a scalar zero returns the other child if pybamm.is_scalar_zero(left): return -right if pybamm.is_scalar_zero(right): return left # Check matrices after checking scalars if pybamm.is_matrix_zero(left): if right.evaluates_to_number(): return -right * pybamm.ones_like(left) # See comments in simplified_addition elif all(left_dim_size <= right_dim_size for left_dim_size, right_dim_size in zip( left.shape_for_testing, right.shape_for_testing)) and all( left.evaluates_on_edges(dim) == right.evaluates_on_edges(dim) for dim in ["primary", "secondary", "tertiary"]): return -right if pybamm.is_matrix_zero(right): if left.evaluates_to_number(): return left * pybamm.ones_like(right) # See comments in simplified_addition elif all(left_dim_size >= right_dim_size for left_dim_size, right_dim_size in zip( left.shape_for_testing, right.shape_for_testing)) and all( left.evaluates_on_edges(dim) == right.evaluates_on_edges(dim) for dim in ["primary", "secondary", "tertiary"]): return left # a symbol minus itself is 0s of the same shape if left.id == right.id: return pybamm.zeros_like(left) return pybamm.simplify_if_constant(pybamm.Subtraction(left, right))
def simplified_multiplication(left, right): left, right = simplify_elementwise_binary_broadcasts(left, right) # Check for Concatenations and Broadcasts out = simplified_binary_broadcast_concatenation(left, right, simplified_multiplication) if out is not None: return out # simplify multiply by scalar zero, being careful about shape if pybamm.is_scalar_zero(left): return pybamm.zeros_like(right) if pybamm.is_scalar_zero(right): return pybamm.zeros_like(left) # if one of the children is a zero matrix, we have to be careful about shapes if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): return pybamm.zeros_like(pybamm.Multiplication(left, right)) # anything multiplied by a scalar one returns itself if pybamm.is_scalar_one(left): return right if pybamm.is_scalar_one(right): return left # anything multiplied by a scalar negative one returns negative itself if pybamm.is_scalar_minus_one(left): return -right if pybamm.is_scalar_minus_one(right): return -left # anything multiplied by a matrix one returns itself if # - the shapes are the same # - both left and right evaluate on edges, or both evaluate on nodes, in all # dimensions # (and possibly more generally, but not implemented here) try: if left.shape_for_testing == right.shape_for_testing and all( left.evaluates_on_edges(dim) == right.evaluates_on_edges(dim) for dim in ["primary", "secondary", "tertiary"]): if pybamm.is_matrix_one(left): return right elif pybamm.is_matrix_one(right): return left # also check for negative one if pybamm.is_matrix_minus_one(left): return -right elif pybamm.is_matrix_minus_one(right): return -left except NotImplementedError: pass # Return constant if both sides are constant if left.is_constant() and right.is_constant(): return pybamm.simplify_if_constant(pybamm.Multiplication(left, right)) # Simplify (B @ c) * a to (a * B) @ c if (a * B) is constant # This is a common construction that appears from discretisation of spatial # operators if (isinstance(left, MatrixMultiplication) and left.left.is_constant() and right.is_constant() and not (right.ndim_for_testing == 2 and right.shape_for_testing[1] > 1)): l_left, l_right = left.orphans new_left = right * l_left # Special hack for the case where l_left is a matrix one # because of weird domain errors otherwise if new_left == right and isinstance(right, pybamm.Array): new_left = right.new_copy() # be careful about domains to avoid weird errors new_left.clear_domains() new_mul = new_left @ l_right # Keep the domain of the old left new_mul.copy_domains(left) return new_mul elif isinstance(left, Multiplication) and right.is_constant(): # Simplify (a * b) * c to (a * c) * b if (a * c) is constant if left.left.is_constant(): l_left, l_right = left.orphans new_left = l_left * right return new_left * l_right # Simplify (a * b) * c to a * (b * c) if (b * c) is constant elif left.right.is_constant(): l_left, l_right = left.orphans new_right = l_right * right return l_left * new_right elif isinstance(left, Division) and right.is_constant(): # Simplify (a / b) * c to a * (c / b) if (c / b) is constant if left.right.is_constant(): l_left, l_right = left.orphans new_right = right / l_right return l_left * new_right # Simplify a * (B @ c) to (a * B) @ c if (a * B) is constant if (isinstance(right, MatrixMultiplication) and right.left.is_constant() and left.is_constant() and not (left.ndim_for_testing == 2 and left.shape_for_testing[1] > 1)): r_left, r_right = right.orphans new_left = left * r_left # Special hack for the case where r_left is a matrix one # because of weird domain errors otherwise if new_left == left and isinstance(left, pybamm.Array): new_left = left.new_copy() # be careful about domains to avoid weird errors new_left.clear_domains() new_mul = new_left @ r_right # Keep the domain of the old right new_mul.copy_domains(right) return new_mul elif isinstance(right, Multiplication) and left.is_constant(): # Simplify a * (b * c) to (a * b) * c if (a * b) is constant if right.left.is_constant(): r_left, r_right = right.orphans new_left = left * r_left return new_left * r_right # Simplify a * (b * c) to (a * c) * b if (a * c) is constant elif right.right.is_constant(): r_left, r_right = right.orphans new_left = left * r_right return new_left * r_left elif isinstance(right, Division) and left.is_constant(): # Simplify a * (b / c) to (a / c) * b if (a / c) is constant if right.right.is_constant(): r_left, r_right = right.orphans new_left = left / r_right return new_left * r_left # Simplify a * (b + c) to (a * b) + (a * c) if (a * b) or (a * c) is constant # This is a common construction that appears from discretisation of spatial # operators # Also do this for cases like a * (b @ c + d) where (a * b) is constant elif isinstance(right, Addition): mul_classes = ( pybamm.Multiplication, pybamm.MatrixMultiplication, pybamm.Division, ) if (right.left.is_constant() or right.right.is_constant() or (isinstance(right.left, mul_classes) and right.left.left.is_constant()) or (isinstance(right.right, mul_classes) and right.right.left.is_constant())): r_left, r_right = right.orphans if (r_left.domain == right.domain or r_left.domain == []) and (r_right.domain == right.domain or r_right.domain == []): return (left * r_left) + (left * r_right) # Negation simplifications if isinstance(left, pybamm.Negate) and right.is_constant(): # Simplify (-a) * b to a * (-b) if (-b) is constant return left.orphans[0] * (-right) elif isinstance(right, pybamm.Negate) and left.is_constant(): # Simplify a * (-b) to (-a) * b if (-a) is constant return (-left) * right.orphans[0] return pybamm.Multiplication(left, right)
def simplified_division(left, right): left, right = simplify_elementwise_binary_broadcasts(left, right) # Check for Concatenations and Broadcasts out = simplified_binary_broadcast_concatenation(left, right, simplified_division) if out is not None: return out # zero divided by anything returns zero (being careful about shape) if pybamm.is_scalar_zero(left): return pybamm.zeros_like(right) # matrix zero divided by anything returns matrix zero (i.e. itself) if pybamm.is_matrix_zero(left): return pybamm.zeros_like(pybamm.Division(left, right)) # anything divided by zero raises error if pybamm.is_scalar_zero(right): raise ZeroDivisionError # anything divided by one is itself if pybamm.is_scalar_one(right): return left # a symbol divided by itself is 1s of the same shape if left.id == right.id: return pybamm.ones_like(left) # anything multiplied by a matrix one returns itself if # - the shapes are the same # - both left and right evaluate on edges, or both evaluate on nodes, in all # dimensions # (and possibly more generally, but not implemented here) try: if left.shape_for_testing == right.shape_for_testing and all( left.evaluates_on_edges(dim) == right.evaluates_on_edges(dim) for dim in ["primary", "secondary", "tertiary"]): if pybamm.is_matrix_one(right): return left # also check for negative one if pybamm.is_matrix_minus_one(right): return -left except NotImplementedError: pass # Return constant if both sides are constant if left.is_constant() and right.is_constant(): return pybamm.simplify_if_constant(pybamm.Division(left, right)) # Simplify (B @ c) / a to (B / a) @ c if (B / a) is constant # This is a common construction that appears from discretisation of averages elif isinstance(left, MatrixMultiplication) and right.is_constant(): l_left, l_right = left.orphans new_left = l_left / right if new_left.is_constant(): # be careful about domains to avoid weird errors new_left.clear_domains() new_division = new_left @ l_right # Keep the domain of the old left new_division.copy_domains(left) return new_division if isinstance(left, Multiplication): # Simplify (a * b) / c to (a / c) * b if (a / c) is constant if left.left.is_constant(): l_left, l_right = left.orphans new_left = l_left / right if new_left.is_constant(): return new_left * l_right # Simplify (a * b) / c to a * (b / c) if (b / c) is constant elif left.right.is_constant(): l_left, l_right = left.orphans new_right = l_right / right if new_right.is_constant(): return l_left * new_right # Negation simplifications elif isinstance(left, pybamm.Negate) and right.is_constant(): # Simplify (-a) / b to a / (-b) if (-b) is constant return left.orphans[0] / (-right) elif isinstance(right, pybamm.Negate) and left.is_constant(): # Simplify a / (-b) to (-a) / b if (-a) is constant return (-left) / right.orphans[0] return pybamm.simplify_if_constant(pybamm.Division(left, right))