def normalize_minus( expr ): ld = PatternDot("ld") md = PatternDot("md") l = PatternStar("l") r = PatternStar("r") return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )), RewriteRule( Times([ Minus([Minus([ld])]), l ]), Replacement( lambda d: Times([ d["ld"], d["l"]]) )), RewriteRule( Minus([ Times([ ld, l ]) ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
def before_to_rule(self, before): lhs, rhs = before.get_children() if isinstance(rhs, Plus): rules = [ RewriteRule( Plus([PatternStar("rest")] + rhs.get_children()), Replacement(lambda d: Plus(d["rest"].get_children() + lhs. get_children()))), #Replacement("Plus(rest+lhs.get_children())") ) RewriteRule( Plus([ Times([PatternStar("l")] + [ch] + [PatternStar("r")]) for ch in rhs.get_children() ]), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))), RewriteRule( to_canonical( Plus([ Times([PatternStar("l")] + [ch] + [PatternStar("r")]) for ch in rhs.get_children() ])), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))) ] elif isinstance(rhs, Times): rules = [ RewriteRule( Times([PatternStar("l")] + rhs.get_children() + [PatternStar("r")]), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))), # [TODO] For chol loop invariants 2 and 3, should fix elsewhere (maybe -1 instead of minus operator) RewriteRule( Times([PatternStar("l")] + [to_canonical(Minus([Times(rhs.get_children())]))] + [PatternStar("r")]), Replacement(lambda d: Times(d["l"].get_children() + [ to_canonical(Minus([Times(lhs.get_children())])) ] + d["r"].get_children()))) ] else: # [FIX] what if multiple outputs? #rules = [ RewriteRule( rhs, Replacement( lhs ) ) ] rules = [RewriteRule(rhs, Replacement(lhs.children[0]))] return rules
def expr_to_rule_rhs_lhs(self, predicates): rules = [] t = PatternStar("t") l = PatternStar("l") ld = PatternDot("ld") r = PatternStar("r") for p in predicates: pr = [] lhs, rhs = p.children if len(lhs.children) == 1: #lhs_sym = WrapOutBef( lhs.children[0] ) lhs_sym = lhs.children[0] if isinstance(rhs, Plus): # t___ + rhs -> t + lhs repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [lhs]))(lhs_sym) pr.append( RewriteRule(Plus([t] + rhs.children), Replacement(repl_f))) # t___ + l___ rhs_i r___ + ... -> t + l lhs r repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times(d["l"].children + [lhs] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([l] + [ch] + [r]) for ch in rhs.children ]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([simplify(to_canonical(Minus([lhs])))] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([simplify(to_canonical(Minus([ch])))] + [r]) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in L B C R + -L A R (minus pushed all the way to the left, and whole thing negated) repl_f = (lambda lhs: lambda d: normalize_minus( Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [Minus([lhs])] + d["r"].children) ])))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, Minus([ch]), r])) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in -L B C R + L A R (minus pushed all the way to the left) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [lhs] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, ch, r])) for ch in rhs.children ]), Replacement(repl_f))) #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym) #pr.append( RewriteRule( Plus([t] + [ #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \ #else Times([ l, ch.children[0], r]) \ #for ch in rhs.children ]), #Replacement(repl_f) ) ) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([ simplify(to_canonical(Minus([Transpose([lhs])]))) ] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \ else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \ for ch in rhs.children ]), Replacement(repl_f) ) ) elif isinstance(rhs, Times): repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [lhs] + d["r"].children))(lhs_sym) pr.append( RewriteRule(Times([l] + rhs.children + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [Transpose([lhs])] + d["r"].children) )(lhs_sym) pr.append( RewriteRule( Times([ l, simplify(to_canonical(Transpose([rhs]))), r ]), Replacement(repl_f))) # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally? repl_f = (lambda lhs: lambda d: Times([ simplify(to_canonical(Minus([Times([lhs])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([Times(rhs.get_children())]))) ] + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times([ simplify( to_canonical(Minus([Transpose([Times([lhs])])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([ Transpose( [Times(rhs.get_children())]) ]))) ] + [r]), Replacement(repl_f))) else: pr.append(RewriteRule(rhs, Replacement(lhs_sym))) new_rhs = simplify(to_canonical(Transpose([rhs]))) if not isOperand(new_rhs): pr.append( RewriteRule( simplify(to_canonical(Transpose([rhs]))), Replacement(Transpose([lhs_sym])))) else: pr.append(RewriteRule(rhs, Replacement(lhs))) rules.append(pr) return rules
from core.expression import Symbol, Matrix, Vector, \ Equal, Plus, Minus, Times, Transpose, \ PatternDot, PatternStar from core.properties import * from core.InferenceOfProperties import * from core.functional import Constraint, Replacement, RewriteRule, replace_all from core.builtin_operands import Zero, Identity LHS = PatternDot("LHS") RHS = PatternDot("RHS") subexpr = PatternDot("subexpr") PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") PS3 = PatternStar("PS3") PSLeft = PatternStar("PSLeft") PSRight = PatternStar("PSRight") simplify_rules = [ # Minus # --expr -> expr RewriteRule(Minus([Minus([subexpr])]), Replacement(lambda d: d["subexpr"])), # -0 -> 0 RewriteRule((Minus([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))), Replacement(lambda d: d["subexpr"])), # Plus # Plus(a) -> a
def solve_equations(self): dist = self.distributed_partitioned_postcondition # Update TOE for eq in dist.flatten_children(): for subeq in eq: lhs, rhs = subeq.get_children() update_TOE(RewriteRule(lhs, Replacement(rhs))) update_TOE(RewriteRule(rhs, Replacement(lhs))) # Collect all output parts to be computed basic_part = self.basic_partitionings basic_part = self.partitionings #[TODO] Ok? all_outputs = [] for op in self.operands: if op.isOutput(): all_outputs.extend([ node for node in basic_part[op.get_name()].iterate_preorder() if isinstance(node, Symbol) ]) # Iterate until PME is solved subeqs_to_solve = dist.flatten_children() subeqs_to_solve = filter_zero_zero(subeqs_to_solve) subeqs_solved = [] solved_outputs = set([ "" ]) # any initialization that cannot render the equality below true solved = False # while not solved: for eq in subeqs_to_solve: new = replace(copy.deepcopy(eq), self.known_ops) if isinstance(new, Equal): lhs, rhs = new.get_children() # Set input property if all([isinstance(elem, Symbol) for elem in lhs]): if isinstance(rhs, Predicate): rhs.set_property(INPUT) # SOLVED? for op in lhs: op.set_property(INPUT) # SOLVED? subeqs_solved.append(new) else: #print( "Can it be?" ) # Yes, e.g., trans(X_BL) in lyapunov #raise Exception pass cur_solved_outputs = set( [op.get_name() for op in all_outputs if isInput(op)]) #print( "SOLVED:", cur_solved_outputs ) if solved_outputs == cur_solved_outputs: print("") print( "[WARNING] PME Generation is stuck - Trying recursive instance" ) print("") minimum = (100000, 100000) subeq = None for eq in subeqs_to_solve: unks = [] cnt = 0 for node in eq.iterate_preorder(): cnt += 1 if isinstance(node, Symbol) and node.isOutput( ) and node not in unks: unks.append(node) if (len(unks), cnt) < minimum: subeq = eq minimum = (len(unks), cnt) from NestedInstance import rec_instance ((md_name, md_data), new_patts, new_pmes) = rec_instance(subeq.children[0]) self.known_pmes.extend(new_pmes) print("NPMES:", len(new_pmes)) for patt in new_patts: self.known_ops.append(patt) pm.DB[md_name] = md_data #raise Exception solved_outputs = cur_solved_outputs if all([isInput(op) or op.isZero() for op in all_outputs]): solved = True else: new_to_solve = [] for eq in subeqs_to_solve: if not isinstance( eq, Equal ) and \ len([op for op in eq.iterate_preorder() if isinstance(op, Symbol) and isOutput(op)]) != 0: new_to_solve.append(eq) subeqs_to_solve = [ simplify(to_canonicalIO(eq))._cleanup() for eq in new_to_solve ] self.solved_subequations = subeqs_solved # Replace rhs with lhs if they appear as subexpressions of other rhss # # Create replacements when suited reuse_rules = [] for i, eq in enumerate(subeqs_solved): eq_rules = [] lhs, rhs = eq.children if not isinstance(rhs, Predicate) and rhs.get_head() not in ( Minus, Transpose, Inverse): lhs_sym = lhs.children[0] t = PatternStar("t") l = PatternStar("l") r = PatternStar("r") repl_f = (lambda lhs_sym: lambda d: Plus( [d["t"], Times([d["l"], lhs_sym, d["r"]])]))(lhs_sym) eq_rules.append( RewriteRule(Plus([t, Times([l, rhs, r])]), Replacement(repl_f))) # A - B C in -L B C R + L A R (minus pushed all the way to the left) if len(rhs.children) > 1: repl_f = (lambda lhs_sym: lambda d: Times( [Minus([lhs_sym])] + d["r"].children))(lhs_sym) eq_rules.append( RewriteRule( Times([ Minus([rhs.children[0]]), *rhs.children[1:], r ]), Replacement(repl_f))) reuse_rules.append((i, eq_rules)) # Replace self.solved_subequations = [] for i, eq in enumerate(subeqs_solved): rules = [r for j, r in reuse_rules if i != j] self.solved_subequations.append( replace_all(eq, list(itertools.chain(*rules)))) # [TODO] Add T to operands and bind dimensions #self.bind_temporaries( ) # Reset outputs for op in solved_outputs: TOS[op][0].set_property(OUTPUT) print("* PME ") for eq in self.solved_subequations: print("* ", eq)
if rhs not in matched_this_node: ### matched_this_node.append( rhs ) ### ongoing.append( alg[:] + [tile, new] ) # Set size of new temporary lhs.children[0].size = rhs.get_size() else: # Aestetic. Simply to avoid missing T? values. TOS.push_back_temp( ) TOS._TOS.unset_operand( tile.children[0].children[0] ) if matched_in_this_level: ### break PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") grouping_rules = [ # A B + A C D -> A (B + C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ])) ), # A B - A C D -> A (B - C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ])) ), # B A + C D A -> (B + C D) A RewriteRule(
from core.functional import Constraint, Replacement, RewriteRule, match, replace import core.properties as properties from core.InferenceOfProperties import isUpperTriangular from core.prop_to_queryfunc import propagate_properties from CoreExtension import isOperand import storage import core.TOS as TOS alpha = PatternDot( "alpha" ) beta = PatternDot( "beta" ) A = PatternDot( "A" ) B = PatternDot( "B" ) C = PatternDot( "C" ) D_PS = PatternStar( "D_PS") left = PatternStar( "left" ) middle = PatternStar( "middle" ) right = PatternStar( "right" ) class Instruction( object ): def __init__( self, pattern, create_rewrite_rule, create_tile ): self.pattern = pattern self.create_rewrite_rule = create_rewrite_rule self.create_tile = create_tile def match( self, expr ): yield from match( expr, self.pattern ) def tile( self, tree, node, match_dict ):