def to_canonical_IO(self): self.equation = [ replace_all( eq, canonical_rules + canonicalIO_rules + simplify_rules ) \ for eq in self.equation ] # Minimize number of nodes among: # 1) as is, # 2) applying minus to both sides, and # 3) applying transpose to both sides minimal = [] for eq in self.equation: alternative_forms = [ eq, simplify( to_canonical(Equal([Minus([eq.lhs()]), Minus([eq.rhs()])]))), simplify( to_canonical( Equal([Transpose([eq.lhs()]), Transpose([eq.rhs()])]))) ] _, new = min([(alt.num_nodes(), alt) for alt in alternative_forms]) minimal.append(new) # Set minimal forms self.equation = NList(minimal) # nodes = list( itertools.chain(*[[node for node in eq.iterate_preorder()] for eq in self.equation])) self.operands = [op for op in self.operands if op in nodes]
def normalize_minus( expr ): ld = PatternDot("ld") md = PatternDot("md") l = PatternStar("l") r = PatternStar("r") return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )), RewriteRule( Times([ Minus([Minus([ld])]), l ]), Replacement( lambda d: Times([ d["ld"], d["l"]]) )), RewriteRule( Minus([ Times([ ld, l ]) ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
def express_in_terms_of_input(self, assignments): # In LU, there may be (overwritable) inputs as lhs for a in assignments: if a.children[0].children[0].isInput(): raise Exception assignments = [ a for a in assignments if not a.children[0].children[0].isInput() ] expand_rules = list( itertools.chain(*self.expr_to_rule_lhs_rhs(assignments))) new_ass = [] for expr in assignments: new = copy.deepcopy(expr) new.children[1] = simplify( to_canonical(replace_all(new.children[1], expand_rules))) new_ass.append(new) return new_ass
def all_tilings( expr ): ongoing = [ [expr] ] while len( ongoing ) > 0: alg = ongoing.pop() to_tile = alg.pop() if isOperand( to_tile ): alg.append( to_tile ) yield alg continue to_tile = replace_all( copy.deepcopy( to_tile), grouping_rules ) for collection in reversed(instruction_set): matched_in_this_level = False ### These are just control vars ### to avoid redundancies for instr in collection: for node in to_tile.iterate_preorder(): # To avoid redundancies matched_this_node = [] ### if isinstance( instr, tuple ): # A way to deactivate some for quick development/debugging? continue for _m in instr.match( node ): matched_in_this_level = True ### #_T = TOS.new_temp() new = copy.deepcopy( to_tile ) #tile, new = instr.tile( new, _m, _T ) tile, new = instr.tile( new, node, _m ) lhs, rhs = tile.get_children() if rhs not in matched_this_node: ### matched_this_node.append( rhs ) ### ongoing.append( alg[:] + [tile, new] ) # Set size of new temporary lhs.children[0].size = rhs.get_size() else: # Aestetic. Simply to avoid missing T? values. TOS.push_back_temp( ) TOS._TOS.unset_operand( tile.children[0].children[0] ) if matched_in_this_level: ### break
def find_updates_v2(self, before, after): # If a part is (partially) computed in the before and # does not appear in the after or # going from before to after requires undoing some computation # it is potentially unstable, and more expensive: ignore dict_bef = dict([(str(u.get_children()[0]), u) for u in before]) dict_aft = dict([(str(u.get_children()[0]), u) for u in after]) ignore = False quadrant = None for k, v in dict_bef.items(): if k not in dict_aft: ignore = True break else: rules = self.expr_to_rule_rhs_lhs([v]) rules = list(itertools.chain(*rules)) expr_copy = copy.deepcopy(dict_aft[k]) t = replace(expr_copy, rules) #if v == replace( expr_copy, rules ): if dict_aft[k] == t: ignore = True break if ignore: print("[INFO] Skipping invariant: %s" % reason) return None # # Wrap outputs for before and after WrapBefOut = WrapOutBef for u in before: u.children[0] = NList([WrapBefOut(l) for l in u.children[0]]) # wrap_rules_after = [] for u in after: u.children[0] = NList([WrapOutAft(l) for l in u.children[0]]) # replace before in after wrap_rules_before = [] for u in before: print(u) lhs, rhs = u.get_children() if len(lhs.children) > 1: continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_before.append(list(itertools.chain(*rules))) # for i, rule in enumerate(reversed(wrap_rules_before)): idx = len(wrap_rules_before) - i - 1 for j in range(idx - 1, -1, -1): for _rule in rule: _rule.pattern = replace_all(_rule.pattern, wrap_rules_before[j]) wrap_rules_before = list(itertools.chain(*wrap_rules_before)) # for u in after: _, rhs = u.get_children() u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), wrap_rules_before))) # replace after in after done = False while not done: # replace after in after wrap_rules_after = [] for u in after: lhs, rhs = u.get_children() if len(lhs.children) > 1: wrap_rules_after.append([]) continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_after.append(list(itertools.chain(*rules))) # after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() rules = list( itertools.chain.from_iterable(wrap_rules_after[:i] + wrap_rules_after[i + 1:])) u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), rules))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # [TODO] Multiple lhss, won't work updates = [] for u in after: lhs, rhs = u.get_children() lhs = lhs.children[0] # NList[op] -> op if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \ matchq(lhs.children[0], rhs.children[0]): continue updates.append(u) # tiled_updates = [] for u in updates: print("* ", u) tilings = list(tile_expr(u)) if len(tilings) > 1: print("[WARNING] Multiple (%d) tilings for expression %s" % (len(tilings), u)) print(" Discarding all but one") tiled_updates.extend(tilings[0]) tiled_updates = sort(tiled_updates) print("* Tiled update") for t in tiled_updates: print("* ", t) # Drop WrapOutBef's # Drop WrapOutAft's s = PatternDot("s") updates = [] for u in tiled_updates: u = replace_all( u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))]) u = replace_all( u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))]) updates.append(u) return updates
def find_updates(self, before, after): # If a part is (partially) computed in the before and # does not appear in the after or # going from before to after requires undoing some computation # it is potentially unstable, and more expensive: ignore try: before_finputs = self.express_in_terms_of_input(before) after_finputs = self.express_in_terms_of_input(after) except: # [TODO] In LU's variant 5, parts of A appear as lhs's return None # dict_bef = dict([(str(u.get_children()[0]), u) for u in before_finputs]) dict_aft = dict([(str(u.get_children()[0]), u) for u in after_finputs]) same = [] ignore = False for k, v in dict_bef.items(): if k in dict_aft and matchq(v, dict_aft[k]): same.extend(v.children[0].children) if k not in dict_aft: ignore = True reason = "%s not in %s" % (k, dict_aft.keys()) break else: rules = self.expr_to_rule_rhs_lhs([v]) rules = list(itertools.chain(*rules)) expr_copy = copy.deepcopy(dict_aft[k]) t = replace(expr_copy, rules) #if v == replace( expr_copy, rules ): if dict_aft[k] == t: ignore = True reason = "%s would require undoing job" % k break if ignore: print("[INFO] Skipping invariant: %s" % reason) return None # # Wrap outputs for before and after WrapBefOut = WrapOutBef lhss = [] for u in before: lhss.extend(u.children[0]) u.children[0] = NList([WrapBefOut(l) for l in u.children[0]]) for u in before: u.children[1] = replace( u.children[1], [RewriteRule(l, Replacement(WrapBefOut(l))) for l in lhss]) # lhss = [] for u in after: lhss.extend(u.children[0]) u.children[0] = NList([WrapOutAft(l) for l in u.children[0]]) wrap_rules_after = \ [ RewriteRule(l, Replacement(WrapBefOut(l))) if l in same else RewriteRule(l, Replacement(WrapOutAft(l))) for l in lhss ] for u in after: u.children[1] = replace(u.children[1], wrap_rules_after) # replace before in before wrap_rules_before = [] for u in before: lhs, rhs = u.get_children() #if len(lhs.children) > 1: #wrap_rules_before.append([]) #continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_before.append(list(itertools.chain(*rules))) # new_rules = [] for i, rules in enumerate(wrap_rules_before): new_rules.append([]) for rule in rules: new_r = copy.deepcopy(rule) new_r.pattern = replace_all( new_r.pattern, list( itertools.chain.from_iterable(wrap_rules_before[:i] + wrap_rules_before[i + 1:]))) if new_r.pattern != rule.pattern: new_rules[-1].append(new_r) for r1, r2 in zip(new_rules, wrap_rules_before): r2.extend(r1) # wrap_rules_before = list(itertools.chain(*wrap_rules_before)) done = False while not done: after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() u.children[1] = simplify( to_canonical( replace_all(copy.deepcopy(rhs), wrap_rules_before))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # replace after in after done = False while not done: # replace after in after wrap_rules_after = [] for u in after: lhs, rhs = u.get_children() #if len(lhs.children) > 1: #wrap_rules_after.append([]) #continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_after.append(list(itertools.chain(*rules))) # after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() rules = list( itertools.chain.from_iterable(wrap_rules_after[:i] + wrap_rules_after[i + 1:])) u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), rules))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # [TODO] Multiple lhss, won't work updates = [] for u in after: lhs, rhs = u.get_children() if len(lhs.children) == 1: lhs = lhs.children[0] # NList[op] -> op if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \ matchq(lhs.children[0], rhs.children[0]): continue elif not isinstance(rhs, NList): # multiple outputs/predicate in rhs, # but not complete (otherwise it would be NList) pass else: to_skip = True for l, r in zip(lhs.children, rhs.children): if not( isinstance(r, WrapBefOut) and isinstance(l, WrapOutAft) and \ matchq(l.children[0], r.children[0]) ): to_skip = False break if to_skip: continue updates.append(u) # tiled_updates = [] for u in updates: print("* ", u) tilings = list(tile_expr(u)) if len(tilings) > 1: print("[WARNING] Multiple (%d) tilings for expression %s" % (len(tilings), u)) print(" Discarding all but one") tiled_updates.extend(tilings[0]) tiled_updates = sort(tiled_updates) print("* Tiled update") for t in tiled_updates: print("* ", t) # Drop WrapOutBef's # Drop WrapOutAft's s = PatternDot("s") updates = [] for u in tiled_updates: u = replace_all( u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))]) u = replace_all( u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))]) updates.append(u) return updates
def generate_loop_based_algorithms(self): print("* Generating Loop-based algorithms...") self.algs = [] variant = 1 for pme, linvs in zip(self.pmes, self.linvs): algs = [] for linv in linvs: print("* ") print("* Loop invariant", variant) for expr in linv.expressions: print("* ", expr) print("* ") trav, init_state, _ = linv.traversals[ 0] # this would be another for loop init = self.algorithm_initialization(init_state) print("* Init") #print( init_state ) print("* ", init) s = PatternDot("s") init = [ replace_all(i, [ RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"])) ]) for i in init ] print("* Before") repart, before = self.generate_predicate_before( pme, trav, linv.expressions, linv) print("* After") reversed_trav = dict([(k, (r * -1, c * -1)) for k, (r, c) in trav.items()]) cont_with, after = self.generate_predicate_before( pme, reversed_trav, linv.expressions, linv) # find updates print("* Updates") updates = self.find_updates(before, after) if updates is None: #variant += 1 continue # Tile updates for u in updates: infer_properties(u) final_updates = [] # [DIEGO] Fixing some output pieces being labeled as input outputs = [] for u in updates: lhs, rhs = u.children for l in lhs: if not l.isTemporary(): outputs.append(l) l.set_property(OUTPUT) # for u in updates: #[DIEGO] u.children[0].children[0].set_property( OUTPUT ) for node in u.children[1].iterate_preorder(): #if isinstance(node, Symbol): if isinstance(node, Symbol) and node not in outputs: node.set_property(INPUT) # #copy_u = replace( copy.deepcopy(u), known_ops_single ) copy_u = copy.deepcopy(u) copy_u = tile_expr(copy.deepcopy(copy_u))[0] # One tiling final_updates.extend(copy_u) if len(updates) == 0: print("No updates!! Should only happen in copy") continue algs.append( Algorithm(linv, variant, init, repart, cont_with, before, after, final_updates)) algs[-1].prepare_for_code_generation() variant += 1 self.algs.append(algs)
def learn_pattern(self): inops = [op for op in self.operands if op.isInput()] outops = [op for op in self.operands if op.isOutput()] # single_assignment = len( self.equation.children ) == 1 and \ isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL # op_to_pattern = [ RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name)) for op in self.operands ] pattern = NList([ replace_all(copy.deepcopy(eq), op_to_pattern) for eq in self.equation ]) if single_assignment: props_str = [ symbol_props_to_constraints_no_io(op) for op in self.operands ] constraint = Constraint(" and ".join( [prop for prop in props_str if prop])) else: constraint = Constraint(" and ".join( [symbol_props_to_constraints(op) for op in self.operands])) # [TODO] Tuple for get_size replacement = Replacement( "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join([ op.name for op in inops ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops]))) # [TODO] This should be part of the verbose option print("* Learnt pattern") print("* ", pattern, end="") if constraint.to_eval: print("with ", constraint.to_eval) print(" --> ") print("* ", replacement.to_eval) print("**********************************") # [TODO] Maybe sort known ops by specificity (a la mathematica) #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) ) if single_assignment: expr = pattern.children[0] expr.children[0] = NList([expr.children[0]]) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) # With minus replacement = Replacement( "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join( [op.name for op in inops]), ", ".join( ["%s.get_size()" % op.get_name() for op in outops]))) expr = copy.deepcopy(expr) expr.children[1] = Minus([expr.children[1]]) expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1])) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f: #pickle.dump( known_ops_single[1], patt_f ) #pickle.dump( known_ops_single[0], patt_f ) else: known_ops.insert(0, RewriteRule((pattern, constraint), replacement)) with open(os.path.join("OUTPUT", self.name + "_patterns"), "wb") as patt_f: pickle.dump(known_ops[0], patt_f) pattern = Equal([ NList([PatternDot(op.get_name()) for op in outops]), Predicate(self.name, [PatternDot(op.get_name()) for op in inops], [op.get_size() for op in outops]) ]) replacement = Replacement(equation2replacement(self.equation)) op_to_implicit.append(RewriteRule(pattern, replacement))
def simplify(expr): return replace_all(expr, simplify_rules)
def to_canonicalIO(expr): return replace_all(expr, canonical_rules + canonicalIO_rules)
def to_canonical(expr): return replace_all(expr, canonical_rules + simplify_rules_base)