def normalize_minus( expr ): ld = PatternDot("ld") md = PatternDot("md") l = PatternStar("l") r = PatternStar("r") return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )), RewriteRule( Times([ Minus([Minus([ld])]), l ]), Replacement( lambda d: Times([ d["ld"], d["l"]]) )), RewriteRule( Minus([ Times([ ld, l ]) ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
def expr_to_rule_rhs_lhs(self, predicates): rules = [] t = PatternStar("t") l = PatternStar("l") ld = PatternDot("ld") r = PatternStar("r") for p in predicates: pr = [] lhs, rhs = p.children if len(lhs.children) == 1: #lhs_sym = WrapOutBef( lhs.children[0] ) lhs_sym = lhs.children[0] if isinstance(rhs, Plus): # t___ + rhs -> t + lhs repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [lhs]))(lhs_sym) pr.append( RewriteRule(Plus([t] + rhs.children), Replacement(repl_f))) # t___ + l___ rhs_i r___ + ... -> t + l lhs r repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times(d["l"].children + [lhs] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([l] + [ch] + [r]) for ch in rhs.children ]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([simplify(to_canonical(Minus([lhs])))] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([simplify(to_canonical(Minus([ch])))] + [r]) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in L B C R + -L A R (minus pushed all the way to the left, and whole thing negated) repl_f = (lambda lhs: lambda d: normalize_minus( Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [Minus([lhs])] + d["r"].children) ])))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, Minus([ch]), r])) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in -L B C R + L A R (minus pushed all the way to the left) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [lhs] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, ch, r])) for ch in rhs.children ]), Replacement(repl_f))) #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym) #pr.append( RewriteRule( Plus([t] + [ #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \ #else Times([ l, ch.children[0], r]) \ #for ch in rhs.children ]), #Replacement(repl_f) ) ) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([ simplify(to_canonical(Minus([Transpose([lhs])]))) ] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \ else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \ for ch in rhs.children ]), Replacement(repl_f) ) ) elif isinstance(rhs, Times): repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [lhs] + d["r"].children))(lhs_sym) pr.append( RewriteRule(Times([l] + rhs.children + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [Transpose([lhs])] + d["r"].children) )(lhs_sym) pr.append( RewriteRule( Times([ l, simplify(to_canonical(Transpose([rhs]))), r ]), Replacement(repl_f))) # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally? repl_f = (lambda lhs: lambda d: Times([ simplify(to_canonical(Minus([Times([lhs])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([Times(rhs.get_children())]))) ] + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times([ simplify( to_canonical(Minus([Transpose([Times([lhs])])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([ Transpose( [Times(rhs.get_children())]) ]))) ] + [r]), Replacement(repl_f))) else: pr.append(RewriteRule(rhs, Replacement(lhs_sym))) new_rhs = simplify(to_canonical(Transpose([rhs]))) if not isOperand(new_rhs): pr.append( RewriteRule( simplify(to_canonical(Transpose([rhs]))), Replacement(Transpose([lhs_sym])))) else: pr.append(RewriteRule(rhs, Replacement(lhs))) rules.append(pr) return rules
def find_updates_v2(self, before, after): # If a part is (partially) computed in the before and # does not appear in the after or # going from before to after requires undoing some computation # it is potentially unstable, and more expensive: ignore dict_bef = dict([(str(u.get_children()[0]), u) for u in before]) dict_aft = dict([(str(u.get_children()[0]), u) for u in after]) ignore = False quadrant = None for k, v in dict_bef.items(): if k not in dict_aft: ignore = True break else: rules = self.expr_to_rule_rhs_lhs([v]) rules = list(itertools.chain(*rules)) expr_copy = copy.deepcopy(dict_aft[k]) t = replace(expr_copy, rules) #if v == replace( expr_copy, rules ): if dict_aft[k] == t: ignore = True break if ignore: print("[INFO] Skipping invariant: %s" % reason) return None # # Wrap outputs for before and after WrapBefOut = WrapOutBef for u in before: u.children[0] = NList([WrapBefOut(l) for l in u.children[0]]) # wrap_rules_after = [] for u in after: u.children[0] = NList([WrapOutAft(l) for l in u.children[0]]) # replace before in after wrap_rules_before = [] for u in before: print(u) lhs, rhs = u.get_children() if len(lhs.children) > 1: continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_before.append(list(itertools.chain(*rules))) # for i, rule in enumerate(reversed(wrap_rules_before)): idx = len(wrap_rules_before) - i - 1 for j in range(idx - 1, -1, -1): for _rule in rule: _rule.pattern = replace_all(_rule.pattern, wrap_rules_before[j]) wrap_rules_before = list(itertools.chain(*wrap_rules_before)) # for u in after: _, rhs = u.get_children() u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), wrap_rules_before))) # replace after in after done = False while not done: # replace after in after wrap_rules_after = [] for u in after: lhs, rhs = u.get_children() if len(lhs.children) > 1: wrap_rules_after.append([]) continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_after.append(list(itertools.chain(*rules))) # after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() rules = list( itertools.chain.from_iterable(wrap_rules_after[:i] + wrap_rules_after[i + 1:])) u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), rules))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # [TODO] Multiple lhss, won't work updates = [] for u in after: lhs, rhs = u.get_children() lhs = lhs.children[0] # NList[op] -> op if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \ matchq(lhs.children[0], rhs.children[0]): continue updates.append(u) # tiled_updates = [] for u in updates: print("* ", u) tilings = list(tile_expr(u)) if len(tilings) > 1: print("[WARNING] Multiple (%d) tilings for expression %s" % (len(tilings), u)) print(" Discarding all but one") tiled_updates.extend(tilings[0]) tiled_updates = sort(tiled_updates) print("* Tiled update") for t in tiled_updates: print("* ", t) # Drop WrapOutBef's # Drop WrapOutAft's s = PatternDot("s") updates = [] for u in tiled_updates: u = replace_all( u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))]) u = replace_all( u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))]) updates.append(u) return updates
def find_updates(self, before, after): # If a part is (partially) computed in the before and # does not appear in the after or # going from before to after requires undoing some computation # it is potentially unstable, and more expensive: ignore try: before_finputs = self.express_in_terms_of_input(before) after_finputs = self.express_in_terms_of_input(after) except: # [TODO] In LU's variant 5, parts of A appear as lhs's return None # dict_bef = dict([(str(u.get_children()[0]), u) for u in before_finputs]) dict_aft = dict([(str(u.get_children()[0]), u) for u in after_finputs]) same = [] ignore = False for k, v in dict_bef.items(): if k in dict_aft and matchq(v, dict_aft[k]): same.extend(v.children[0].children) if k not in dict_aft: ignore = True reason = "%s not in %s" % (k, dict_aft.keys()) break else: rules = self.expr_to_rule_rhs_lhs([v]) rules = list(itertools.chain(*rules)) expr_copy = copy.deepcopy(dict_aft[k]) t = replace(expr_copy, rules) #if v == replace( expr_copy, rules ): if dict_aft[k] == t: ignore = True reason = "%s would require undoing job" % k break if ignore: print("[INFO] Skipping invariant: %s" % reason) return None # # Wrap outputs for before and after WrapBefOut = WrapOutBef lhss = [] for u in before: lhss.extend(u.children[0]) u.children[0] = NList([WrapBefOut(l) for l in u.children[0]]) for u in before: u.children[1] = replace( u.children[1], [RewriteRule(l, Replacement(WrapBefOut(l))) for l in lhss]) # lhss = [] for u in after: lhss.extend(u.children[0]) u.children[0] = NList([WrapOutAft(l) for l in u.children[0]]) wrap_rules_after = \ [ RewriteRule(l, Replacement(WrapBefOut(l))) if l in same else RewriteRule(l, Replacement(WrapOutAft(l))) for l in lhss ] for u in after: u.children[1] = replace(u.children[1], wrap_rules_after) # replace before in before wrap_rules_before = [] for u in before: lhs, rhs = u.get_children() #if len(lhs.children) > 1: #wrap_rules_before.append([]) #continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_before.append(list(itertools.chain(*rules))) # new_rules = [] for i, rules in enumerate(wrap_rules_before): new_rules.append([]) for rule in rules: new_r = copy.deepcopy(rule) new_r.pattern = replace_all( new_r.pattern, list( itertools.chain.from_iterable(wrap_rules_before[:i] + wrap_rules_before[i + 1:]))) if new_r.pattern != rule.pattern: new_rules[-1].append(new_r) for r1, r2 in zip(new_rules, wrap_rules_before): r2.extend(r1) # wrap_rules_before = list(itertools.chain(*wrap_rules_before)) done = False while not done: after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() u.children[1] = simplify( to_canonical( replace_all(copy.deepcopy(rhs), wrap_rules_before))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # replace after in after done = False while not done: # replace after in after wrap_rules_after = [] for u in after: lhs, rhs = u.get_children() #if len(lhs.children) > 1: #wrap_rules_after.append([]) #continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_after.append(list(itertools.chain(*rules))) # after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() rules = list( itertools.chain.from_iterable(wrap_rules_after[:i] + wrap_rules_after[i + 1:])) u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), rules))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # [TODO] Multiple lhss, won't work updates = [] for u in after: lhs, rhs = u.get_children() if len(lhs.children) == 1: lhs = lhs.children[0] # NList[op] -> op if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \ matchq(lhs.children[0], rhs.children[0]): continue elif not isinstance(rhs, NList): # multiple outputs/predicate in rhs, # but not complete (otherwise it would be NList) pass else: to_skip = True for l, r in zip(lhs.children, rhs.children): if not( isinstance(r, WrapBefOut) and isinstance(l, WrapOutAft) and \ matchq(l.children[0], r.children[0]) ): to_skip = False break if to_skip: continue updates.append(u) # tiled_updates = [] for u in updates: print("* ", u) tilings = list(tile_expr(u)) if len(tilings) > 1: print("[WARNING] Multiple (%d) tilings for expression %s" % (len(tilings), u)) print(" Discarding all but one") tiled_updates.extend(tilings[0]) tiled_updates = sort(tiled_updates) print("* Tiled update") for t in tiled_updates: print("* ", t) # Drop WrapOutBef's # Drop WrapOutAft's s = PatternDot("s") updates = [] for u in tiled_updates: u = replace_all( u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))]) u = replace_all( u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))]) updates.append(u) return updates
def generate_loop_based_algorithms(self): print("* Generating Loop-based algorithms...") self.algs = [] variant = 1 for pme, linvs in zip(self.pmes, self.linvs): algs = [] for linv in linvs: print("* ") print("* Loop invariant", variant) for expr in linv.expressions: print("* ", expr) print("* ") trav, init_state, _ = linv.traversals[ 0] # this would be another for loop init = self.algorithm_initialization(init_state) print("* Init") #print( init_state ) print("* ", init) s = PatternDot("s") init = [ replace_all(i, [ RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"])) ]) for i in init ] print("* Before") repart, before = self.generate_predicate_before( pme, trav, linv.expressions, linv) print("* After") reversed_trav = dict([(k, (r * -1, c * -1)) for k, (r, c) in trav.items()]) cont_with, after = self.generate_predicate_before( pme, reversed_trav, linv.expressions, linv) # find updates print("* Updates") updates = self.find_updates(before, after) if updates is None: #variant += 1 continue # Tile updates for u in updates: infer_properties(u) final_updates = [] # [DIEGO] Fixing some output pieces being labeled as input outputs = [] for u in updates: lhs, rhs = u.children for l in lhs: if not l.isTemporary(): outputs.append(l) l.set_property(OUTPUT) # for u in updates: #[DIEGO] u.children[0].children[0].set_property( OUTPUT ) for node in u.children[1].iterate_preorder(): #if isinstance(node, Symbol): if isinstance(node, Symbol) and node not in outputs: node.set_property(INPUT) # #copy_u = replace( copy.deepcopy(u), known_ops_single ) copy_u = copy.deepcopy(u) copy_u = tile_expr(copy.deepcopy(copy_u))[0] # One tiling final_updates.extend(copy_u) if len(updates) == 0: print("No updates!! Should only happen in copy") continue algs.append( Algorithm(linv, variant, init, repart, cont_with, before, after, final_updates)) algs[-1].prepare_for_code_generation() variant += 1 self.algs.append(algs)
def learn_pattern(self): inops = [op for op in self.operands if op.isInput()] outops = [op for op in self.operands if op.isOutput()] # single_assignment = len( self.equation.children ) == 1 and \ isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL # op_to_pattern = [ RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name)) for op in self.operands ] pattern = NList([ replace_all(copy.deepcopy(eq), op_to_pattern) for eq in self.equation ]) if single_assignment: props_str = [ symbol_props_to_constraints_no_io(op) for op in self.operands ] constraint = Constraint(" and ".join( [prop for prop in props_str if prop])) else: constraint = Constraint(" and ".join( [symbol_props_to_constraints(op) for op in self.operands])) # [TODO] Tuple for get_size replacement = Replacement( "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join([ op.name for op in inops ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops]))) # [TODO] This should be part of the verbose option print("* Learnt pattern") print("* ", pattern, end="") if constraint.to_eval: print("with ", constraint.to_eval) print(" --> ") print("* ", replacement.to_eval) print("**********************************") # [TODO] Maybe sort known ops by specificity (a la mathematica) #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) ) if single_assignment: expr = pattern.children[0] expr.children[0] = NList([expr.children[0]]) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) # With minus replacement = Replacement( "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join( [op.name for op in inops]), ", ".join( ["%s.get_size()" % op.get_name() for op in outops]))) expr = copy.deepcopy(expr) expr.children[1] = Minus([expr.children[1]]) expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1])) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f: #pickle.dump( known_ops_single[1], patt_f ) #pickle.dump( known_ops_single[0], patt_f ) else: known_ops.insert(0, RewriteRule((pattern, constraint), replacement)) with open(os.path.join("OUTPUT", self.name + "_patterns"), "wb") as patt_f: pickle.dump(known_ops[0], patt_f) pattern = Equal([ NList([PatternDot(op.get_name()) for op in outops]), Predicate(self.name, [PatternDot(op.get_name()) for op in inops], [op.get_size() for op in outops]) ]) replacement = Replacement(equation2replacement(self.equation)) op_to_implicit.append(RewriteRule(pattern, replacement))
from BindDimensions import bindDimensions from graph import build_dep_graph, subgraphs, zero_out_lower from Partitioning import partition_shape, repartition, repartition_shape, repart_group from Tiling import tile_expr from utils import sort from RewritingExtension import * from BackEnd.MatlabCode import generate_matlab_code, click2matlab #from BackEnd.LatexCode import generate_latex_code from BackEnd.LGenCode import generate_lgen_files #from BackEnd.CCode import generate_c_code known_ops = [ # Straight assignment RewriteRule( (NList([Equal([PatternDot("LHS"), PatternDot("RHS")])]), Constraint( "isinstance(LHS, Symbol) and isOutput(LHS) and isInput(RHS)")), Replacement("Equal([ NList([LHS]), RHS ])")) ] known_ops_single = [] known_pmes = [] op_to_implicit = [] # Add a verbose option # For now it would help to illustrate the process class Operation(object): def __init__(self, name, operands, equation, overwrite): self.name = name self.operands = operands
from core.expression import Symbol, Matrix, Vector, \ Equal, Plus, Minus, Times, Transpose, \ PatternDot, PatternStar from core.properties import * from core.InferenceOfProperties import * from core.functional import Constraint, Replacement, RewriteRule, replace_all from core.builtin_operands import Zero, Identity LHS = PatternDot("LHS") RHS = PatternDot("RHS") subexpr = PatternDot("subexpr") PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") PS3 = PatternStar("PS3") PSLeft = PatternStar("PSLeft") PSRight = PatternStar("PSRight") simplify_rules = [ # Minus # --expr -> expr RewriteRule(Minus([Minus([subexpr])]), Replacement(lambda d: d["subexpr"])), # -0 -> 0 RewriteRule((Minus([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))), Replacement(lambda d: d["subexpr"])), # Plus # Plus(a) -> a
def flatten_blocked_operation_click( expr ): PD = PatternDot("PD") rule = RewriteRule( WrapOutBef( PD ), \ Replacement(lambda d: BlockedExpression( map_thread( WrapOutBef, [d["PD"]], 2 ), d["PD"].size, d["PD"].shape)) ) expr = replace( copy.deepcopy(expr), [rule] ) return flatten_blocked_operation( expr )
def learn_pattern(self): inops = [op for op in self.operands if op.isInput()] outops = [op for op in self.operands if op.isOutput()] # pattern predicate_inops = [] predicate_outops = [] for op in self.operands: rewrite_predicate_ops = [] basic_part = self.basic_partitionings[op.get_name()] for part_op in itertools.chain(*basic_part): rewrite_predicate_ops.append( RewriteRule(part_op, Replacement(PatternDot(part_op.get_name())))) new = BlockedExpression( copy.deepcopy(self.basic_partitionings[op.get_name()]), op.get_size(), basic_part.shape) if op.isInput(): predicate_inops.append(replace(new, rewrite_predicate_ops)) else: predicate_outops.append(replace(new, rewrite_predicate_ops)) pattern = Equal([ NList(predicate_outops), Predicate(self.name, predicate_inops, [op.get_size() for op in outops]) ]) # replacement # [TODO] Tuple for get_size #basic_parts = self.basic_partitionings basic_parts = self.partitionings lhss = map_thread(NList, [basic_parts[op.get_name()] for op in outops], 2) # [TODO] This is a fix for lu (maybe coup sylv as well). # Generalize and clean up for i, row in enumerate(lhss): for j, cell in enumerate(row): cell = replace(cell, [ RewriteRule( (NList([PatternPlus("PP"), PatternDot("PD") ]), Constraint(lambda d: isZero(d["PD"]))), Replacement(lambda d: NList(d["PP"].get_children()))), RewriteRule( (NList([PatternDot("PD"), PatternPlus("PP") ]), Constraint(lambda d: isZero(d["PD"]))), Replacement(lambda d: NList(d["PP"].get_children()))) ]) lhss[i][j] = cell # # [CHECK] parts = self.partitionings #parts = self.basic_partitionings parts = self.partitionings eqs = map_thread(Equal, [lhss, parts[outops[0].get_name()]], 2) output_shape = self.basic_partitionings[outops[0].get_name()].shape r, c = output_shape for eq in self.solved_subequations: lhs, rhs = eq.get_children() for row in range(r): for col in range(c): this_lhs, this_rhs = eqs[row][col].get_children() if lhs == this_lhs: eqs[row][col].set_children(1, rhs) #for row in range(r): #for col in range(c): #eqs[row][col] = equation2replacement( eqs[row][col].get_children()[1] ) replacement_str = equation2replacement( BlockedExpression(eqs, (0, 0), output_shape)) #", ".join([ #"[ " + ", ".join( [ eq for eq in row ] ) + " ]" #for row in eqs ]) + " ]" + \ #", (0,0), (%d, %d) )" % (r, c) + \ #"]), " + \ #"BlockedExpression([ " + \ #", ".join([ #"[ " + ", ".join( [ eq for eq in row ] ) + " ]" #for row in eqs ]) + " ]" + \ #", (0,0), (%d, %d) )" % (r, c) + \ #"])" # size does not matter print("* Learnt PME pattern") print("* ", RewriteRule(pattern, Replacement(replacement_str))) self.known_pmes.append( RewriteRule(pattern, Replacement(replacement_str))) with open(os.path.join("OUTPUT", self.name + "_pmes"), "ab+") as pmes_f: pickle.dump(self.known_pmes[-1], pmes_f)
tile, new = instr.tile( new, node, _m ) lhs, rhs = tile.get_children() if rhs not in matched_this_node: ### matched_this_node.append( rhs ) ### ongoing.append( alg[:] + [tile, new] ) # Set size of new temporary lhs.children[0].size = rhs.get_size() else: # Aestetic. Simply to avoid missing T? values. TOS.push_back_temp( ) TOS._TOS.unset_operand( tile.children[0].children[0] ) if matched_in_this_level: ### break PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") grouping_rules = [ # A B + A C D -> A (B + C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ])) ), # A B - A C D -> A (B - C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ])) ),
class tril(Operator): def __init__(self, arg): Operator.__init__(self, [arg], [], UNARY) self.size = arg.get_size() class triu(Operator): def __init__(self, arg): Operator.__init__(self, [arg], [], UNARY) self.size = arg.get_size() # Patterns for inv to trsm A = PatternDot("A") B = PatternDot("B") C = PatternDot("C") # [TODO] Complete the set of patterns trsm_patterns = [ #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \ RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \ Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ), ] # # Produces Matlab code for: # - Loop-based code # def generate_matlab_code(operation, matlab_dir):
import copy from core.expression import Equal, Plus, Times, Minus, Transpose, Inverse, \ Symbol, PatternDot, PatternStar, NList, Predicate from core.functional import Constraint, Replacement, RewriteRule, match, replace import core.properties as properties from core.InferenceOfProperties import isUpperTriangular from core.prop_to_queryfunc import propagate_properties from CoreExtension import isOperand import storage import core.TOS as TOS alpha = PatternDot( "alpha" ) beta = PatternDot( "beta" ) A = PatternDot( "A" ) B = PatternDot( "B" ) C = PatternDot( "C" ) D_PS = PatternStar( "D_PS") left = PatternStar( "left" ) middle = PatternStar( "middle" ) right = PatternStar( "right" ) class Instruction( object ): def __init__( self, pattern, create_rewrite_rule, create_tile ): self.pattern = pattern self.create_rewrite_rule = create_rewrite_rule self.create_tile = create_tile
pm.DB["rdiv_unu_ow"] = pm.PredicateMetadata("rdiv_unu_ow", tuple()) pm.DB["rdiv_unu_ow"].overwrite = [(1, 0)] pm.DB["rdiv_uti"] = pm.PredicateMetadata("rdiv_uti", tuple()) pm.DB["rdiv_uti"].overwrite = [] pm.DB["rdiv_uti_ow"] = pm.PredicateMetadata("rdiv_uti_ow", tuple()) pm.DB["rdiv_uti_ow"].overwrite = [(1, 0)] pm.DB["rdiv_utn"] = pm.PredicateMetadata("rdiv_utn", tuple()) pm.DB["rdiv_utn"].overwrite = [] pm.DB["rdiv_utn_ow"] = pm.PredicateMetadata("rdiv_utn_ow", tuple()) pm.DB["rdiv_utn_ow"].overwrite = [(1, 0)] pm.DB["rdiv_utu"] = pm.PredicateMetadata("rdiv_utu", tuple()) pm.DB["rdiv_utu"].overwrite = [] pm.DB["rdiv_utu_ow"] = pm.PredicateMetadata("rdiv_utu_ow", tuple()) pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)] A = PatternDot("A") B = PatternDot("B") X = PatternDot("X") trsm2lgen_rules = [ # X = i(t(A)) B -> ldiv_lni RewriteRule(( Equal([NList([X]), Times([Inverse([A]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name" )), Replacement(lambda d: Equal([ NList([d["X"]]), Predicate("ldiv_lni", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]))),