Ejemplo n.º 1
0
def eigen2_callback(substitution, equations, eqn_idx, position):
    # Here, the "Eigen-Trick" is applied. That is, given
    # Plus([Q W Q^T + alpha I]), tmp = W + alpha I is extraced and the
    # entire expression is replaced with Times([Q tmp Q^T])
    # The sum can not be computed directly with decompose_sum
    # because it is not necessarily a sufficiently simple
    # sum.
    equations_list = list(equations.equations)

    diagonal_sum = Plus(substitution["SYM2"],
                        Times(substitution["WD1"], substitution["SYM3"]))
    tmp = temporaries.create_tmp(diagonal_sum, True)

    if substitution["WS1"]:
        replacement = Plus(
            Times(substitution["SYM1"], tmp, Transpose(substitution["SYM1"])),
            *substitution["WS1"])
    else:
        replacement = Times(substitution["SYM1"], tmp,
                            Transpose(substitution["SYM1"]))
    equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx],
                                              (1, ) + tuple(position),
                                              replacement)
    equations_list[eqn_idx] = simplify(equations_list[eqn_idx])

    new_equation = Equal(tmp, diagonal_sum)
    equations_list.insert(eqn_idx, new_equation)
    new_equations = Equations(*equations_list)

    return (new_equations, ())
Ejemplo n.º 2
0
def apply_kernel_anywhere(expr, discrimination_net):
    """Applies one kernel to an expressions.

    This function applies the first matching kernel of the discrimination_net to
    expr and returns the replacement, as well as the matched kernel.

    The functions searches for matches anywhere in expr, and it expects
    discrimination_net to use the pattern without context variables
    (Kernel.pattern).

    Args:
        expr (Expression): The expression where the kernel should be applied.
        discrimination_net (DiscriminationNet): Discrimination net for kernel
        patterns without context. The final_label has to be the kernel.

    Returns:
        A tuple containing
        - the replacement (Expression)
        - the matched kernel (MatchedKernel)
        If no match was found, the replacement is the input expression and
        instead of the matched kernel, None is returned.
    """
    for node, pos in expr.preorder_iter():
        for kernel, substitution in discrimination_net.match(node):
            matched_kernel = kernel.set_match(substitution, False)
            expr = matchpy.replace(expr, pos, matched_kernel.replacement)
            return expr, matched_kernel
    return (expr, None)
Ejemplo n.º 3
0
def trick4_callback(substitution, equations, eqn_idx, position):
    # A^T B + B^T A + A^T C A (C is symmetric)
    # A^T (B + 1/2 C A) + (B + 1/2 C A)^T A
    # WD1 = A
    # WD2 = B
    # WD3 = C

    equations_list = list(equations.equations)

    one_half = ConstantScalar(0.5)
    sum_expr = Plus(Times(one_half, substitution["WD3"], substitution["WD1"]),
                    substitution["WD2"])
    tmp = temporaries.create_tmp(sum_expr, True)
    new_equation = Equal(tmp, sum_expr)

    replacement = Plus(Times(Transpose(substitution["WD1"]), tmp),
                       Times(Transpose(tmp), substitution["WD1"]),
                       *substitution["WS1"])

    equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx],
                                              (1, ) + position, replacement)

    equations_list.insert(eqn_idx, new_equation)
    new_equations = Equations(*equations_list)

    return (new_equations, ())
Ejemplo n.º 4
0
    def apply_all(self,
                  expr: Expression,
                  max_count: Optional[int] = None) -> Expression:
        """Apply the rules :arg:`expr` until that's impossible

        :param expr: Expression to replace in.
        :param max_count: Maximum number of times to apply a rule, if any
        :returns: Expression with rule applied as much as possible"""
        any_change = True
        apply_count = 0
        while any_change and (max_count is None or apply_count < max_count):
            any_change = False
            for subexpr, pos in expr.preorder_iter():
                try:
                    rule, subst = next(iter(self.matcher.match(subexpr)))
                    new_subexpr = rule.apply_match(subst)
                    new_expr = matchpy.replace(expr, pos, new_subexpr)
                    if not isinstance(new_expr, Expression):
                        raise TypeError(
                            "Result of swapping part of an expression by an expression is not an expression"
                        )  # NOQA
                    else:
                        expr = new_expr
                    any_change = True
                    apply_count += 1
                    break
                except StopIteration:
                    pass
        return expr
Ejemplo n.º 5
0
def symmetric_product_callback(substitution, equations, eqn_idx, position):
    # symmetric product
    equations_list = list(equations.equations)

    # This trick is not applied if the current position is inside an inverse
    # because Cholesky is applied in that case anyway.
    expr = equations_list[eqn_idx].rhs
    for p in position:
        if is_inverse(expr):
            return
        expr = expr[p]
    # There is no need to check the type of expr here again because it is Times.

    matched_kernel = collections_module.cholesky.set_match(
        {"_A": substitution["_A"]}, False)

    replacement = Times(*substitution["WP1"], matched_kernel.replacement,
                        *substitution["WP2"])

    equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx],
                                              (1, ) + tuple(position),
                                              replacement)

    new_equations = Equations(*equations_list)

    return (new_equations, (matched_kernel, ))
Ejemplo n.º 6
0
def apply_matrix_chain_algorithm(equations,
                                 eqn_idx,
                                 initial_pos,
                                 explicit_inversion=False):

    try:
        msc = matrix_chain_solver.MatrixChainSolver(
            equations[eqn_idx][initial_pos], explicit_inversion)
    except matrix_chain_solver.MatrixChainNotComputable:
        return

    replacement = msc.tmp
    matched_kernels = msc.matched_kernels

    new_equation = matchpy.replace(equations[eqn_idx], initial_pos,
                                   replacement)
    equations_copy = equations.set(eqn_idx, new_equation)
    equations_copy = equations_copy.to_normalform()

    temporaries.set_equivalent_upwards(equations[eqn_idx].rhs,
                                       equations_copy[eqn_idx].rhs)
    # remove_identities has to be called after set_equivalent because
    # after remove_identities, eqn_idx may not be correct anymore
    equations_copy = equations_copy.remove_identities()

    yield (equations_copy, matched_kernels)
Ejemplo n.º 7
0
def apply_reductions(equations, eqn_idx, initial_pos):

    initial_node = equations[eqn_idx][initial_pos]

    for node, _pos in initial_node.preorder_iter():
        pos = initial_pos + _pos

        for grouped_kernels in collections_module.reduction_MA.match(
                node).grouped():

            kernel, substitution = select_optimal_match(grouped_kernels)

            matched_kernel = kernel.set_match(substitution, True)
            if is_blocked(matched_kernel.operation.rhs):
                continue

            evaled_repl = matched_kernel.replacement

            new_equation = matchpy.replace(equations[eqn_idx], pos,
                                           evaled_repl)

            equations_copy = equations.set(eqn_idx, new_equation)
            equations_copy = equations_copy.to_normalform()

            temporaries.set_equivalent(equations[eqn_idx].rhs,
                                       equations_copy[eqn_idx].rhs)
            # remove_identities has to be called after set_equivalent because
            # after remove_identities, eqn_idx may not be correct anymore
            equations_copy = equations_copy.remove_identities()

            yield (equations_copy, (matched_kernel, ))
Ejemplo n.º 8
0
def apply_sum_algorithm(equations, eqn_idx, initial_pos):

    # Note: For addition, we decided to only use a binary kernel, no
    #       variadic addition.

    new_expr, matched_kernels = decompose_sum(equations[eqn_idx][initial_pos])

    new_equation = matchpy.replace(equations[eqn_idx], initial_pos, new_expr)
    equations_copy = equations.set(eqn_idx, new_equation)
    equations_copy = equations_copy.to_normalform().remove_identities()

    yield (equations_copy, matched_kernels)
Ejemplo n.º 9
0
    def TR_unary_kernels(self, expression):
        transformed_expressions = []

        # iterate over all subexpressions
        for node, pos in expression.preorder_iter():
            kernel, substitution = select_optimal_match(
                collections_module.unary_kernel_DN.match(node))

            if kernel:
                matched_kernel = kernel.set_match(substitution, False)
                transformed_expression = matchpy.replace(
                    expression, pos, matched_kernel.replacement)
                transformed_expressions.append(
                    (transformed_expression, (matched_kernel, )))

        return transformed_expressions
Ejemplo n.º 10
0
def symmetric_product_callback(substitution, equations, eqn_idx, position):
    # symmetric product
    equations_list = list(equations.equations)

    matched_kernel = collections_module.cholesky.set_match(
        {"_A": substitution["_A"]}, False)

    replacement = Times(*substitution["WP1"], matched_kernel.replacement,
                        *substitution["WP2"])

    equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx],
                                              (1, ) + tuple(position),
                                              replacement)

    new_equations = Equations(*equations_list)
    new_equations.set_equivalent(equations)

    return (new_equations, (matched_kernel, ))
Ejemplo n.º 11
0
    def apply_each_once(self, expr: Expression,
                        only: Optional[Container[RewriteRule]] = None) ->\
                       Iterable[Tuple[RewriteRule, Expression]]:  # NOQA
        """Apply each rule in the set once to :param:`expr` if possible.

        :param expr: Expression to match against.
        :param only: If present, only record matches from the given rules.
        :returns: A map from rewrite rules to the expressions they produced.
        If a rule matches multiple times, the outermost match is returned."""
        for subexpr, pos in expr.preorder_iter():
            for rule, subst in self.matcher.match(subexpr):
                if only is None or rule in only:
                    new_subexpr = rule.apply_match(subst)
                    new_expr = matchpy.replace(expr, pos, new_subexpr)
                    if not isinstance(new_expr, Expression):
                        raise TypeError(
                            "Result of swapping part of an expression by an expression is not an expression"
                        )  # NOQA
                    yield (rule, new_expr)
Ejemplo n.º 12
0
def apply_unary_kernels(equations, eqn_idx, initial_pos):

    initial_node = equations[eqn_idx][initial_pos]

    # iterate over all subexpressions
    for node, _pos in initial_node.preorder_iter():
        pos = initial_pos + _pos

        kernel, substitution = select_optimal_match(
            collections_module.unary_kernel_DN.match(node))

        if kernel:
            matched_kernel = kernel.set_match(substitution, False)
            evaled_repl = matched_kernel.replacement

            equations_copy = equations.set(
                eqn_idx, matchpy.replace(equations[eqn_idx], pos, evaled_repl))
            equations_copy = equations_copy.to_normalform().remove_identities()

            yield (equations_copy, (matched_kernel, ))
Ejemplo n.º 13
0
def trick3_callback(substitution, equations, eqn_idx, position):
    # A^T B + B^T A + A^T A
    # A^T (B + 1/2 A) + (B + 1/2 A)^T A
    # WD1 = A
    # WD2 = B

    equations_list = list(equations.equations)
    one_half = ConstantScalar(0.5)
    sum_expr = Plus(Times(one_half, substitution["WD1"]), substitution["WD2"])
    tmp, matched_kernels = decompose_sum(sum_expr)

    replacement = Plus(Times(Transpose(substitution["WD1"]), tmp),
                       Times(Transpose(tmp), substitution["WD1"]),
                       *substitution["WS1"])

    equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx],
                                              (1, ) + position, replacement)

    new_equations = Equations(*equations_list)

    return (new_equations, matched_kernels)