예제 #1
0
def simplified_function(func_class, child):
    """
    Simplifications implemented before applying the function.
    Currently only implemented for one-child functions.
    """
    if isinstance(child, pybamm.Broadcast):
        # Move the function inside the broadcast
        # Apply recursively
        func_child_not_broad = pybamm.simplify_if_constant(
            simplified_function(func_class, child.orphans[0]))
        return child._unary_new_copy(func_child_not_broad)
    else:
        return pybamm.simplify_if_constant(func_class(child))
예제 #2
0
def min(child):
    """
    Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which
    returns the smaller of two objects.
    """
    return pybamm.simplify_if_constant(Function(np.min, child),
                                       keep_domains=True)
예제 #3
0
def simplified_domain_concatenation(children, mesh, copy_this=None):
    """ Perform simplifications on a domain concatenation """
    # Create the DomainConcatenation to read domain and child domain
    concat = DomainConcatenation(children, mesh, copy_this=copy_this)
    # Simplify Concatenation of StateVectors to a single StateVector
    # The sum of the evalation arrays of the StateVectors must be exactly 1
    if all(isinstance(child, pybamm.StateVector) for child in children):
        longest_eval_array = len(children[-1]._evaluation_array)
        eval_arrays = {}
        for child in children:
            eval_arrays[child] = np.concatenate(
                [
                    child.evaluation_array,
                    np.zeros(longest_eval_array - len(child.evaluation_array)),
                ]
            )
        first_start = children[0].y_slices[0].start
        last_stop = children[-1].y_slices[-1].stop
        if all(
            sum(array for array in eval_arrays.values())[first_start:last_stop] == 1
        ):
            return pybamm.StateVector(
                slice(first_start, last_stop),
                domain=concat.domain,
                auxiliary_domains=concat.auxiliary_domains,
            )

    return pybamm.simplify_if_constant(concat)
예제 #4
0
파일: symbol.py 프로젝트: yonas-y/PyBaMM
 def __abs__(self):
     """return an :class:`AbsoluteValue` object, or a smooth approximation"""
     k = pybamm.settings.abs_smoothing
     # Return exact approximation if that is the setting or the outcome is a constant
     # (i.e. no need for smoothing)
     if k == "exact" or is_constant(self):
         out = pybamm.AbsoluteValue(self)
     else:
         out = pybamm.smooth_absolute_value(self, k)
     return pybamm.simplify_if_constant(out, keep_domains=True)
예제 #5
0
 def __ge__(self, other):
     """return a :class:`EqualHeaviside` object, or a smooth approximation."""
     k = pybamm.settings.heaviside_smoothing
     # Return exact approximation if that is the setting or the outcome is a constant
     # (i.e. no need for smoothing)
     if k == "exact" or (is_constant(self) and is_constant(other)):
         out = pybamm.EqualHeaviside(other, self)
     else:
         out = pybamm.sigmoid(other, self, k)
     return pybamm.simplify_if_constant(out)
예제 #6
0
 def __abs__(self):
     """return an :class:`AbsoluteValue` object, or a smooth approximation."""
     if isinstance(self, pybamm.AbsoluteValue):
         # No need to apply abs a second time
         return self
     elif isinstance(self, pybamm.Broadcast):
         # Move absolute value inside the broadcast
         # Apply recursively
         abs_self_not_broad = pybamm.simplify_if_constant(
             abs(self.orphans[0]))
         return self._unary_new_copy(abs_self_not_broad)
     else:
         k = pybamm.settings.abs_smoothing
         # Return exact approximation if that is the setting or the outcome is a
         # constant (i.e. no need for smoothing)
         if k == "exact" or is_constant(self):
             out = pybamm.AbsoluteValue(self)
         else:
             out = pybamm.smooth_absolute_value(self, k)
         return pybamm.simplify_if_constant(out)
예제 #7
0
def simplified_numpy_concatenation(*children):
    """ Perform simplifications on a numpy concatenation """
    # Turn a concatenation of concatenations into a single concatenation
    new_children = []
    for child in children:
        # extract any children from numpy concatenation
        if isinstance(child, NumpyConcatenation):
            new_children.extend(child.orphans)
        else:
            new_children.append(child)
    return pybamm.simplify_if_constant(NumpyConcatenation(*new_children))
예제 #8
0
def simplified_matrix_multiplication(left, right):
    left, right = preprocess_binary(left, right)
    if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right):
        return pybamm.zeros_like(pybamm.MatrixMultiplication(left, right))

    if isinstance(right, Multiplication) and left.is_constant():
        # Simplify A @ (b * c) to (A * b) @ c if (A * b) is constant
        if right.left.evaluates_to_constant_number():
            r_left, r_right = right.orphans
            new_left = left * r_left
            return new_left @ r_right
        # Simplify A @ (b * c) to (A * c) @ b if (A * c) is constant
        elif right.right.evaluates_to_constant_number():
            r_left, r_right = right.orphans
            new_left = left * r_right
            return new_left @ r_left
    elif isinstance(right, Division) and left.is_constant():
        # Simplify A @ (b / c) to (A / c) @ b if (A / c) is constant
        if right.right.evaluates_to_constant_number():
            r_left, r_right = right.orphans
            new_left = left / r_right
            new_mul = new_left @ r_left
            # Keep the domain of the old left
            new_mul.copy_domains(left)
            return new_mul

    # Simplify A @ (B @ c) to (A @ B) @ c if (A @ B) is constant
    # This is a common construction that appears from discretisation of spatial
    # operators
    if (isinstance(right, MatrixMultiplication) and right.left.is_constant()
            and left.is_constant()):
        r_left, r_right = right.orphans
        new_left = left @ r_left
        # be careful about domains to avoid weird errors
        new_left.clear_domains()
        new_mul = new_left @ r_right
        # Keep the domain of the old right
        new_mul.copy_domains(right)
        return new_mul

    # Simplify A @ (b + c) to (A @ b) + (A @ c) if (A @ b) or (A @ c) is constant
    # This is a common construction that appears from discretisation of spatial
    # operators
    # Don't do this if either b or c is a number as this will lead to matmul errors
    elif isinstance(right, Addition):
        if (right.left.is_constant() or right.right.is_constant()
            ) and not (right.left.size_for_testing == 1
                       or right.right.size_for_testing == 1):
            r_left, r_right = right.orphans
            return (left @ r_left) + (left @ r_right)

    return pybamm.simplify_if_constant(pybamm.MatrixMultiplication(
        left, right))
예제 #9
0
def maximum(left, right):
    """
    Returns the larger of two objects, possibly with a smoothing approximation.
    Not to be confused with :meth:`pybamm.max`, which returns max function of child.
    """
    k = pybamm.settings.max_smoothing
    # Return exact approximation if that is the setting or the outcome is a constant
    # (i.e. no need for smoothing)
    if k == "exact" or (pybamm.is_constant(left) and pybamm.is_constant(right)):
        out = Maximum(left, right)
    else:
        out = pybamm.softplus(left, right, k)
    return pybamm.simplify_if_constant(out, keep_domains=True)
예제 #10
0
 def __neg__(self):
     """return a :class:`Negate` object."""
     if isinstance(self, pybamm.Negate):
         # Double negative is a positive
         return self.orphans[0]
     elif isinstance(self, pybamm.Broadcast):
         # Move negation inside the broadcast
         # Apply recursively
         return self._unary_new_copy(-self.orphans[0])
     elif isinstance(self, pybamm.Concatenation) and all(
             child.is_constant() for child in self.children):
         return pybamm.concatenation(*[-child for child in self.orphans])
     else:
         return pybamm.simplify_if_constant(pybamm.Negate(self))
예제 #11
0
def simplified_subtraction(left, right):
    """
     Note
    ----
    We check for scalars first, then matrices. This is because
    (Zero Matrix) - (Zero Scalar)
    should return (Zero Matrix), not -(Zero Scalar).
    """
    left, right = simplify_elementwise_binary_broadcasts(left, right)

    # Check for Concatenations and Broadcasts
    out = simplified_binary_broadcast_concatenation(left, right,
                                                    simplified_subtraction)
    if out is not None:
        return out

    # anything added by a scalar zero returns the other child
    if pybamm.is_scalar_zero(left):
        return -right
    if pybamm.is_scalar_zero(right):
        return left
    # Check matrices after checking scalars
    if pybamm.is_matrix_zero(left):
        if right.evaluates_to_number():
            return -right * pybamm.ones_like(left)
        # See comments in simplified_addition
        elif all(left_dim_size <= right_dim_size
                 for left_dim_size, right_dim_size in zip(
                     left.shape_for_testing, right.shape_for_testing)) and all(
                         left.evaluates_on_edges(dim) ==
                         right.evaluates_on_edges(dim)
                         for dim in ["primary", "secondary", "tertiary"]):
            return -right
    if pybamm.is_matrix_zero(right):
        if left.evaluates_to_number():
            return left * pybamm.ones_like(right)
        # See comments in simplified_addition
        elif all(left_dim_size >= right_dim_size
                 for left_dim_size, right_dim_size in zip(
                     left.shape_for_testing, right.shape_for_testing)) and all(
                         left.evaluates_on_edges(dim) ==
                         right.evaluates_on_edges(dim)
                         for dim in ["primary", "secondary", "tertiary"]):
            return left

    # a symbol minus itself is 0s of the same shape
    if left.id == right.id:
        return pybamm.zeros_like(left)

    return pybamm.simplify_if_constant(pybamm.Subtraction(left, right))
예제 #12
0
 def simplify_with_mat_mul(nodes, types):
     new_nodes = [nodes[0]]
     new_types = [types[0]]
     for child, typ in zip(nodes[1:], types[1:]):
         if (new_nodes[-1].is_constant() and child.is_constant()
                 and new_nodes[-1].evaluate_ignoring_errors() is not None
                 and child.evaluate_ignoring_errors() is not None):
             if typ == pybamm.MatrixMultiplication:
                 new_nodes[-1] = new_nodes[-1] @ child
             else:
                 new_nodes[-1] *= child
             new_nodes[-1] = pybamm.simplify_if_constant(new_nodes[-1])
         else:
             new_nodes.append(child)
             new_types.append(typ)
     new_nodes = fold_multiply(new_nodes, new_types)
     return new_nodes
예제 #13
0
    def _function_new_copy(self, children):
        """Returns a new copy of the function.

        Inputs
        ------
        children : : list
            A list of the children of the function

        Returns
        -------
            : :pybamm.Function
            A new copy of the function
        """
        return pybamm.simplify_if_constant(
            pybamm.Function(
                self.function,
                *children,
                name=self.name,
                derivative=self.derivative,
                differentiated_function=self.differentiated_function), )
예제 #14
0
def inner(left, right):
    """Return inner product of two symbols."""
    left, right = preprocess_binary(left, right)
    # simplify multiply by scalar zero, being careful about shape
    if pybamm.is_scalar_zero(left):
        return pybamm.zeros_like(right)
    if pybamm.is_scalar_zero(right):
        return pybamm.zeros_like(left)

    # if one of the children is a zero matrix, we have to be careful about shapes
    if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right):
        return pybamm.zeros_like(pybamm.Inner(left, right))

    # anything multiplied by a scalar one returns itself
    if pybamm.is_scalar_one(left):
        return right
    if pybamm.is_scalar_one(right):
        return left

    return pybamm.simplify_if_constant(pybamm.Inner(left, right))
예제 #15
0
def simplified_power(left, right):
    left, right = simplify_elementwise_binary_broadcasts(left, right)

    # Check for Concatenations and Broadcasts
    out = simplified_binary_broadcast_concatenation(left, right,
                                                    simplified_power)
    if out is not None:
        return out

    # anything to the power of zero is one
    if pybamm.is_scalar_zero(right):
        return pybamm.ones_like(left)

    # zero to the power of anything is zero
    if pybamm.is_scalar_zero(left):
        return pybamm.Scalar(0)

    # anything to the power of one is itself
    if pybamm.is_scalar_one(right):
        return left

    if isinstance(left, Multiplication):
        # Simplify (a * b) ** c to (a ** c) * (b ** c)
        # if (a ** c) is constant or (b ** c) is constant
        if left.left.is_constant() or left.right.is_constant():
            l_left, l_right = left.orphans
            new_left = l_left**right
            new_right = l_right**right
            if new_left.is_constant() or new_right.is_constant():
                return new_left * new_right
    elif isinstance(left, Division):
        # Simplify (a / b) ** c to (a ** c) / (b ** c)
        # if (a ** c) is constant or (b ** c) is constant
        if left.left.is_constant() or left.right.is_constant():
            l_left, l_right = left.orphans
            new_left = l_left**right
            new_right = l_right**right
            if new_left.is_constant() or new_right.is_constant():
                return new_left / new_right

    return pybamm.simplify_if_constant(pybamm.Power(left, right))
예제 #16
0
 def __neg__(self):
     """return a :class:`Negate` object"""
     return pybamm.simplify_if_constant(pybamm.Negate(self),
                                        keep_domains=True)
예제 #17
0
 def __ge__(self, other):
     """return a :class:`Heaviside` object"""
     return pybamm.simplify_if_constant(pybamm.Heaviside(other,
                                                         self,
                                                         equal=True),
                                        keep_domains=True)
예제 #18
0
 def __rpow__(self, other):
     """return a :class:`Power` object"""
     return pybamm.simplify_if_constant(pybamm.Power(other, self),
                                        keep_domains=True)
예제 #19
0
 def __rtruediv__(self, other):
     """return a :class:`Division` object"""
     return pybamm.simplify_if_constant(pybamm.Division(other, self),
                                        keep_domains=True)
예제 #20
0
 def __rmatmul__(self, other):
     """return a :class:`MatrixMultiplication` object"""
     return pybamm.simplify_if_constant(pybamm.MatrixMultiplication(
         other, self),
                                        keep_domains=True)
예제 #21
0
 def __rsub__(self, other):
     """return a :class:`Subtraction` object"""
     return pybamm.simplify_if_constant(pybamm.Subtraction(other, self),
                                        keep_domains=True)
예제 #22
0
 def __radd__(self, other):
     """return an :class:`Addition` object"""
     return pybamm.simplify_if_constant(pybamm.Addition(other, self),
                                        keep_domains=True)
예제 #23
0
def maximum(left, right):
    """
    Returns the larger of two objects. Not to be confused with :meth:`pybamm.max`,
    which returns max function of child.
    """
    return pybamm.simplify_if_constant(Maximum(left, right), keep_domains=True)
예제 #24
0
 def __abs__(self):
     """return an :class:`AbsoluteValue` object"""
     return pybamm.simplify_if_constant(pybamm.AbsoluteValue(self),
                                        keep_domains=True)
예제 #25
0
def sech(child):
    " Returns hyperbolic sec function of child. "
    return pybamm.simplify_if_constant(1 / Cosh(child), keep_domains=True)
예제 #26
0
def min(child):
    " Returns min function of child. "
    return pybamm.simplify_if_constant(Function(np.min, child), keep_domains=True)
예제 #27
0
 def __getitem__(self, key):
     """return a :class:`Index` object"""
     return pybamm.simplify_if_constant(pybamm.Index(self, key),
                                        keep_domains=True)
예제 #28
0
def sqrt(child):
    " Returns square root function of child. "
    return pybamm.simplify_if_constant(Sqrt(child), keep_domains=True)
예제 #29
0
def sin(child):
    " Returns sine function of child. "
    return pybamm.simplify_if_constant(Sin(child), keep_domains=True)
예제 #30
0
def tanh(child):
    " Returns hyperbolic tan function of child. "
    return pybamm.simplify_if_constant(Tanh(child), keep_domains=True)