def remove_identities(self): """Removes equations of the form tmpX = tmpY. There are three different cases: If both sides of the equation are the same temporary, the equation can simply be removed. If both sides are different temporaries, all occurrences of the temporary on the left-hand side in the other equations are replaced by the temporary on the right-hand side. This is necessary to ensure that code generation still works. This case can happen when setting equivalent expressions does not work perfectly. If the operand on the left-hand side is an intermediate, and the one on the right hand side is a temporary, all occurrences of the intermediate will be replace with the temporary. Replacing the temporary in the other equations only works if the temporary that is replaced is not part of any computation yet. Replacing temporaries even works if there are sequences of assignments such as: tmp2 = tmp3 tmp1 = tmp2 Returns: Equations: self without identities, and with replaced temporaries. """ new_equations = [] mapping = dict() for equation in self.equations: if temporaries.is_temporary(equation.rhs): if equation.lhs != equation.rhs: try: op = mapping[equation.rhs] except KeyError: mapping[equation.lhs] = equation.rhs else: mapping[equation.lhs] = op if temporaries.is_temporary(equation.lhs): continue else: continue if mapping: rules = [ matchpy.ReplacementRule(matchpy.Pattern(rhs), Replacer(lhs)) for rhs, lhs in mapping.items() ] new_equations.append( ae.Equal(equation.lhs, matchpy.replace_all(equation.rhs, rules))) else: new_equations.append(equation) return Equations(*new_equations)
def remove_identities(self): """Removes equations where both sides are temporaries. There are three different cases. If both sides of an equation are the same, it can simply be removed. If both sides are different temporaries, all occurrences of the temporary on the left-hand side in the other equations are replaced by the temporary on the right-hand side. This is necessary to ensure that code generation still works. This case can happen when setting equivalent expressions does not work perfectly. If the operand on the left-hand side is an intermediate, and the one on the right hand side is a temporary, all occurrences of the intermediate will be replace with the temporary. Replacing the temporary in the other equations only works if the temporary that is replaced is not part of any computation yet. Furthermore, it is assumed that there are no cases such as tmp2 = tmp3 tmp1 = tmp2 Returns: Equations: self with equations removed. """ remove = [] replace_eqns = [] equations = list(self.equations) for n, equation in enumerate(equations): if equation.rhs.name in temporaries._equivalent_expressions: if equation.lhs.name in temporaries._equivalent_expressions: remove.append(n) if equation.lhs != equation.rhs: replace_eqns.append(equation) else: # TODO it would be better to only do this if the operand is known to be an intermediate. replace_eqns.append(equation) for idx in reversed(remove): del equations[idx] for replace_eqn in replace_eqns: rule = matchpy.ReplacementRule(matchpy.Pattern(replace_eqn.lhs), lambda: replace_eqn.rhs) equations_replaced = [] for equation in equations: equations_replaced.append( ae.Equal(equation.lhs, matchpy.replace_all(equation.rhs, (rule, )))) equations = equations_replaced # TODO do we want to manipulate the table of temporaries here? return Equations(*equations)
def remove_explicit_transposition(self, eqn_idx=None): """Removes equations of the form tmpX = tmpY^T. With common subexpression elimination and the application of tranposed operands, it is possible to reach assignments of the form tmpX = tmpY^T. To avoid that an explicit transposition is computed, this function replaces tmpX in all subsequent assignments with tmpY^T. Args: eqn_idx (int, optional): Only test if self.equations[eqn_idx] is an explicit transposition. Returns: Equations: self with explicit transposition removed. """ indices = None if eqn_idx: indices = [eqn_idx] else: indices = range(len(self.equations)) replacement_rules = [] remove = set() for idx in indices: equation = self.equations[idx] if temporaries.is_temporary(equation.lhs) and isinstance( equation.rhs, ae.Transpose) and isinstance( equation.rhs.operand, ae.Symbol): if equation.lhs != equation.rhs.operand: # In they are equal, the operand is symmetric, which means # that this assignment is an identity and can be removed. replacement_rules.append( matchpy.ReplacementRule(matchpy.Pattern(equation.lhs), Replacer(equation.rhs))) remove.add(idx) if replacement_rules: new_equations = [] for i, equation in enumerate(self.equations): if not i in remove: if i < min(remove): # It is sufficient to start replacing at the smallest index in remove. new_equations.append(equation) else: new_equations.append( ae.Equal( equation.lhs, matchpy.replace_all(equation.rhs, replacement_rules))) return Equations(*new_equations) else: return self
def __init__(self, left: Expression, right: Expression) -> None: """ :param left: Expression to find :param right: Expression to rewrite to. This must have same variables as :arg:`left` """ if not get_variables(right) <= get_variables(left): raise (ValueError( "Variables on right of rule with no equivalent on the left") ) # NOQA substitution = ManyToOneMatcher._collect_variable_renaming(left) self.left = rename_variables(left, substitution) self.right = rename_variables(right, substitution) self.lhs = matchpy.Pattern(self.left)
def replace_auxiliaries(self): """Replaces auxiliaries. If there is an equation X = rhs where X has the property "auxiliary", the equation is removed and all occurences of X in other equations are replaced with rhs. """ replacement_rules = [] remove = [] # IMPORTANT: eqn can NOT be renamed to equation. Otherwise, # "equation.rhs in the lambda function refers to the wrong equation. for eqn in self.equations: if eqn.lhs.has_property(Property.AUXILIARY): rule = (matchpy.Pattern(eqn.lhs), lambda **_: eqn.rhs) replacement_rules.append(rule) remove.append(eqn) for eqn in remove: self.equations.remove(eqn) self.equations = [matchpy.replace_all(equation, replacement_rules) for equation in self.equations]
##################### # Cholesky # # Note: We assume that all symmetric matrices are stored as lower triangular # matrices. Actually, that's not necessary in Julia, obtaining the other half is # just more expensive. Test which conversion is more expensive. # Actually, with storage format conversions, we don't need this assumption. _A = matchpy.Wildcard.symbol("_A", symbol_type=ae.Matrix) _L = matchpy.Wildcard.symbol("_L") cf = lambda d: (d["N"]**3) / 3 cholesky = FactorizationKernel( matchpy.Pattern( _A, matchpy.CustomConstraint(lambda _A: _A.has_property(Property.SPSD))), [InputOperand(_A, StorageFormat.symmetric_triangular)], Times(_L, Transpose(_L)), [ OutputOperand(_L, _A, ("N", "N"), [Property.LOWER_TRIANGULAR, Property.NON_SINGULAR], StorageFormat.lower_triangular) ], cf, None, CodeTemplate("""LinearAlgebra.LAPACK.potrf!('L', $_A)"""), None, [SizeArgument("N", _A, "rows")], )
WD1 = matchpy.Wildcard.dot("WD1") WD2 = matchpy.Wildcard.dot("WD2") WD3 = matchpy.Wildcard.dot("WD3") SYM1 = matchpy.Wildcard.symbol("SYM1") SYM2 = matchpy.Wildcard.symbol("SYM2") SYM3 = matchpy.Wildcard.symbol("SYM3") _A = matchpy.Wildcard.symbol( "_A" ) # It's important that this Wildcard has the same name as the one in the pattern for Cholesky WS1 = matchpy.Wildcard.star("WS1") WP1 = matchpy.Wildcard.plus("WP1") WP2 = matchpy.Wildcard.plus("WP2") eigen1 = matchpy.Pattern( Plus(Times(Transpose(SYM1), SYM2, SYM1), Times(WD1, SYM3), WS1), matchpy.CustomConstraint( lambda SYM1, SYM2, SYM3, WD1: SYM1.has_property(Property.ORTHOGONAL) and SYM2.has_property(Property.DIAGONAL) and WD1.has_property( Property.SCALAR) and SYM3.has_property(Property.IDENTITY))) def eigen1_callback(substitution, equations, eqn_idx, position): # Here, the "Eigen-Trick" is applied. That is, given # Plus([Q^T W Q + alpha I]), tmp = W + alpha I is extraced and the # entire expression is replaced with Times([Q^T tmp Q]) # The sum can not be computed directly with decompose_sum # because it is not necessarily a sufficiently simple # sum. equations_list = list(equations.equations) diagonal_sum = Plus(substitution["SYM2"], Times(substitution["WD1"], substitution["SYM3"]))
def register(*args): pattern, *constraints, replacement = args replacer.add( matchpy.ReplacementRule(matchpy.Pattern(pattern, *constraints), replacement))
def register(pattern: _T, replacement: typing.Callable[..., _T], *constraints: matchpy.Constraint) -> None: replacer.add( matchpy.ReplacementRule(matchpy.Pattern(pattern, *constraints), replacement))
os.remove(path_to_file) WD1 = matchpy.Wildcard.dot("WD1") WD2 = matchpy.Wildcard.dot("WD2") WS1 = matchpy.Wildcard.star("WS1") WS2 = matchpy.Wildcard.star("WS2") PS1 = matchpy.CustomConstraint(lambda WD1: WD1.has_property(Property.MATRIX) or WD1.has_property(Property.VECTOR)) PS2 = matchpy.CustomConstraint(lambda WD2: WD2.has_property(Property.MATRIX) or WD2.has_property(Property.VECTOR)) notInv1 = matchpy.CustomConstraint(lambda WD1: not is_inverse(WD1)) notInv2 = matchpy.CustomConstraint(lambda WD2: not is_inverse(WD2)) linsolveL = matchpy.ReplacementRule( matchpy.Pattern(Times(WS1, Inverse(WD1), WD2, WS2), PS1, PS2), lambda WS1, WD1, WD2, WS2: Times(*WS1, LinSolveL(WD1, WD2), *WS2)) linsolveLT = matchpy.ReplacementRule( matchpy.Pattern(Times(WS1, InverseTranspose(WD1), WD2, WS2), PS1, PS2), lambda WS1, WD1, WD2, WS2: Times( *WS1, LinSolveL(Transpose(WD1), WD2), *WS2)) linsolveR = matchpy.ReplacementRule( matchpy.Pattern(Times(WS1, WD1, Inverse(WD2), WS2), PS1, PS2), lambda WS1, WD1, WD2, WS2: Times(*WS1, LinSolveR(WD1, WD2), *WS2)) linsolveRT = matchpy.ReplacementRule( matchpy.Pattern(Times(WS1, WD1, InverseTranspose(WD2), WS2), PS1, PS2), lambda WS1, WD1, WD2, WS2: Times( *WS1, LinSolveR(WD1, Transpose(WD2)), *WS2))
def _set_match(self, match_dict, input_expr): """Auxiliary function for set_match() Computes only those things that are independent of whether temporaries are reused or not. """ # Constructing input. kernel_io = KernelIO() for input_operand in self.input_operands: kernel_io.add_input( input_operand.operand, match_dict[input_operand.operand.variable_name], input_operand.storage_format) _arg_dict = dict() for arg in self.arguments: if isinstance(arg, SizeArgument): _arg_dict[arg.name] = arg.get_value(match_dict) _partial_operand_dict = dict() # replacement_dict maps wildcard names to operands replacement_dict = dict() for output_operand in self.output_operands: # output_operand.operand.name[1:] because it's a Wildcard, we drop the _ name = "".join([ output_operand.operand.variable_name[1:], temporaries.get_identifier() ]) size = (_arg_dict[output_operand.size[0]], _arg_dict[output_operand.size[1]]) # TODO what if the output is a scalar? Check sizes. operand = Matrix(name, size, input_expr.indices) operand.set_property(Property.FACTOR) operand.factorization_labels = set( operand[0].name for operand in kernel_io.input_operands) for property in output_operand.properties: operand.set_property(property) replacement_dict[output_operand.operand.variable_name] = operand # Constructing output. if output_operand.overwriting: kernel_io.add_output(output_operand.overwriting, operand, output_operand.storage_format) else: kernel_io.add_output(output_operand.operand, operand, output_operand.storage_format) _partial_operand_dict[ output_operand.operand.variable_name] = operand _output_expr = matchpy.substitute(self.replacement_template, replacement_dict) input_equiv = temporaries.get_equivalent(input_expr) temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern(Times(ctx1, _output_expr, ctx2)), lambda ctx1, ctx2: Times(*ctx1, input_equiv, *ctx2))) if input_expr.has_property(Property.SQUARE): temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern(Times(ctx1, invert(_output_expr), ctx2)), lambda ctx1, ctx2: Times(*ctx1, invert(input_equiv), *ctx2 ))) # There is no need to generate transposed pattern for factorizations # with symmetric output; Cholesky (id 0) and Eigen (id 4). if self.id in {1, 2, 3, 5, 6, 7}: temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern(Times(ctx1, transpose(_output_expr), ctx2)), lambda ctx1, ctx2: Times(*ctx1, transpose(input_equiv), *ctx2))) if input_expr.has_property(Property.SQUARE): temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern( Times(ctx1, invert_transpose(_output_expr), ctx2)), lambda ctx1, ctx2: Times( *ctx1, invert_transpose(input_equiv), *ctx2))) return _output_expr, _arg_dict, _partial_operand_dict, kernel_io
def generate_matchpy_matcher(pattern_list): matcher = matchpy.ManyToOneMatcher() for pattern in pattern_list: matcher.add(matchpy.Pattern(pattern)) return matcher
def generate_kernels(self): """Generator for ReductionKernel objects. Yields: ReductionKernel: Each object represents one instance of a BLAS-like kernel. """ # TODO using dictionaries for arg_vals does not make a lot of sense # the lookup functionality is never used! types = [type(variant) for variant in self.variants] # iterating over expression variants for arg, expr in self.expr_variant.arg_vals.items(): arguments_copy1 = [ copy.copy(arg) for arg in self.arguments ] # copy is not enough, deepcopy is not necessary if self.expr_variant.arg_name is not None: # This kernel does have expression variants. # TODO I don't like the previous comment. arguments_copy1.append( ConstantArgument(self.expr_variant.arg_name, arg)) # Dealing with the remaining kernel variants # Step 1: Generating all combinations. for arg_val_pairs in itertools.product( *[variant.arg_vals.items() for variant in self.variants]): expr_copy2 = expr arguments_copy2 = [ copy.copy(arg) for arg in arguments_copy1 ] # copy is not enough, deepcopy is not necessary constraints_dict = dict() kernel_io = KernelIO() # Step 2: For each combination, iterating over all kernel variant objects. for i, _type in enumerate(types): if _type == PropertyKV: # Adding property constraints to constraint object arg, property = arg_val_pairs[i] # Properties can not be stored on the symbol, because # the same Symbol instance is used for multiple # generated kernels (the alternative would be to # traverse expr_copy2, which contains a copy of the # symbol). constraints_dict.setdefault( self.variants[i].wildcard_name, set()).add(property) arguments_copy2.append( ConstantArgument(self.variants[i].arg_name, arg)) elif _type == DefaultValueKV: # replacing symbol with default value arg, value = arg_val_pairs[i] if arg is not None: expr_copy2 = replace_symbol( expr_copy2, self.variants[i].operand, value) kernel_io.add_input( self.variants[i].operand, value, storage_format.StorageFormat.full) elif _type == OperatorKV: # replacing operator placeholder with actual operator arg, new_operator = arg_val_pairs[i] expr_copy2 = replace_operator( expr_copy2, self.variants[i].operator, new_operator) for argument in arguments_copy2: if argument.operand: argument.operand = replace_operator( argument.operand, self.variants[i].operator, new_operator) arguments_copy2.append( ConstantArgument(self.variants[i].arg_name, arg)) expr_copy2 = simplify(expr_copy2) remaining_operands = [] seen_before = set() for node, _ in expr_copy2.preorder_iter(): if isinstance(node, Symbol) and node.name not in seen_before: remaining_operands.append(node) # print(expr_copy2) # print(self.signature) # print(remaining_operands) # print(kernel_io) # print(self.return_value.operand, self.return_value.storage_format) # print(self.wildcards) # Add property constraints (including Matrix, Vector and Scalar) # This has to be done after simplifying to make sure that only # those operands are included that actually show up in the # expression. constraints_list = [] for operand in remaining_operands: _wildcard_name = self.wildcards[operand.name].variable_name if operand.properties or _wildcard_name in constraints_dict: property_set = operand.properties.union( constraints_dict.get(_wildcard_name, set())) # Those properties can be removed because variables use symbol_type property_set.difference_update( (Property.MATRIX, Property.VECTOR, Property.SCALAR)) if property_set: constraints_list.append( PropertyConstraint(_wildcard_name, property_set)) if self.constraints: operand_names = set(op.name for op in remaining_operands) renaming = { name: self.wildcards[name].variable_name for name in operand_names } for constraint in self.constraints: if constraint.variables <= operand_names: constraints_list.append( constraint.with_renamed_vars(renaming)) remaining_wildcards = [ self.wildcards[operand.name] for operand in remaining_operands ] remaining_input_operands = [] for input_operand in self.input_operands: if input_operand.operand in remaining_wildcards: remaining_input_operands.append(input_operand) # print([io.operand for io in remaining_input_operands]) # Replace symbols with wildcards in the expression expr_copy2 = replace_symbols(expr_copy2, self.wildcards) # print(constraints_list) # print(self.wildcards) # print(repr(expr_copy2)) # print(constraints) kernel_io.replace_variables(self.wildcards) for kernel_type in self.kernel_types: if kernel_type is KernelType.scaling: expr_copy3 = Times(*reversed(expr_copy2.operands)) else: expr_copy3 = expr_copy2 yield ReductionKernel( matchpy.Pattern(expr_copy3, *constraints_list), remaining_input_operands, self.return_value, self.cost_function, self.pre_code, self.signature, self.post_code, arguments_copy2, kernel_type, kernel_io)
def __init__(self, pattern, input_operands, output_operand, cost_function, pre_code, signature, post_code, arguments, type=KernelType.identity, kernel_io=None): super().__init__(cost_function, pre_code, signature, post_code, arguments) """The wildcard which represents the output operand. If it also appears in the pattern, it is assumed that it gets overwritten. """ self.output_operand = output_operand self.input_operands = input_operands self.is_overwriting = False for input_operand in self.input_operands: if input_operand.operand == self.output_operand.operand: self.is_overwriting = True break self.kernel_io = kernel_io if not self.kernel_io: self.kernel_io = KernelIO() # print(self.input_operands) # print(self.output_operand) # Template for the operation that is computed. Is used to construct the # temporary. # Does not contain context. # Does not take type into account. # Example TRSV (transpose type): # Times(Inverse(X_), y_) # (X is wildcard for a matrix, y for a vector) self.operation_template = pattern # The pattern object to be used for matching. # Can be used for the matrix chain algorithm. # Does not contain context. # Takes type into account. # Example TRSV (transpose type): # Times(Transpose(y_), InverseTranspose(X_)) self.pattern = pattern # Template for the replacement. # Does not contain context. # Takes type into account. # Example TRSV (transpose type): # Transpose(op) self.replacement_template = _op # TODO is it necessary to simplify? pattern_expr = self.pattern.expression # Modifying pattern and replacement to take type into account. # TODO this doesn't work for some kernels (e.g. diagsv), because # variables don't have properties that might be necessary to simplify # expressions. Those properties are part of the constraints, not the # expression itself. if type == KernelType.identity: pass elif type == KernelType.transpose: pattern_expr = transpose(pattern_expr) # TODO ugly hack for constraint in self.pattern.constraints: if isinstance(constraint, PropertyConstraint) and ( (Property.SQUARE in constraint.properties and Property.DIAGONAL in constraint.properties) or Property.SYMMETRIC in constraint.properties): pattern_expr = remove_transpose(pattern_expr, constraint.variable) self.replacement_template = transpose(self.replacement_template) elif type == KernelType.conjugate_transpose: pattern_expr = conjugate_transpose(pattern_expr) self.replacement_template = conjugate_transpose( self.replacement_template) self.pattern = matchpy.Pattern(pattern_expr, *self.pattern.constraints) # The pattern object to be used for matching. # Contains context. # Takes type into account. # Example TRSV (transpose type): # Times(ctx1___, Transpose(y_), InverseTranspose(X_), ctx2___) self.pattern_with_context = pattern # Template for the replacement. # Contains context. # Takes type into account. # Example TRSV (transpose type): # Times(ctx1, Transpose(op), ctx2) self.replacement_with_context_template = None # adding contexts if isinstance(pattern.expression, Times): pattern_with_context_expr = Times(ctx1, self.pattern.expression, ctx2) self.replacement_with_context_template = Times( ctx1, self.replacement_template, ctx2) elif isinstance(pattern.expression, Plus): pattern_with_context_expr = Plus(ctx1, self.pattern.expression) self.replacement_with_context_template = Plus( ctx1, self.replacement_template) else: # For unary kernels pattern_with_context_expr = self.pattern.expression self.replacement_with_context_template = self.replacement_template self.pattern_with_context = matchpy.Pattern(pattern_with_context_expr, *self.pattern.constraints) property_list = [] for expr, pos in self.pattern.expression.preorder_iter(): if isinstance(expr, matchpy.Wildcard): for constraint in self.pattern.constraints: if isinstance( constraint, PropertyConstraint ) and constraint.variable == expr.variable_name: property_list.append(constraint.properties) break else: property_list.append(frozenset()) self.property_tuple = PropertyTuple(property_list)
def rule(pat, call): return matchpy.ReplacementRule(matchpy.Pattern(pat), call)
import matchpy ##################### # Cholesky # # id 0 # # It's not possible to use uplo = 'U' here because that changes the output. _A = matchpy.Wildcard.symbol("_A", symbol_type=ae.Matrix) _L = matchpy.Wildcard.symbol("_L") cholesky = FactorizationKernel( matchpy.Pattern(_A, PropertyConstraint("_A", {Property.SPSD})), [InputOperand(_A, StorageFormat.symmetric_lower_triangular)], Times(_L, Transpose(_L)), [OutputOperand(_L, _A, ("N", "N"), [Property.LOWER_TRIANGULAR, Property.NON_SINGULAR], StorageFormat.lower_triangular)], lambda N: (N**3)/3, CodeTemplate("""LAPACK.potrf!('L', $_A)"""), [SizeArgument("N", _A, "rows")], ) ##################### # LU with pivoting # # id 1 _A = matchpy.Wildcard.symbol("_A", symbol_type=ae.Matrix)