Esempio n. 1
0
    def remove_identities(self):
        """Removes equations of the form tmpX = tmpY.

        There are three different cases:
        
        If both sides of the equation are the same temporary, the equation can
        simply be removed.

        If both sides are different temporaries, all occurrences of the
        temporary on the left-hand side in the other equations are replaced by
        the temporary on the right-hand side. This is necessary to ensure that
        code generation still works. This case can happen when setting
        equivalent expressions does not work perfectly.

        If the operand on the left-hand side is an intermediate, and the one on
        the right hand side is a temporary, all occurrences of the intermediate
        will be replace with the temporary.

        Replacing the temporary in the other equations only works if the
        temporary that is replaced is not part of any computation yet.

        Replacing temporaries even works if there are sequences of assignments
        such as:
        tmp2 = tmp3
        tmp1 = tmp2

        Returns:
            Equations: self without identities, and with replaced temporaries.
        """

        new_equations = []
        mapping = dict()
        for equation in self.equations:
            if temporaries.is_temporary(equation.rhs):
                if equation.lhs != equation.rhs:
                    try:
                        op = mapping[equation.rhs]
                    except KeyError:
                        mapping[equation.lhs] = equation.rhs
                    else:
                        mapping[equation.lhs] = op
                    if temporaries.is_temporary(equation.lhs):
                        continue
                else:
                    continue

            if mapping:
                rules = [
                    matchpy.ReplacementRule(matchpy.Pattern(rhs),
                                            Replacer(lhs))
                    for rhs, lhs in mapping.items()
                ]
                new_equations.append(
                    ae.Equal(equation.lhs,
                             matchpy.replace_all(equation.rhs, rules)))
            else:
                new_equations.append(equation)

        return Equations(*new_equations)
Esempio n. 2
0
    def remove_identities(self):
        """Removes equations where both sides are temporaries.

        There are three different cases.

        If both sides of an equation are the same, it can simply be removed.

        If both sides are different temporaries, all occurrences of the
        temporary on the left-hand side in the other equations are replaced by
        the temporary on the right-hand side. This is necessary to ensure that
        code generation still works. This case can happen when setting
        equivalent expressions does not work perfectly.

        If the operand on the left-hand side is an intermediate, and the one on
        the right hand side is a temporary, all occurrences of the intermediate
        will be replace with the temporary.

        Replacing the temporary in the other equations only works if the
        temporary that is replaced is not part of any computation yet.

        Furthermore, it is assumed that there are no cases such as
        tmp2 = tmp3
        tmp1 = tmp2

        Returns:
            Equations: self with equations removed.
        """
        remove = []
        replace_eqns = []
        equations = list(self.equations)
        for n, equation in enumerate(equations):
            if equation.rhs.name in temporaries._equivalent_expressions:
                if equation.lhs.name in temporaries._equivalent_expressions:
                    remove.append(n)
                    if equation.lhs != equation.rhs:
                        replace_eqns.append(equation)
                else:
                    # TODO it would be better to only do this if the operand is known to be an intermediate.
                    replace_eqns.append(equation)

        for idx in reversed(remove):
            del equations[idx]

        for replace_eqn in replace_eqns:
            rule = matchpy.ReplacementRule(matchpy.Pattern(replace_eqn.lhs),
                                           lambda: replace_eqn.rhs)
            equations_replaced = []
            for equation in equations:
                equations_replaced.append(
                    ae.Equal(equation.lhs,
                             matchpy.replace_all(equation.rhs, (rule, ))))
            equations = equations_replaced

        # TODO do we want to manipulate the table of temporaries here?
        return Equations(*equations)
Esempio n. 3
0
    def remove_explicit_transposition(self, eqn_idx=None):
        """Removes equations of the form tmpX = tmpY^T.

        With common subexpression elimination and the application of tranposed
        operands, it is possible to reach assignments of the form tmpX = tmpY^T.
        To avoid that an explicit transposition is computed, this function
        replaces tmpX in all subsequent assignments with tmpY^T.

        Args:
            eqn_idx (int, optional): Only test if self.equations[eqn_idx] is an
                explicit transposition.

        Returns:
            Equations: self with explicit transposition removed.
        """

        indices = None
        if eqn_idx:
            indices = [eqn_idx]
        else:
            indices = range(len(self.equations))

        replacement_rules = []
        remove = set()
        for idx in indices:
            equation = self.equations[idx]
            if temporaries.is_temporary(equation.lhs) and isinstance(
                    equation.rhs, ae.Transpose) and isinstance(
                        equation.rhs.operand, ae.Symbol):
                if equation.lhs != equation.rhs.operand:
                    # In they are equal, the operand is symmetric, which means
                    # that this assignment is an identity and can be removed.
                    replacement_rules.append(
                        matchpy.ReplacementRule(matchpy.Pattern(equation.lhs),
                                                Replacer(equation.rhs)))
                remove.add(idx)

        if replacement_rules:
            new_equations = []
            for i, equation in enumerate(self.equations):
                if not i in remove:
                    if i < min(remove):
                        # It is sufficient to start replacing at the smallest index in remove.
                        new_equations.append(equation)
                    else:
                        new_equations.append(
                            ae.Equal(
                                equation.lhs,
                                matchpy.replace_all(equation.rhs,
                                                    replacement_rules)))
            return Equations(*new_equations)
        else:
            return self
Esempio n. 4
0
 def __init__(self, left: Expression, right: Expression) -> None:
     """
     :param left: Expression to find
     :param right: Expression to rewrite to.
     This must have same variables as :arg:`left`
     """
     if not get_variables(right) <= get_variables(left):
         raise (ValueError(
             "Variables on right of rule with no equivalent on the left")
                )  # NOQA
     substitution = ManyToOneMatcher._collect_variable_renaming(left)
     self.left = rename_variables(left, substitution)
     self.right = rename_variables(right, substitution)
     self.lhs = matchpy.Pattern(self.left)
Esempio n. 5
0
    def replace_auxiliaries(self):
        """Replaces auxiliaries.

        If there is an equation X = rhs where X has the property "auxiliary", the
        equation is removed and all occurences of X in other equations are
        replaced with rhs.
        """
        replacement_rules = []
        remove = []
        # IMPORTANT: eqn can NOT be renamed to equation. Otherwise,
        # "equation.rhs in the lambda function refers to the wrong equation.
        for eqn in self.equations:
            if eqn.lhs.has_property(Property.AUXILIARY):
                rule = (matchpy.Pattern(eqn.lhs), lambda **_: eqn.rhs)
                replacement_rules.append(rule)
                remove.append(eqn)

        for eqn in remove:
            self.equations.remove(eqn)

        self.equations = [matchpy.replace_all(equation, replacement_rules) for equation in self.equations]
Esempio n. 6
0
#####################
# Cholesky
#
# Note: We assume that all symmetric matrices are stored as lower triangular
# matrices. Actually, that's not necessary in Julia, obtaining the other half is
# just more expensive. Test which conversion is more expensive.
# Actually, with storage format conversions, we don't need this assumption.

_A = matchpy.Wildcard.symbol("_A", symbol_type=ae.Matrix)
_L = matchpy.Wildcard.symbol("_L")
cf = lambda d: (d["N"]**3) / 3

cholesky = FactorizationKernel(
    matchpy.Pattern(
        _A,
        matchpy.CustomConstraint(lambda _A: _A.has_property(Property.SPSD))),
    [InputOperand(_A, StorageFormat.symmetric_triangular)],
    Times(_L, Transpose(_L)),
    [
        OutputOperand(_L, _A, ("N", "N"),
                      [Property.LOWER_TRIANGULAR, Property.NON_SINGULAR],
                      StorageFormat.lower_triangular)
    ],
    cf,
    None,
    CodeTemplate("""LinearAlgebra.LAPACK.potrf!('L', $_A)"""),
    None,
    [SizeArgument("N", _A, "rows")],
)
Esempio n. 7
0
WD1 = matchpy.Wildcard.dot("WD1")
WD2 = matchpy.Wildcard.dot("WD2")
WD3 = matchpy.Wildcard.dot("WD3")
SYM1 = matchpy.Wildcard.symbol("SYM1")
SYM2 = matchpy.Wildcard.symbol("SYM2")
SYM3 = matchpy.Wildcard.symbol("SYM3")
_A = matchpy.Wildcard.symbol(
    "_A"
)  # It's important that this Wildcard has the same name as the one in the pattern for Cholesky
WS1 = matchpy.Wildcard.star("WS1")
WP1 = matchpy.Wildcard.plus("WP1")
WP2 = matchpy.Wildcard.plus("WP2")

eigen1 = matchpy.Pattern(
    Plus(Times(Transpose(SYM1), SYM2, SYM1), Times(WD1, SYM3), WS1),
    matchpy.CustomConstraint(
        lambda SYM1, SYM2, SYM3, WD1: SYM1.has_property(Property.ORTHOGONAL)
        and SYM2.has_property(Property.DIAGONAL) and WD1.has_property(
            Property.SCALAR) and SYM3.has_property(Property.IDENTITY)))


def eigen1_callback(substitution, equations, eqn_idx, position):
    # Here, the "Eigen-Trick" is applied. That is, given
    # Plus([Q^T W Q + alpha I]), tmp = W + alpha I is extraced and the
    # entire expression is replaced with Times([Q^T tmp Q])
    # The sum can not be computed directly with decompose_sum
    # because it is not necessarily a sufficiently simple
    # sum.
    equations_list = list(equations.equations)

    diagonal_sum = Plus(substitution["SYM2"],
                        Times(substitution["WD1"], substitution["SYM3"]))
Esempio n. 8
0
def register(*args):
    pattern, *constraints, replacement = args
    replacer.add(
        matchpy.ReplacementRule(matchpy.Pattern(pattern, *constraints),
                                replacement))
Esempio n. 9
0
def register(pattern: _T, replacement: typing.Callable[..., _T],
             *constraints: matchpy.Constraint) -> None:
    replacer.add(
        matchpy.ReplacementRule(matchpy.Pattern(pattern, *constraints),
                                replacement))
Esempio n. 10
0
                os.remove(path_to_file)


WD1 = matchpy.Wildcard.dot("WD1")
WD2 = matchpy.Wildcard.dot("WD2")
WS1 = matchpy.Wildcard.star("WS1")
WS2 = matchpy.Wildcard.star("WS2")
PS1 = matchpy.CustomConstraint(lambda WD1: WD1.has_property(Property.MATRIX) or
                               WD1.has_property(Property.VECTOR))
PS2 = matchpy.CustomConstraint(lambda WD2: WD2.has_property(Property.MATRIX) or
                               WD2.has_property(Property.VECTOR))
notInv1 = matchpy.CustomConstraint(lambda WD1: not is_inverse(WD1))
notInv2 = matchpy.CustomConstraint(lambda WD2: not is_inverse(WD2))

linsolveL = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, Inverse(WD1), WD2, WS2), PS1, PS2),
    lambda WS1, WD1, WD2, WS2: Times(*WS1, LinSolveL(WD1, WD2), *WS2))

linsolveLT = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, InverseTranspose(WD1), WD2, WS2), PS1,
                    PS2), lambda WS1, WD1, WD2, WS2: Times(
                        *WS1, LinSolveL(Transpose(WD1), WD2), *WS2))

linsolveR = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, WD1, Inverse(WD2), WS2), PS1, PS2),
    lambda WS1, WD1, WD2, WS2: Times(*WS1, LinSolveR(WD1, WD2), *WS2))

linsolveRT = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, WD1, InverseTranspose(WD2), WS2), PS1,
                    PS2), lambda WS1, WD1, WD2, WS2: Times(
                        *WS1, LinSolveR(WD1, Transpose(WD2)), *WS2))
Esempio n. 11
0
    def _set_match(self, match_dict, input_expr):
        """Auxiliary function for set_match()

        Computes only those things that are independent of whether temporaries
        are reused or not.
        """

        # Constructing input.
        kernel_io = KernelIO()
        for input_operand in self.input_operands:
            kernel_io.add_input(
                input_operand.operand,
                match_dict[input_operand.operand.variable_name],
                input_operand.storage_format)

        _arg_dict = dict()
        for arg in self.arguments:
            if isinstance(arg, SizeArgument):
                _arg_dict[arg.name] = arg.get_value(match_dict)

        _partial_operand_dict = dict()

        # replacement_dict maps wildcard names to operands
        replacement_dict = dict()
        for output_operand in self.output_operands:
            # output_operand.operand.name[1:] because it's a Wildcard, we drop the _
            name = "".join([
                output_operand.operand.variable_name[1:],
                temporaries.get_identifier()
            ])
            size = (_arg_dict[output_operand.size[0]],
                    _arg_dict[output_operand.size[1]])
            # TODO what if the output is a scalar? Check sizes.

            operand = Matrix(name, size, input_expr.indices)
            operand.set_property(Property.FACTOR)
            operand.factorization_labels = set(
                operand[0].name for operand in kernel_io.input_operands)
            for property in output_operand.properties:
                operand.set_property(property)

            replacement_dict[output_operand.operand.variable_name] = operand

            # Constructing output.
            if output_operand.overwriting:
                kernel_io.add_output(output_operand.overwriting, operand,
                                     output_operand.storage_format)
            else:
                kernel_io.add_output(output_operand.operand, operand,
                                     output_operand.storage_format)
                _partial_operand_dict[
                    output_operand.operand.variable_name] = operand

        _output_expr = matchpy.substitute(self.replacement_template,
                                          replacement_dict)

        input_equiv = temporaries.get_equivalent(input_expr)
        temporaries.equivalence_replacer.add(
            matchpy.ReplacementRule(
                matchpy.Pattern(Times(ctx1, _output_expr, ctx2)),
                lambda ctx1, ctx2: Times(*ctx1, input_equiv, *ctx2)))
        if input_expr.has_property(Property.SQUARE):
            temporaries.equivalence_replacer.add(
                matchpy.ReplacementRule(
                    matchpy.Pattern(Times(ctx1, invert(_output_expr), ctx2)),
                    lambda ctx1, ctx2: Times(*ctx1, invert(input_equiv), *ctx2
                                             )))
        # There is no need to generate transposed pattern for factorizations
        # with symmetric output; Cholesky (id 0) and Eigen (id 4).
        if self.id in {1, 2, 3, 5, 6, 7}:
            temporaries.equivalence_replacer.add(
                matchpy.ReplacementRule(
                    matchpy.Pattern(Times(ctx1, transpose(_output_expr),
                                          ctx2)), lambda ctx1, ctx2:
                    Times(*ctx1, transpose(input_equiv), *ctx2)))
            if input_expr.has_property(Property.SQUARE):
                temporaries.equivalence_replacer.add(
                    matchpy.ReplacementRule(
                        matchpy.Pattern(
                            Times(ctx1, invert_transpose(_output_expr), ctx2)),
                        lambda ctx1, ctx2: Times(
                            *ctx1, invert_transpose(input_equiv), *ctx2)))

        return _output_expr, _arg_dict, _partial_operand_dict, kernel_io
Esempio n. 12
0
def generate_matchpy_matcher(pattern_list):
    matcher = matchpy.ManyToOneMatcher()
    for pattern in pattern_list:
        matcher.add(matchpy.Pattern(pattern))
    return matcher
Esempio n. 13
0
    def generate_kernels(self):
        """Generator for ReductionKernel objects.

        Yields:
            ReductionKernel: Each object represents one instance of a BLAS-like
                kernel.
        """
        # TODO using dictionaries for arg_vals does not make a lot of sense
        # the lookup functionality is never used!

        types = [type(variant) for variant in self.variants]
        # iterating over expression variants
        for arg, expr in self.expr_variant.arg_vals.items():

            arguments_copy1 = [
                copy.copy(arg) for arg in self.arguments
            ]  # copy is not enough, deepcopy is not necessary

            if self.expr_variant.arg_name is not None:
                # This kernel does have expression variants.
                # TODO I don't like the previous comment.
                arguments_copy1.append(
                    ConstantArgument(self.expr_variant.arg_name, arg))

            # Dealing with the remaining kernel variants
            # Step 1: Generating all combinations.
            for arg_val_pairs in itertools.product(
                    *[variant.arg_vals.items() for variant in self.variants]):
                expr_copy2 = expr
                arguments_copy2 = [
                    copy.copy(arg) for arg in arguments_copy1
                ]  # copy is not enough, deepcopy is not necessary

                constraints_dict = dict()

                kernel_io = KernelIO()

                # Step 2: For each combination, iterating over all kernel variant objects.
                for i, _type in enumerate(types):
                    if _type == PropertyKV:
                        # Adding property constraints to constraint object
                        arg, property = arg_val_pairs[i]
                        # Properties can not be stored on the symbol, because
                        # the same Symbol instance is used for multiple
                        # generated kernels (the alternative would be to
                        # traverse expr_copy2, which contains a copy of the
                        # symbol).
                        constraints_dict.setdefault(
                            self.variants[i].wildcard_name,
                            set()).add(property)
                        arguments_copy2.append(
                            ConstantArgument(self.variants[i].arg_name, arg))
                    elif _type == DefaultValueKV:
                        # replacing symbol with default value
                        arg, value = arg_val_pairs[i]
                        if arg is not None:
                            expr_copy2 = replace_symbol(
                                expr_copy2, self.variants[i].operand, value)
                            kernel_io.add_input(
                                self.variants[i].operand, value,
                                storage_format.StorageFormat.full)
                    elif _type == OperatorKV:
                        # replacing operator placeholder with actual operator
                        arg, new_operator = arg_val_pairs[i]
                        expr_copy2 = replace_operator(
                            expr_copy2, self.variants[i].operator,
                            new_operator)
                        for argument in arguments_copy2:
                            if argument.operand:
                                argument.operand = replace_operator(
                                    argument.operand,
                                    self.variants[i].operator, new_operator)
                        arguments_copy2.append(
                            ConstantArgument(self.variants[i].arg_name, arg))

                expr_copy2 = simplify(expr_copy2)

                remaining_operands = []
                seen_before = set()
                for node, _ in expr_copy2.preorder_iter():
                    if isinstance(node,
                                  Symbol) and node.name not in seen_before:
                        remaining_operands.append(node)

                # print(expr_copy2)
                # print(self.signature)
                # print(remaining_operands)
                # print(kernel_io)
                # print(self.return_value.operand, self.return_value.storage_format)
                # print(self.wildcards)

                # Add property constraints (including Matrix, Vector and Scalar)
                # This has to be done after simplifying to make sure that only
                # those operands are included that actually show up in the
                # expression.
                constraints_list = []
                for operand in remaining_operands:
                    _wildcard_name = self.wildcards[operand.name].variable_name
                    if operand.properties or _wildcard_name in constraints_dict:
                        property_set = operand.properties.union(
                            constraints_dict.get(_wildcard_name, set()))
                        # Those properties can be removed because variables use symbol_type
                        property_set.difference_update(
                            (Property.MATRIX, Property.VECTOR,
                             Property.SCALAR))
                        if property_set:
                            constraints_list.append(
                                PropertyConstraint(_wildcard_name,
                                                   property_set))

                if self.constraints:
                    operand_names = set(op.name for op in remaining_operands)
                    renaming = {
                        name: self.wildcards[name].variable_name
                        for name in operand_names
                    }
                    for constraint in self.constraints:
                        if constraint.variables <= operand_names:
                            constraints_list.append(
                                constraint.with_renamed_vars(renaming))

                remaining_wildcards = [
                    self.wildcards[operand.name]
                    for operand in remaining_operands
                ]

                remaining_input_operands = []
                for input_operand in self.input_operands:
                    if input_operand.operand in remaining_wildcards:
                        remaining_input_operands.append(input_operand)

                # print([io.operand for io in remaining_input_operands])

                # Replace symbols with wildcards in the expression
                expr_copy2 = replace_symbols(expr_copy2, self.wildcards)

                # print(constraints_list)

                # print(self.wildcards)
                # print(repr(expr_copy2))
                # print(constraints)

                kernel_io.replace_variables(self.wildcards)
                for kernel_type in self.kernel_types:

                    if kernel_type is KernelType.scaling:
                        expr_copy3 = Times(*reversed(expr_copy2.operands))
                    else:
                        expr_copy3 = expr_copy2

                    yield ReductionKernel(
                        matchpy.Pattern(expr_copy3, *constraints_list),
                        remaining_input_operands, self.return_value,
                        self.cost_function, self.pre_code, self.signature,
                        self.post_code, arguments_copy2, kernel_type,
                        kernel_io)
Esempio n. 14
0
    def __init__(self,
                 pattern,
                 input_operands,
                 output_operand,
                 cost_function,
                 pre_code,
                 signature,
                 post_code,
                 arguments,
                 type=KernelType.identity,
                 kernel_io=None):
        super().__init__(cost_function, pre_code, signature, post_code,
                         arguments)
        """The wildcard which represents the output operand. If it also appears
        in the pattern, it is assumed that it gets overwritten.
        """
        self.output_operand = output_operand

        self.input_operands = input_operands

        self.is_overwriting = False
        for input_operand in self.input_operands:
            if input_operand.operand == self.output_operand.operand:
                self.is_overwriting = True
                break

        self.kernel_io = kernel_io
        if not self.kernel_io:
            self.kernel_io = KernelIO()

        # print(self.input_operands)
        # print(self.output_operand)

        # Template for the operation that is computed. Is used to construct the
        # temporary.
        # Does not contain context.
        # Does not take type into account.
        # Example TRSV (transpose type):
        #   Times(Inverse(X_), y_)
        #   (X is wildcard for a matrix, y for a vector)
        self.operation_template = pattern

        # The pattern object to be used for matching.
        # Can be used for the matrix chain algorithm.
        # Does not contain context.
        # Takes type into account.
        # Example TRSV (transpose type):
        #   Times(Transpose(y_), InverseTranspose(X_))
        self.pattern = pattern

        # Template for the replacement.
        # Does not contain context.
        # Takes type into account.
        # Example TRSV (transpose type):
        #   Transpose(op)
        self.replacement_template = _op

        # TODO is it necessary to simplify?

        pattern_expr = self.pattern.expression

        # Modifying pattern and replacement to take type into account.
        # TODO this doesn't work for some kernels (e.g. diagsv), because
        # variables don't have properties that might be necessary to simplify
        # expressions. Those properties are part of the constraints, not the
        # expression itself.
        if type == KernelType.identity:
            pass
        elif type == KernelType.transpose:
            pattern_expr = transpose(pattern_expr)
            # TODO ugly hack
            for constraint in self.pattern.constraints:
                if isinstance(constraint, PropertyConstraint) and (
                    (Property.SQUARE in constraint.properties
                     and Property.DIAGONAL in constraint.properties)
                        or Property.SYMMETRIC in constraint.properties):
                    pattern_expr = remove_transpose(pattern_expr,
                                                    constraint.variable)
            self.replacement_template = transpose(self.replacement_template)
        elif type == KernelType.conjugate_transpose:
            pattern_expr = conjugate_transpose(pattern_expr)
            self.replacement_template = conjugate_transpose(
                self.replacement_template)

        self.pattern = matchpy.Pattern(pattern_expr, *self.pattern.constraints)

        # The pattern object to be used for matching.
        # Contains context.
        # Takes type into account.
        # Example TRSV (transpose type):
        #   Times(ctx1___, Transpose(y_), InverseTranspose(X_), ctx2___)
        self.pattern_with_context = pattern

        # Template for the replacement.
        # Contains context.
        # Takes type into account.
        # Example TRSV (transpose type):
        #   Times(ctx1, Transpose(op), ctx2)
        self.replacement_with_context_template = None

        # adding contexts
        if isinstance(pattern.expression, Times):
            pattern_with_context_expr = Times(ctx1, self.pattern.expression,
                                              ctx2)
            self.replacement_with_context_template = Times(
                ctx1, self.replacement_template, ctx2)
        elif isinstance(pattern.expression, Plus):
            pattern_with_context_expr = Plus(ctx1, self.pattern.expression)
            self.replacement_with_context_template = Plus(
                ctx1, self.replacement_template)
        else:  # For unary kernels
            pattern_with_context_expr = self.pattern.expression
            self.replacement_with_context_template = self.replacement_template

        self.pattern_with_context = matchpy.Pattern(pattern_with_context_expr,
                                                    *self.pattern.constraints)

        property_list = []
        for expr, pos in self.pattern.expression.preorder_iter():
            if isinstance(expr, matchpy.Wildcard):
                for constraint in self.pattern.constraints:
                    if isinstance(
                            constraint, PropertyConstraint
                    ) and constraint.variable == expr.variable_name:
                        property_list.append(constraint.properties)
                        break
                else:
                    property_list.append(frozenset())
        self.property_tuple = PropertyTuple(property_list)
Esempio n. 15
0
def rule(pat, call):
    return matchpy.ReplacementRule(matchpy.Pattern(pat), call)
Esempio n. 16
0
import matchpy


#####################
# Cholesky
#
# id 0
#
# It's not possible to use uplo = 'U' here because that changes the output.

_A = matchpy.Wildcard.symbol("_A", symbol_type=ae.Matrix)
_L = matchpy.Wildcard.symbol("_L")

cholesky = FactorizationKernel(
    matchpy.Pattern(_A, PropertyConstraint("_A", {Property.SPSD})),
    [InputOperand(_A, StorageFormat.symmetric_lower_triangular)],
    Times(_L, Transpose(_L)),
    [OutputOperand(_L, _A, ("N", "N"), [Property.LOWER_TRIANGULAR, Property.NON_SINGULAR], StorageFormat.lower_triangular)],
    lambda N: (N**3)/3,
    CodeTemplate("""LAPACK.potrf!('L', $_A)"""),
    [SizeArgument("N", _A, "rows")],
    )


#####################
# LU with pivoting
#
# id 1

_A = matchpy.Wildcard.symbol("_A", symbol_type=ae.Matrix)