def eigen2_callback(substitution, equations, eqn_idx, position): # Here, the "Eigen-Trick" is applied. That is, given # Plus([Q W Q^T + alpha I]), tmp = W + alpha I is extraced and the # entire expression is replaced with Times([Q tmp Q^T]) # The sum can not be computed directly with decompose_sum # because it is not necessarily a sufficiently simple # sum. equations_list = list(equations.equations) diagonal_sum = Plus(substitution["SYM2"], Times(substitution["WD1"], substitution["SYM3"])) tmp = temporaries.create_tmp(diagonal_sum, True) if substitution["WS1"]: replacement = Plus( Times(substitution["SYM1"], tmp, Transpose(substitution["SYM1"])), *substitution["WS1"]) else: replacement = Times(substitution["SYM1"], tmp, Transpose(substitution["SYM1"])) equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx], (1, ) + tuple(position), replacement) equations_list[eqn_idx] = simplify(equations_list[eqn_idx]) new_equation = Equal(tmp, diagonal_sum) equations_list.insert(eqn_idx, new_equation) new_equations = Equations(*equations_list) return (new_equations, ())
def apply_kernel_anywhere(expr, discrimination_net): """Applies one kernel to an expressions. This function applies the first matching kernel of the discrimination_net to expr and returns the replacement, as well as the matched kernel. The functions searches for matches anywhere in expr, and it expects discrimination_net to use the pattern without context variables (Kernel.pattern). Args: expr (Expression): The expression where the kernel should be applied. discrimination_net (DiscriminationNet): Discrimination net for kernel patterns without context. The final_label has to be the kernel. Returns: A tuple containing - the replacement (Expression) - the matched kernel (MatchedKernel) If no match was found, the replacement is the input expression and instead of the matched kernel, None is returned. """ for node, pos in expr.preorder_iter(): for kernel, substitution in discrimination_net.match(node): matched_kernel = kernel.set_match(substitution, False) expr = matchpy.replace(expr, pos, matched_kernel.replacement) return expr, matched_kernel return (expr, None)
def trick4_callback(substitution, equations, eqn_idx, position): # A^T B + B^T A + A^T C A (C is symmetric) # A^T (B + 1/2 C A) + (B + 1/2 C A)^T A # WD1 = A # WD2 = B # WD3 = C equations_list = list(equations.equations) one_half = ConstantScalar(0.5) sum_expr = Plus(Times(one_half, substitution["WD3"], substitution["WD1"]), substitution["WD2"]) tmp = temporaries.create_tmp(sum_expr, True) new_equation = Equal(tmp, sum_expr) replacement = Plus(Times(Transpose(substitution["WD1"]), tmp), Times(Transpose(tmp), substitution["WD1"]), *substitution["WS1"]) equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx], (1, ) + position, replacement) equations_list.insert(eqn_idx, new_equation) new_equations = Equations(*equations_list) return (new_equations, ())
def apply_all(self, expr: Expression, max_count: Optional[int] = None) -> Expression: """Apply the rules :arg:`expr` until that's impossible :param expr: Expression to replace in. :param max_count: Maximum number of times to apply a rule, if any :returns: Expression with rule applied as much as possible""" any_change = True apply_count = 0 while any_change and (max_count is None or apply_count < max_count): any_change = False for subexpr, pos in expr.preorder_iter(): try: rule, subst = next(iter(self.matcher.match(subexpr))) new_subexpr = rule.apply_match(subst) new_expr = matchpy.replace(expr, pos, new_subexpr) if not isinstance(new_expr, Expression): raise TypeError( "Result of swapping part of an expression by an expression is not an expression" ) # NOQA else: expr = new_expr any_change = True apply_count += 1 break except StopIteration: pass return expr
def symmetric_product_callback(substitution, equations, eqn_idx, position): # symmetric product equations_list = list(equations.equations) # This trick is not applied if the current position is inside an inverse # because Cholesky is applied in that case anyway. expr = equations_list[eqn_idx].rhs for p in position: if is_inverse(expr): return expr = expr[p] # There is no need to check the type of expr here again because it is Times. matched_kernel = collections_module.cholesky.set_match( {"_A": substitution["_A"]}, False) replacement = Times(*substitution["WP1"], matched_kernel.replacement, *substitution["WP2"]) equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx], (1, ) + tuple(position), replacement) new_equations = Equations(*equations_list) return (new_equations, (matched_kernel, ))
def apply_matrix_chain_algorithm(equations, eqn_idx, initial_pos, explicit_inversion=False): try: msc = matrix_chain_solver.MatrixChainSolver( equations[eqn_idx][initial_pos], explicit_inversion) except matrix_chain_solver.MatrixChainNotComputable: return replacement = msc.tmp matched_kernels = msc.matched_kernels new_equation = matchpy.replace(equations[eqn_idx], initial_pos, replacement) equations_copy = equations.set(eqn_idx, new_equation) equations_copy = equations_copy.to_normalform() temporaries.set_equivalent_upwards(equations[eqn_idx].rhs, equations_copy[eqn_idx].rhs) # remove_identities has to be called after set_equivalent because # after remove_identities, eqn_idx may not be correct anymore equations_copy = equations_copy.remove_identities() yield (equations_copy, matched_kernels)
def apply_reductions(equations, eqn_idx, initial_pos): initial_node = equations[eqn_idx][initial_pos] for node, _pos in initial_node.preorder_iter(): pos = initial_pos + _pos for grouped_kernels in collections_module.reduction_MA.match( node).grouped(): kernel, substitution = select_optimal_match(grouped_kernels) matched_kernel = kernel.set_match(substitution, True) if is_blocked(matched_kernel.operation.rhs): continue evaled_repl = matched_kernel.replacement new_equation = matchpy.replace(equations[eqn_idx], pos, evaled_repl) equations_copy = equations.set(eqn_idx, new_equation) equations_copy = equations_copy.to_normalform() temporaries.set_equivalent(equations[eqn_idx].rhs, equations_copy[eqn_idx].rhs) # remove_identities has to be called after set_equivalent because # after remove_identities, eqn_idx may not be correct anymore equations_copy = equations_copy.remove_identities() yield (equations_copy, (matched_kernel, ))
def apply_sum_algorithm(equations, eqn_idx, initial_pos): # Note: For addition, we decided to only use a binary kernel, no # variadic addition. new_expr, matched_kernels = decompose_sum(equations[eqn_idx][initial_pos]) new_equation = matchpy.replace(equations[eqn_idx], initial_pos, new_expr) equations_copy = equations.set(eqn_idx, new_equation) equations_copy = equations_copy.to_normalform().remove_identities() yield (equations_copy, matched_kernels)
def TR_unary_kernels(self, expression): transformed_expressions = [] # iterate over all subexpressions for node, pos in expression.preorder_iter(): kernel, substitution = select_optimal_match( collections_module.unary_kernel_DN.match(node)) if kernel: matched_kernel = kernel.set_match(substitution, False) transformed_expression = matchpy.replace( expression, pos, matched_kernel.replacement) transformed_expressions.append( (transformed_expression, (matched_kernel, ))) return transformed_expressions
def symmetric_product_callback(substitution, equations, eqn_idx, position): # symmetric product equations_list = list(equations.equations) matched_kernel = collections_module.cholesky.set_match( {"_A": substitution["_A"]}, False) replacement = Times(*substitution["WP1"], matched_kernel.replacement, *substitution["WP2"]) equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx], (1, ) + tuple(position), replacement) new_equations = Equations(*equations_list) new_equations.set_equivalent(equations) return (new_equations, (matched_kernel, ))
def apply_each_once(self, expr: Expression, only: Optional[Container[RewriteRule]] = None) ->\ Iterable[Tuple[RewriteRule, Expression]]: # NOQA """Apply each rule in the set once to :param:`expr` if possible. :param expr: Expression to match against. :param only: If present, only record matches from the given rules. :returns: A map from rewrite rules to the expressions they produced. If a rule matches multiple times, the outermost match is returned.""" for subexpr, pos in expr.preorder_iter(): for rule, subst in self.matcher.match(subexpr): if only is None or rule in only: new_subexpr = rule.apply_match(subst) new_expr = matchpy.replace(expr, pos, new_subexpr) if not isinstance(new_expr, Expression): raise TypeError( "Result of swapping part of an expression by an expression is not an expression" ) # NOQA yield (rule, new_expr)
def apply_unary_kernels(equations, eqn_idx, initial_pos): initial_node = equations[eqn_idx][initial_pos] # iterate over all subexpressions for node, _pos in initial_node.preorder_iter(): pos = initial_pos + _pos kernel, substitution = select_optimal_match( collections_module.unary_kernel_DN.match(node)) if kernel: matched_kernel = kernel.set_match(substitution, False) evaled_repl = matched_kernel.replacement equations_copy = equations.set( eqn_idx, matchpy.replace(equations[eqn_idx], pos, evaled_repl)) equations_copy = equations_copy.to_normalform().remove_identities() yield (equations_copy, (matched_kernel, ))
def trick3_callback(substitution, equations, eqn_idx, position): # A^T B + B^T A + A^T A # A^T (B + 1/2 A) + (B + 1/2 A)^T A # WD1 = A # WD2 = B equations_list = list(equations.equations) one_half = ConstantScalar(0.5) sum_expr = Plus(Times(one_half, substitution["WD1"]), substitution["WD2"]) tmp, matched_kernels = decompose_sum(sum_expr) replacement = Plus(Times(Transpose(substitution["WD1"]), tmp), Times(Transpose(tmp), substitution["WD1"]), *substitution["WS1"]) equations_list[eqn_idx] = matchpy.replace(equations_list[eqn_idx], (1, ) + position, replacement) new_equations = Equations(*equations_list) return (new_equations, matched_kernels)