Пример #1
0
    def remove_identities(self):
        """Removes equations of the form tmpX = tmpY.

        There are three different cases:
        
        If both sides of the equation are the same temporary, the equation can
        simply be removed.

        If both sides are different temporaries, all occurrences of the
        temporary on the left-hand side in the other equations are replaced by
        the temporary on the right-hand side. This is necessary to ensure that
        code generation still works. This case can happen when setting
        equivalent expressions does not work perfectly.

        If the operand on the left-hand side is an intermediate, and the one on
        the right hand side is a temporary, all occurrences of the intermediate
        will be replace with the temporary.

        Replacing the temporary in the other equations only works if the
        temporary that is replaced is not part of any computation yet.

        Replacing temporaries even works if there are sequences of assignments
        such as:
        tmp2 = tmp3
        tmp1 = tmp2

        Returns:
            Equations: self without identities, and with replaced temporaries.
        """

        new_equations = []
        mapping = dict()
        for equation in self.equations:
            if temporaries.is_temporary(equation.rhs):
                if equation.lhs != equation.rhs:
                    try:
                        op = mapping[equation.rhs]
                    except KeyError:
                        mapping[equation.lhs] = equation.rhs
                    else:
                        mapping[equation.lhs] = op
                    if temporaries.is_temporary(equation.lhs):
                        continue
                else:
                    continue

            if mapping:
                rules = [
                    matchpy.ReplacementRule(matchpy.Pattern(rhs),
                                            Replacer(lhs))
                    for rhs, lhs in mapping.items()
                ]
                new_equations.append(
                    ae.Equal(equation.lhs,
                             matchpy.replace_all(equation.rhs, rules)))
            else:
                new_equations.append(equation)

        return Equations(*new_equations)
Пример #2
0
    def remove_identities(self):
        """Removes equations where both sides are temporaries.

        There are three different cases.

        If both sides of an equation are the same, it can simply be removed.

        If both sides are different temporaries, all occurrences of the
        temporary on the left-hand side in the other equations are replaced by
        the temporary on the right-hand side. This is necessary to ensure that
        code generation still works. This case can happen when setting
        equivalent expressions does not work perfectly.

        If the operand on the left-hand side is an intermediate, and the one on
        the right hand side is a temporary, all occurrences of the intermediate
        will be replace with the temporary.

        Replacing the temporary in the other equations only works if the
        temporary that is replaced is not part of any computation yet.

        Furthermore, it is assumed that there are no cases such as
        tmp2 = tmp3
        tmp1 = tmp2

        Returns:
            Equations: self with equations removed.
        """
        remove = []
        replace_eqns = []
        equations = list(self.equations)
        for n, equation in enumerate(equations):
            if equation.rhs.name in temporaries._equivalent_expressions:
                if equation.lhs.name in temporaries._equivalent_expressions:
                    remove.append(n)
                    if equation.lhs != equation.rhs:
                        replace_eqns.append(equation)
                else:
                    # TODO it would be better to only do this if the operand is known to be an intermediate.
                    replace_eqns.append(equation)

        for idx in reversed(remove):
            del equations[idx]

        for replace_eqn in replace_eqns:
            rule = matchpy.ReplacementRule(matchpy.Pattern(replace_eqn.lhs),
                                           lambda: replace_eqn.rhs)
            equations_replaced = []
            for equation in equations:
                equations_replaced.append(
                    ae.Equal(equation.lhs,
                             matchpy.replace_all(equation.rhs, (rule, ))))
            equations = equations_replaced

        # TODO do we want to manipulate the table of temporaries here?
        return Equations(*equations)
Пример #3
0
    def remove_explicit_transposition(self, eqn_idx=None):
        """Removes equations of the form tmpX = tmpY^T.

        With common subexpression elimination and the application of tranposed
        operands, it is possible to reach assignments of the form tmpX = tmpY^T.
        To avoid that an explicit transposition is computed, this function
        replaces tmpX in all subsequent assignments with tmpY^T.

        Args:
            eqn_idx (int, optional): Only test if self.equations[eqn_idx] is an
                explicit transposition.

        Returns:
            Equations: self with explicit transposition removed.
        """

        indices = None
        if eqn_idx:
            indices = [eqn_idx]
        else:
            indices = range(len(self.equations))

        replacement_rules = []
        remove = set()
        for idx in indices:
            equation = self.equations[idx]
            if temporaries.is_temporary(equation.lhs) and isinstance(
                    equation.rhs, ae.Transpose) and isinstance(
                        equation.rhs.operand, ae.Symbol):
                if equation.lhs != equation.rhs.operand:
                    # In they are equal, the operand is symmetric, which means
                    # that this assignment is an identity and can be removed.
                    replacement_rules.append(
                        matchpy.ReplacementRule(matchpy.Pattern(equation.lhs),
                                                Replacer(equation.rhs)))
                remove.add(idx)

        if replacement_rules:
            new_equations = []
            for i, equation in enumerate(self.equations):
                if not i in remove:
                    if i < min(remove):
                        # It is sufficient to start replacing at the smallest index in remove.
                        new_equations.append(equation)
                    else:
                        new_equations.append(
                            ae.Equal(
                                equation.lhs,
                                matchpy.replace_all(equation.rhs,
                                                    replacement_rules)))
            return Equations(*new_equations)
        else:
            return self
Пример #4
0
def register(*args):
    pattern, *constraints, replacement = args
    replacer.add(
        matchpy.ReplacementRule(matchpy.Pattern(pattern, *constraints),
                                replacement))
Пример #5
0
def register(pattern: _T, replacement: typing.Callable[..., _T],
             *constraints: matchpy.Constraint) -> None:
    replacer.add(
        matchpy.ReplacementRule(matchpy.Pattern(pattern, *constraints),
                                replacement))
Пример #6
0
                os.remove(path_to_file)


WD1 = matchpy.Wildcard.dot("WD1")
WD2 = matchpy.Wildcard.dot("WD2")
WS1 = matchpy.Wildcard.star("WS1")
WS2 = matchpy.Wildcard.star("WS2")
PS1 = matchpy.CustomConstraint(lambda WD1: WD1.has_property(Property.MATRIX) or
                               WD1.has_property(Property.VECTOR))
PS2 = matchpy.CustomConstraint(lambda WD2: WD2.has_property(Property.MATRIX) or
                               WD2.has_property(Property.VECTOR))
notInv1 = matchpy.CustomConstraint(lambda WD1: not is_inverse(WD1))
notInv2 = matchpy.CustomConstraint(lambda WD2: not is_inverse(WD2))

linsolveL = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, Inverse(WD1), WD2, WS2), PS1, PS2),
    lambda WS1, WD1, WD2, WS2: Times(*WS1, LinSolveL(WD1, WD2), *WS2))

linsolveLT = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, InverseTranspose(WD1), WD2, WS2), PS1,
                    PS2), lambda WS1, WD1, WD2, WS2: Times(
                        *WS1, LinSolveL(Transpose(WD1), WD2), *WS2))

linsolveR = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, WD1, Inverse(WD2), WS2), PS1, PS2),
    lambda WS1, WD1, WD2, WS2: Times(*WS1, LinSolveR(WD1, WD2), *WS2))

linsolveRT = matchpy.ReplacementRule(
    matchpy.Pattern(Times(WS1, WD1, InverseTranspose(WD2), WS2), PS1,
                    PS2), lambda WS1, WD1, WD2, WS2: Times(
                        *WS1, LinSolveR(WD1, Transpose(WD2)), *WS2))
Пример #7
0
    def _set_match(self, match_dict, input_expr):
        """Auxiliary function for set_match()

        Computes only those things that are independent of whether temporaries
        are reused or not.
        """

        # Constructing input.
        kernel_io = KernelIO()
        for input_operand in self.input_operands:
            kernel_io.add_input(
                input_operand.operand,
                match_dict[input_operand.operand.variable_name],
                input_operand.storage_format)

        _arg_dict = dict()
        for arg in self.arguments:
            if isinstance(arg, SizeArgument):
                _arg_dict[arg.name] = arg.get_value(match_dict)

        _partial_operand_dict = dict()

        # replacement_dict maps wildcard names to operands
        replacement_dict = dict()
        for output_operand in self.output_operands:
            # output_operand.operand.name[1:] because it's a Wildcard, we drop the _
            name = "".join([
                output_operand.operand.variable_name[1:],
                temporaries.get_identifier()
            ])
            size = (_arg_dict[output_operand.size[0]],
                    _arg_dict[output_operand.size[1]])
            # TODO what if the output is a scalar? Check sizes.

            operand = Matrix(name, size, input_expr.indices)
            operand.set_property(Property.FACTOR)
            operand.factorization_labels = set(
                operand[0].name for operand in kernel_io.input_operands)
            for property in output_operand.properties:
                operand.set_property(property)

            replacement_dict[output_operand.operand.variable_name] = operand

            # Constructing output.
            if output_operand.overwriting:
                kernel_io.add_output(output_operand.overwriting, operand,
                                     output_operand.storage_format)
            else:
                kernel_io.add_output(output_operand.operand, operand,
                                     output_operand.storage_format)
                _partial_operand_dict[
                    output_operand.operand.variable_name] = operand

        _output_expr = matchpy.substitute(self.replacement_template,
                                          replacement_dict)

        input_equiv = temporaries.get_equivalent(input_expr)
        temporaries.equivalence_replacer.add(
            matchpy.ReplacementRule(
                matchpy.Pattern(Times(ctx1, _output_expr, ctx2)),
                lambda ctx1, ctx2: Times(*ctx1, input_equiv, *ctx2)))
        if input_expr.has_property(Property.SQUARE):
            temporaries.equivalence_replacer.add(
                matchpy.ReplacementRule(
                    matchpy.Pattern(Times(ctx1, invert(_output_expr), ctx2)),
                    lambda ctx1, ctx2: Times(*ctx1, invert(input_equiv), *ctx2
                                             )))
        # There is no need to generate transposed pattern for factorizations
        # with symmetric output; Cholesky (id 0) and Eigen (id 4).
        if self.id in {1, 2, 3, 5, 6, 7}:
            temporaries.equivalence_replacer.add(
                matchpy.ReplacementRule(
                    matchpy.Pattern(Times(ctx1, transpose(_output_expr),
                                          ctx2)), lambda ctx1, ctx2:
                    Times(*ctx1, transpose(input_equiv), *ctx2)))
            if input_expr.has_property(Property.SQUARE):
                temporaries.equivalence_replacer.add(
                    matchpy.ReplacementRule(
                        matchpy.Pattern(
                            Times(ctx1, invert_transpose(_output_expr), ctx2)),
                        lambda ctx1, ctx2: Times(
                            *ctx1, invert_transpose(input_equiv), *ctx2)))

        return _output_expr, _arg_dict, _partial_operand_dict, kernel_io
Пример #8
0
def rule(pat, call):
    return matchpy.ReplacementRule(matchpy.Pattern(pat), call)