def normalize_minus( expr ): ld = PatternDot("ld") md = PatternDot("md") l = PatternStar("l") r = PatternStar("r") return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )), RewriteRule( Times([ Minus([Minus([ld])]), l ]), Replacement( lambda d: Times([ d["ld"], d["l"]]) )), RewriteRule( Minus([ Times([ ld, l ]) ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
def initial_rewrite(operand, part, trav_shape): part_ch = part.get_children() part_shape = (len(part_ch), len(part_ch[0])) zero = Zero((sZERO, sZERO)) if part_shape == (1, 1): #rewrite = [[operand]] return [] elif part_shape == (1, 2): # [L|R] if trav_shape == ( 0, 1 ): # Left to Right, initially Left is empty, Right is the full operand rewrite = [[zero, operand]] elif trav_shape == (0, -1): # Right to Left rewrite = [[operand, zero]] elif part_shape == (2, 1): # [T;B] if trav_shape == (1, 0): # Top to Bottom rewrite = [[zero], [operand]] elif trav_shape == (-1, 0): # Bottom to Top rewrite = [[operand], [zero]] elif part_shape == (2, 2): # [TL TR; BL BR] if trav_shape == (1, 1): # Top Left to Bottom Right rewrite = [[zero, zero], [zero, operand]] elif trav_shape == (1, -1): # Top Right to Bottom Left rewrite = [[zero, zero], [operand, zero]] elif trav_shape == (-1, 1): # Bottom Left to Top Right rewrite = [[zero, operand], [zero, zero]] elif trav_shape == (-1, -1): # Bottom Right to Top Left rewrite = [[operand, zero], [zero, zero]] return list( itertools.chain(*[[ RewriteRule(p, Replacement(rew)) for p, rew in zip(p_row, rew_row) ] for p_row, rew_row in zip(part, rewrite)]))
def expr_to_rule_lhs_rhs(self, predicates): rules = [] for p in predicates: lhs, rhs = p.children if len(lhs.children) == 1: rules.append([RewriteRule(lhs.children[0], Replacement(rhs))]) return rules
def compress_tiles(tiled_subpme): compressed = [] for tiles in tiled_subpme: compressed.append([]) i = 0 while i < len(tiles): #for i, tile in enumerate( tiles ): tile = tiles[i] lhs, rhs = tile.get_children() if isinstance(rhs, Predicate): compressed[-1].append(copy.deepcopy(tile)) i += 1 continue replaced = False for j, other_tile in enumerate(tiles[i + 1:], start=i + 1): olhs, orhs = other_tile.get_children() lhs_ch = lhs.get_children()[0] if contains(orhs, lhs_ch): new_other = copy.deepcopy(other_tile) #new_other = replace( new_other, [RewriteRule(lhs_ch, Replacement(rhs))] ) new_other.children[1] = replace( new_other.rhs(), [RewriteRule(lhs_ch, Replacement(rhs))]) tiles[j] = new_other replaced = True # if tile lhs = f(rhs) and other_tile overwrites lhs (lhs = f(lhs)) # stop replacing with "old" lhs if lhs_ch in olhs.children: break if not replaced: compressed[-1].append(copy.deepcopy(tile)) i += 1 return compressed
def tile( self, tree, node, match_dict ): _T = TOS.new_temp() temp = Symbol("T" + str(_T)) temp.set_property( properties.TEMPORARY ) # if isinstance( node, Predicate ): if node == tree: print( "[Warning] If predicate has multiple outputs, may break if not careful from caller" ) return Equal([ NList([ temp ]), copy.deepcopy( node ) ]), \ replace( tree, [RewriteRule( copy.deepcopy( node ), Replacement(temp) )] ) else: tile_expr = self.create_tile( match_dict ) propagate_properties( tile_expr, temp ) propagate_st_info( tile_expr, temp ) ## [FIXME] Quick and dirty test #if isUpperTriangular( tile_expr ): #print( temp, "is upper triangular" ) #temp.set_property( properties.UPPER_TRIANGULAR ) return Equal([ NList([ temp ]), self.create_tile( match_dict ) ]), \ replace( tree, [self.create_rewrite_rule( match_dict, temp )] )
def opt_copy_propagation(self): for i, st in enumerate(self.statements): if isinstance(st, Equal): lhs, rhs = st.children lhs = lhs.children if len(lhs) == 1 and lhs[0].isTemporary() and isinstance( st.rhs(), Symbol): for oth_st in self.statements[i + 1:]: if not isinstance( oth_st, Equal ): #[CHECK] Careful. Not the case, but generally speaking could be some part/repart continue oth_lhs, oth_rhs = oth_st.lhs(), oth_st.rhs() if contains(oth_rhs, lhs[0]): #\ and not (isinstance( oth_rhs, Predicate ) and pm.DB[oth_rhs.name].overwrite): #[FIXME] Can be less restrictive new_rhs = replace( oth_rhs, [RewriteRule(lhs[0], Replacement(rhs))]) oth_st.children[1] = new_rhs if oth_lhs.children[0] == lhs[0] or oth_lhs == rhs: break
def _PassStorage(st): if isinstance(st, lpla._while): for s in st.body: _PassStorage(s) elif isinstance(st, Equal): symbols = [] for n in st.iterate_preorder(): # Zero has no st_info... if isinstance(n, Zero): pass elif isinstance(n, Symbol) and n.st_info[1] != n: n_st = n.st_info[1] n_st.st_info = n.st_info # E.g., in Cholesky, n is L_11, and n_st is A_11 #if n.isLowerTriangular() and not n_st.isLowerTriangular(): #symbols.append( RewriteRule( n, Replacement(tril(n_st)) ) ) #elif n.isUpperTriangular() and not n_st.isUpperTriangular(): #symbols.append( RewriteRule( n, Replacement(triu(n_st)) ) ) #else: symbols.append(RewriteRule(n, Replacement(n_st))) replace(st, symbols)
def before_to_rule(self, before): lhs, rhs = before.get_children() if isinstance(rhs, Plus): rules = [ RewriteRule( Plus([PatternStar("rest")] + rhs.get_children()), Replacement(lambda d: Plus(d["rest"].get_children() + lhs. get_children()))), #Replacement("Plus(rest+lhs.get_children())") ) RewriteRule( Plus([ Times([PatternStar("l")] + [ch] + [PatternStar("r")]) for ch in rhs.get_children() ]), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))), RewriteRule( to_canonical( Plus([ Times([PatternStar("l")] + [ch] + [PatternStar("r")]) for ch in rhs.get_children() ])), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))) ] elif isinstance(rhs, Times): rules = [ RewriteRule( Times([PatternStar("l")] + rhs.get_children() + [PatternStar("r")]), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))), # [TODO] For chol loop invariants 2 and 3, should fix elsewhere (maybe -1 instead of minus operator) RewriteRule( Times([PatternStar("l")] + [to_canonical(Minus([Times(rhs.get_children())]))] + [PatternStar("r")]), Replacement(lambda d: Times(d["l"].get_children() + [ to_canonical(Minus([Times(lhs.get_children())])) ] + d["r"].get_children()))) ] else: # [FIX] what if multiple outputs? #rules = [ RewriteRule( rhs, Replacement( lhs ) ) ] rules = [RewriteRule(rhs, Replacement(lhs.children[0]))] return rules
def partition(self): self.basic_partitionings = dict() self.partitionings = dict() rewrite_rules = [] for operand in self.operands: op_name = operand.get_name() # For code generation #part = partition_shape( operand, self.part_shape[op_name] ) part = partition_shape_with_storage(operand, self.part_shape[op_name]) self.basic_partitionings[op_name] = part # part = partition(operand, self.part_shape[op_name], operand.get_properties()) self.partitionings[op_name] = part # rewrite_rules.append(RewriteRule(operand, Replacement(part))) self.partitioned_postcondition = \ [replace( copy.deepcopy(eq), rewrite_rules ) for eq in self.equation] print("* Partitioned postcondition") for eq in self.partitioned_postcondition: print("* ", eq)
def solve_equations(self): dist = self.distributed_partitioned_postcondition # Update TOE for eq in dist.flatten_children(): for subeq in eq: lhs, rhs = subeq.get_children() update_TOE(RewriteRule(lhs, Replacement(rhs))) update_TOE(RewriteRule(rhs, Replacement(lhs))) # Collect all output parts to be computed basic_part = self.basic_partitionings basic_part = self.partitionings #[TODO] Ok? all_outputs = [] for op in self.operands: if op.isOutput(): all_outputs.extend([ node for node in basic_part[op.get_name()].iterate_preorder() if isinstance(node, Symbol) ]) # Iterate until PME is solved subeqs_to_solve = dist.flatten_children() subeqs_to_solve = filter_zero_zero(subeqs_to_solve) subeqs_solved = [] solved_outputs = set([ "" ]) # any initialization that cannot render the equality below true solved = False # while not solved: for eq in subeqs_to_solve: new = replace(copy.deepcopy(eq), self.known_ops) if isinstance(new, Equal): lhs, rhs = new.get_children() # Set input property if all([isinstance(elem, Symbol) for elem in lhs]): if isinstance(rhs, Predicate): rhs.set_property(INPUT) # SOLVED? for op in lhs: op.set_property(INPUT) # SOLVED? subeqs_solved.append(new) else: #print( "Can it be?" ) # Yes, e.g., trans(X_BL) in lyapunov #raise Exception pass cur_solved_outputs = set( [op.get_name() for op in all_outputs if isInput(op)]) #print( "SOLVED:", cur_solved_outputs ) if solved_outputs == cur_solved_outputs: print("") print( "[WARNING] PME Generation is stuck - Trying recursive instance" ) print("") minimum = (100000, 100000) subeq = None for eq in subeqs_to_solve: unks = [] cnt = 0 for node in eq.iterate_preorder(): cnt += 1 if isinstance(node, Symbol) and node.isOutput( ) and node not in unks: unks.append(node) if (len(unks), cnt) < minimum: subeq = eq minimum = (len(unks), cnt) from NestedInstance import rec_instance ((md_name, md_data), new_patts, new_pmes) = rec_instance(subeq.children[0]) self.known_pmes.extend(new_pmes) print("NPMES:", len(new_pmes)) for patt in new_patts: self.known_ops.append(patt) pm.DB[md_name] = md_data #raise Exception solved_outputs = cur_solved_outputs if all([isInput(op) or op.isZero() for op in all_outputs]): solved = True else: new_to_solve = [] for eq in subeqs_to_solve: if not isinstance( eq, Equal ) and \ len([op for op in eq.iterate_preorder() if isinstance(op, Symbol) and isOutput(op)]) != 0: new_to_solve.append(eq) subeqs_to_solve = [ simplify(to_canonicalIO(eq))._cleanup() for eq in new_to_solve ] self.solved_subequations = subeqs_solved # Replace rhs with lhs if they appear as subexpressions of other rhss # # Create replacements when suited reuse_rules = [] for i, eq in enumerate(subeqs_solved): eq_rules = [] lhs, rhs = eq.children if not isinstance(rhs, Predicate) and rhs.get_head() not in ( Minus, Transpose, Inverse): lhs_sym = lhs.children[0] t = PatternStar("t") l = PatternStar("l") r = PatternStar("r") repl_f = (lambda lhs_sym: lambda d: Plus( [d["t"], Times([d["l"], lhs_sym, d["r"]])]))(lhs_sym) eq_rules.append( RewriteRule(Plus([t, Times([l, rhs, r])]), Replacement(repl_f))) # A - B C in -L B C R + L A R (minus pushed all the way to the left) if len(rhs.children) > 1: repl_f = (lambda lhs_sym: lambda d: Times( [Minus([lhs_sym])] + d["r"].children))(lhs_sym) eq_rules.append( RewriteRule( Times([ Minus([rhs.children[0]]), *rhs.children[1:], r ]), Replacement(repl_f))) reuse_rules.append((i, eq_rules)) # Replace self.solved_subequations = [] for i, eq in enumerate(subeqs_solved): rules = [r for j, r in reuse_rules if i != j] self.solved_subequations.append( replace_all(eq, list(itertools.chain(*rules)))) # [TODO] Add T to operands and bind dimensions #self.bind_temporaries( ) # Reset outputs for op in solved_outputs: TOS[op][0].set_property(OUTPUT) print("* PME ") for eq in self.solved_subequations: print("* ", eq)
pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)] A = PatternDot("A") B = PatternDot("B") X = PatternDot("X") trsm2lgen_rules = [ # X = i(t(A)) B -> ldiv_lni RewriteRule(( Equal([NList([X]), Times([Inverse([A]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name" )), Replacement(lambda d: Equal([ NList([d["X"]]), Predicate("ldiv_lni", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]))), # X = i(t(A)) B -> ldiv_lni_ow RewriteRule( (Equal([NList([X]), Times([Inverse([A]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name" )), Replacement(lambda d: Equal([ NList([d["X"]]), Predicate("ldiv_lni_ow", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]))), # X = i(t(A)) B -> ldiv_lnu RewriteRule(
def expr_to_rule_rhs_lhs(self, predicates): rules = [] t = PatternStar("t") l = PatternStar("l") ld = PatternDot("ld") r = PatternStar("r") for p in predicates: pr = [] lhs, rhs = p.children if len(lhs.children) == 1: #lhs_sym = WrapOutBef( lhs.children[0] ) lhs_sym = lhs.children[0] if isinstance(rhs, Plus): # t___ + rhs -> t + lhs repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [lhs]))(lhs_sym) pr.append( RewriteRule(Plus([t] + rhs.children), Replacement(repl_f))) # t___ + l___ rhs_i r___ + ... -> t + l lhs r repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times(d["l"].children + [lhs] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([l] + [ch] + [r]) for ch in rhs.children ]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([simplify(to_canonical(Minus([lhs])))] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([simplify(to_canonical(Minus([ch])))] + [r]) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in L B C R + -L A R (minus pushed all the way to the left, and whole thing negated) repl_f = (lambda lhs: lambda d: normalize_minus( Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [Minus([lhs])] + d["r"].children) ])))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, Minus([ch]), r])) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in -L B C R + L A R (minus pushed all the way to the left) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [lhs] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, ch, r])) for ch in rhs.children ]), Replacement(repl_f))) #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym) #pr.append( RewriteRule( Plus([t] + [ #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \ #else Times([ l, ch.children[0], r]) \ #for ch in rhs.children ]), #Replacement(repl_f) ) ) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([ simplify(to_canonical(Minus([Transpose([lhs])]))) ] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \ else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \ for ch in rhs.children ]), Replacement(repl_f) ) ) elif isinstance(rhs, Times): repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [lhs] + d["r"].children))(lhs_sym) pr.append( RewriteRule(Times([l] + rhs.children + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [Transpose([lhs])] + d["r"].children) )(lhs_sym) pr.append( RewriteRule( Times([ l, simplify(to_canonical(Transpose([rhs]))), r ]), Replacement(repl_f))) # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally? repl_f = (lambda lhs: lambda d: Times([ simplify(to_canonical(Minus([Times([lhs])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([Times(rhs.get_children())]))) ] + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times([ simplify( to_canonical(Minus([Transpose([Times([lhs])])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([ Transpose( [Times(rhs.get_children())]) ]))) ] + [r]), Replacement(repl_f))) else: pr.append(RewriteRule(rhs, Replacement(lhs_sym))) new_rhs = simplify(to_canonical(Transpose([rhs]))) if not isOperand(new_rhs): pr.append( RewriteRule( simplify(to_canonical(Transpose([rhs]))), Replacement(Transpose([lhs_sym])))) else: pr.append(RewriteRule(rhs, Replacement(lhs))) rules.append(pr) return rules
def find_updates(self, before, after): # If a part is (partially) computed in the before and # does not appear in the after or # going from before to after requires undoing some computation # it is potentially unstable, and more expensive: ignore try: before_finputs = self.express_in_terms_of_input(before) after_finputs = self.express_in_terms_of_input(after) except: # [TODO] In LU's variant 5, parts of A appear as lhs's return None # dict_bef = dict([(str(u.get_children()[0]), u) for u in before_finputs]) dict_aft = dict([(str(u.get_children()[0]), u) for u in after_finputs]) same = [] ignore = False for k, v in dict_bef.items(): if k in dict_aft and matchq(v, dict_aft[k]): same.extend(v.children[0].children) if k not in dict_aft: ignore = True reason = "%s not in %s" % (k, dict_aft.keys()) break else: rules = self.expr_to_rule_rhs_lhs([v]) rules = list(itertools.chain(*rules)) expr_copy = copy.deepcopy(dict_aft[k]) t = replace(expr_copy, rules) #if v == replace( expr_copy, rules ): if dict_aft[k] == t: ignore = True reason = "%s would require undoing job" % k break if ignore: print("[INFO] Skipping invariant: %s" % reason) return None # # Wrap outputs for before and after WrapBefOut = WrapOutBef lhss = [] for u in before: lhss.extend(u.children[0]) u.children[0] = NList([WrapBefOut(l) for l in u.children[0]]) for u in before: u.children[1] = replace( u.children[1], [RewriteRule(l, Replacement(WrapBefOut(l))) for l in lhss]) # lhss = [] for u in after: lhss.extend(u.children[0]) u.children[0] = NList([WrapOutAft(l) for l in u.children[0]]) wrap_rules_after = \ [ RewriteRule(l, Replacement(WrapBefOut(l))) if l in same else RewriteRule(l, Replacement(WrapOutAft(l))) for l in lhss ] for u in after: u.children[1] = replace(u.children[1], wrap_rules_after) # replace before in before wrap_rules_before = [] for u in before: lhs, rhs = u.get_children() #if len(lhs.children) > 1: #wrap_rules_before.append([]) #continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_before.append(list(itertools.chain(*rules))) # new_rules = [] for i, rules in enumerate(wrap_rules_before): new_rules.append([]) for rule in rules: new_r = copy.deepcopy(rule) new_r.pattern = replace_all( new_r.pattern, list( itertools.chain.from_iterable(wrap_rules_before[:i] + wrap_rules_before[i + 1:]))) if new_r.pattern != rule.pattern: new_rules[-1].append(new_r) for r1, r2 in zip(new_rules, wrap_rules_before): r2.extend(r1) # wrap_rules_before = list(itertools.chain(*wrap_rules_before)) done = False while not done: after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() u.children[1] = simplify( to_canonical( replace_all(copy.deepcopy(rhs), wrap_rules_before))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # replace after in after done = False while not done: # replace after in after wrap_rules_after = [] for u in after: lhs, rhs = u.get_children() #if len(lhs.children) > 1: #wrap_rules_after.append([]) #continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_after.append(list(itertools.chain(*rules))) # after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() rules = list( itertools.chain.from_iterable(wrap_rules_after[:i] + wrap_rules_after[i + 1:])) u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), rules))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # [TODO] Multiple lhss, won't work updates = [] for u in after: lhs, rhs = u.get_children() if len(lhs.children) == 1: lhs = lhs.children[0] # NList[op] -> op if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \ matchq(lhs.children[0], rhs.children[0]): continue elif not isinstance(rhs, NList): # multiple outputs/predicate in rhs, # but not complete (otherwise it would be NList) pass else: to_skip = True for l, r in zip(lhs.children, rhs.children): if not( isinstance(r, WrapBefOut) and isinstance(l, WrapOutAft) and \ matchq(l.children[0], r.children[0]) ): to_skip = False break if to_skip: continue updates.append(u) # tiled_updates = [] for u in updates: print("* ", u) tilings = list(tile_expr(u)) if len(tilings) > 1: print("[WARNING] Multiple (%d) tilings for expression %s" % (len(tilings), u)) print(" Discarding all but one") tiled_updates.extend(tilings[0]) tiled_updates = sort(tiled_updates) print("* Tiled update") for t in tiled_updates: print("* ", t) # Drop WrapOutBef's # Drop WrapOutAft's s = PatternDot("s") updates = [] for u in tiled_updates: u = replace_all( u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))]) u = replace_all( u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))]) updates.append(u) return updates
def find_updates_v2(self, before, after): # If a part is (partially) computed in the before and # does not appear in the after or # going from before to after requires undoing some computation # it is potentially unstable, and more expensive: ignore dict_bef = dict([(str(u.get_children()[0]), u) for u in before]) dict_aft = dict([(str(u.get_children()[0]), u) for u in after]) ignore = False quadrant = None for k, v in dict_bef.items(): if k not in dict_aft: ignore = True break else: rules = self.expr_to_rule_rhs_lhs([v]) rules = list(itertools.chain(*rules)) expr_copy = copy.deepcopy(dict_aft[k]) t = replace(expr_copy, rules) #if v == replace( expr_copy, rules ): if dict_aft[k] == t: ignore = True break if ignore: print("[INFO] Skipping invariant: %s" % reason) return None # # Wrap outputs for before and after WrapBefOut = WrapOutBef for u in before: u.children[0] = NList([WrapBefOut(l) for l in u.children[0]]) # wrap_rules_after = [] for u in after: u.children[0] = NList([WrapOutAft(l) for l in u.children[0]]) # replace before in after wrap_rules_before = [] for u in before: print(u) lhs, rhs = u.get_children() if len(lhs.children) > 1: continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_before.append(list(itertools.chain(*rules))) # for i, rule in enumerate(reversed(wrap_rules_before)): idx = len(wrap_rules_before) - i - 1 for j in range(idx - 1, -1, -1): for _rule in rule: _rule.pattern = replace_all(_rule.pattern, wrap_rules_before[j]) wrap_rules_before = list(itertools.chain(*wrap_rules_before)) # for u in after: _, rhs = u.get_children() u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), wrap_rules_before))) # replace after in after done = False while not done: # replace after in after wrap_rules_after = [] for u in after: lhs, rhs = u.get_children() if len(lhs.children) > 1: wrap_rules_after.append([]) continue rules = self.expr_to_rule_rhs_lhs([u]) wrap_rules_after.append(list(itertools.chain(*rules))) # after_top = [copy.deepcopy(u) for u in after] for i, u in enumerate(after): _, rhs = u.get_children() rules = list( itertools.chain.from_iterable(wrap_rules_after[:i] + wrap_rules_after[i + 1:])) u.children[1] = simplify( to_canonical(replace_all(copy.deepcopy(rhs), rules))) done = True for top, bot in zip(after_top, after): if top != bot: done = False break # [TODO] Multiple lhss, won't work updates = [] for u in after: lhs, rhs = u.get_children() lhs = lhs.children[0] # NList[op] -> op if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \ matchq(lhs.children[0], rhs.children[0]): continue updates.append(u) # tiled_updates = [] for u in updates: print("* ", u) tilings = list(tile_expr(u)) if len(tilings) > 1: print("[WARNING] Multiple (%d) tilings for expression %s" % (len(tilings), u)) print(" Discarding all but one") tiled_updates.extend(tilings[0]) tiled_updates = sort(tiled_updates) print("* Tiled update") for t in tiled_updates: print("* ", t) # Drop WrapOutBef's # Drop WrapOutAft's s = PatternDot("s") updates = [] for u in tiled_updates: u = replace_all( u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))]) u = replace_all( u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))]) updates.append(u) return updates
def learn_pattern(self): inops = [op for op in self.operands if op.isInput()] outops = [op for op in self.operands if op.isOutput()] # single_assignment = len( self.equation.children ) == 1 and \ isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL # op_to_pattern = [ RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name)) for op in self.operands ] pattern = NList([ replace_all(copy.deepcopy(eq), op_to_pattern) for eq in self.equation ]) if single_assignment: props_str = [ symbol_props_to_constraints_no_io(op) for op in self.operands ] constraint = Constraint(" and ".join( [prop for prop in props_str if prop])) else: constraint = Constraint(" and ".join( [symbol_props_to_constraints(op) for op in self.operands])) # [TODO] Tuple for get_size replacement = Replacement( "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join([ op.name for op in inops ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops]))) # [TODO] This should be part of the verbose option print("* Learnt pattern") print("* ", pattern, end="") if constraint.to_eval: print("with ", constraint.to_eval) print(" --> ") print("* ", replacement.to_eval) print("**********************************") # [TODO] Maybe sort known ops by specificity (a la mathematica) #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) ) if single_assignment: expr = pattern.children[0] expr.children[0] = NList([expr.children[0]]) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) # With minus replacement = Replacement( "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join( [op.name for op in inops]), ", ".join( ["%s.get_size()" % op.get_name() for op in outops]))) expr = copy.deepcopy(expr) expr.children[1] = Minus([expr.children[1]]) expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1])) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f: #pickle.dump( known_ops_single[1], patt_f ) #pickle.dump( known_ops_single[0], patt_f ) else: known_ops.insert(0, RewriteRule((pattern, constraint), replacement)) with open(os.path.join("OUTPUT", self.name + "_patterns"), "wb") as patt_f: pickle.dump(known_ops[0], patt_f) pattern = Equal([ NList([PatternDot(op.get_name()) for op in outops]), Predicate(self.name, [PatternDot(op.get_name()) for op in inops], [op.get_size() for op in outops]) ]) replacement = Replacement(equation2replacement(self.equation)) op_to_implicit.append(RewriteRule(pattern, replacement))
def generate_loop_based_algorithms(self): print("* Generating Loop-based algorithms...") self.algs = [] variant = 1 for pme, linvs in zip(self.pmes, self.linvs): algs = [] for linv in linvs: print("* ") print("* Loop invariant", variant) for expr in linv.expressions: print("* ", expr) print("* ") trav, init_state, _ = linv.traversals[ 0] # this would be another for loop init = self.algorithm_initialization(init_state) print("* Init") #print( init_state ) print("* ", init) s = PatternDot("s") init = [ replace_all(i, [ RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"])) ]) for i in init ] print("* Before") repart, before = self.generate_predicate_before( pme, trav, linv.expressions, linv) print("* After") reversed_trav = dict([(k, (r * -1, c * -1)) for k, (r, c) in trav.items()]) cont_with, after = self.generate_predicate_before( pme, reversed_trav, linv.expressions, linv) # find updates print("* Updates") updates = self.find_updates(before, after) if updates is None: #variant += 1 continue # Tile updates for u in updates: infer_properties(u) final_updates = [] # [DIEGO] Fixing some output pieces being labeled as input outputs = [] for u in updates: lhs, rhs = u.children for l in lhs: if not l.isTemporary(): outputs.append(l) l.set_property(OUTPUT) # for u in updates: #[DIEGO] u.children[0].children[0].set_property( OUTPUT ) for node in u.children[1].iterate_preorder(): #if isinstance(node, Symbol): if isinstance(node, Symbol) and node not in outputs: node.set_property(INPUT) # #copy_u = replace( copy.deepcopy(u), known_ops_single ) copy_u = copy.deepcopy(u) copy_u = tile_expr(copy.deepcopy(copy_u))[0] # One tiling final_updates.extend(copy_u) if len(updates) == 0: print("No updates!! Should only happen in copy") continue algs.append( Algorithm(linv, variant, init, repart, cont_with, before, after, final_updates)) algs[-1].prepare_for_code_generation() variant += 1 self.algs.append(algs)
LHS = PatternDot("LHS") RHS = PatternDot("RHS") subexpr = PatternDot("subexpr") PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") PS3 = PatternStar("PS3") PSLeft = PatternStar("PSLeft") PSRight = PatternStar("PSRight") simplify_rules = [ # Minus # --expr -> expr RewriteRule(Minus([Minus([subexpr])]), Replacement(lambda d: d["subexpr"])), # -0 -> 0 RewriteRule((Minus([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))), Replacement(lambda d: d["subexpr"])), # Plus # Plus(a) -> a RewriteRule(Plus([subexpr]), Replacement(lambda d: d["subexpr"])), # Plus( a___, 0, b___ ) -> Plus( a, b ) RewriteRule((Plus([PS1, subexpr, PS2 ]), Constraint(lambda d: isZero(d["subexpr"]))), Replacement(lambda d: Plus([d["PS1"], d["PS2"]]))), # a - a -> 0 RewriteRule( Plus([PS1, subexpr, PS2, Minus([subexpr]), PS3]), Replacement(lambda d: Plus( [d["PS1"],
from Tiling import tile_expr from utils import sort from RewritingExtension import * from BackEnd.MatlabCode import generate_matlab_code, click2matlab #from BackEnd.LatexCode import generate_latex_code from BackEnd.LGenCode import generate_lgen_files #from BackEnd.CCode import generate_c_code known_ops = [ # Straight assignment RewriteRule( (NList([Equal([PatternDot("LHS"), PatternDot("RHS")])]), Constraint( "isinstance(LHS, Symbol) and isOutput(LHS) and isInput(RHS)")), Replacement("Equal([ NList([LHS]), RHS ])")) ] known_ops_single = [] known_pmes = [] op_to_implicit = [] # Add a verbose option # For now it would help to illustrate the process class Operation(object): def __init__(self, name, operands, equation, overwrite): self.name = name self.operands = operands self.equation = equation # list because it may be a coupled equation self.overwrite = overwrite
pass # [FIXME] Default? # [TODO] Complete instruction_set = [ # Basic Ops - Completeness of the instruction set # [ # Inverse Instruction( ( Inverse([ A ]), Constraint("isOperand(A)") ), lambda d, t: \ RewriteRule( Inverse([ d["A"] ]), Replacement( "%s" % t ) ), lambda d: Inverse([ d["A"] ]) ), # Transpose Instruction( ( Transpose([ A ]), Constraint("isOperand(A)") ), lambda d, t: \ RewriteRule( Transpose([ d["A"] ]), Replacement( "%s" % t ) ), lambda d: Transpose([ d["A"] ]) ), # Minus #( Minus([ A ]), Constraint("isOperand(A, Symbol)") ),
def flatten_blocked_operation_click( expr ): PD = PatternDot("PD") rule = RewriteRule( WrapOutBef( PD ), \ Replacement(lambda d: BlockedExpression( map_thread( WrapOutBef, [d["PD"]], 2 ), d["PD"].size, d["PD"].shape)) ) expr = replace( copy.deepcopy(expr), [rule] ) return flatten_blocked_operation( expr )
class triu(Operator): def __init__(self, arg): Operator.__init__(self, [arg], [], UNARY) self.size = arg.get_size() # Patterns for inv to trsm A = PatternDot("A") B = PatternDot("B") C = PatternDot("C") # [TODO] Complete the set of patterns trsm_patterns = [ #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \ RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \ Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ), ] # # Produces Matlab code for: # - Loop-based code # def generate_matlab_code(operation, matlab_dir): # Recursive code out_path = os.path.join(matlab_dir, operation.name + ".m") # [FIX] At some this should be opname_rec.m with open(out_path, "w") as out: generate_matlab_recursive_code(operation, operation.pmes[-1], out) # pmes[-1] should be the 2x2 one # Loop based code
def generate_predicate_before( self, pme, trav, linv, linv_obj): # [TODO] Cleanup, no need for linv AND linv_obj new = [copy.deepcopy(expr) for expr in itertools.chain(*linv)] # Repartition reparts = dict() repart_rules = [] for op in linv_obj.linv_operands: part = linv_obj.linv_operands_basic_part[op.get_name()] # [CHECK] _shape or not? Regular one needed for inheritance repart = repartition(op, part.shape, trav[op.get_name()]) # #repart_shape = {(1,1):(1,1), (1,2):(1,3), (2,1):(3,1), (2,2):(3,3)}[part.shape] #repart = repartition_shape( op, repart_shape ) #repart = repart_group( repart, repart_shape, trav[op.get_name()] ) # for part_op, repart_op in zip(itertools.chain(*part), itertools.chain(*repart)): repart_rules.append( RewriteRule(part_op, Replacement(repart_op))) reparts[op.get_name()] = repart # Apply repartitionings new = [replace(expr, repart_rules) for expr in new] # Explicit functions to BlockedExpression # First flatten args, then replace for expr in new: lhs, rhs = expr.get_children() if isinstance(rhs, Predicate): for i, arg in enumerate(rhs.get_children()): #rhs.set_children( i, flatten_blocked_operation(arg) ) rhs.set_children(i, flatten_blocked_operation_click(arg)) new = [replace(expr, known_pmes) for expr in new] # Operators applied to BlockedExpression, into BlockedExpressions for expr in new: if isinstance(expr, BlockedExpression): continue _, rhs = expr.get_children() #print( rhs ) # [TODO] Maybe "Sylv(...)"!!! #rhs = flatten_blocked_operation( rhs ) rhs = flatten_blocked_operation_click(rhs) expr.set_children(1, rhs) # Flatten left-hand sides of the previous type of expressions for expr in new: if isinstance(expr, BlockedExpression): continue lhs, rhs = expr.get_children() new_lhs = [] for out in lhs: if isinstance(out, Symbol): # this is a temporary one out.size = rhs.get_size() #part = partition_shape( out, rhs.shape ) part = partition_shape(out, tuple(rhs.shape)) new_lhs.append(part) else: new_lhs.append(out) lhs = BlockedExpression(map_thread(NList, new_lhs, 2), (0, 0), rhs.shape) expr.set_children(0, lhs) # Flatten the last type of expressions final = [] for expr in new: if isinstance(expr, BlockedExpression): final.extend([ simplify(to_canonical(eq)) for eq in itertools.chain(*expr) ]) else: lhs, rhs = expr.get_children() final.extend([ simplify(to_canonical(eq)) for eq in itertools.chain.from_iterable( map_thread(Equal, [lhs, rhs], 2)) ]) final = filter_zero_zero(final) # remove expressions of the type " B_10^T = ..." (e.g., in symv)" _final = final final = [] for expr in _final: lhs, rhs = expr.children lhs = lhs.children # [FIXME] == 1 only to make sure it does not affect other cases. # Just want to let them break and study them in the future if len(lhs) == 1 and (not isOperand(lhs[0]) or lhs[0].isZero()): continue final.append(expr) # # expand in terms of input parts # #expand_rules = list(itertools.chain(*self.expr_to_rule_lhs_rhs( final ))) #for expr in final: #expr.children[1] = simplify(to_canonical(replace_all(copy.deepcopy(expr.children[1]), expand_rules))) # # Print and return # for expr in final: print("* ", expr) return (reparts, final)
def learn_pattern(self): inops = [op for op in self.operands if op.isInput()] outops = [op for op in self.operands if op.isOutput()] # pattern predicate_inops = [] predicate_outops = [] for op in self.operands: rewrite_predicate_ops = [] basic_part = self.basic_partitionings[op.get_name()] for part_op in itertools.chain(*basic_part): rewrite_predicate_ops.append( RewriteRule(part_op, Replacement(PatternDot(part_op.get_name())))) new = BlockedExpression( copy.deepcopy(self.basic_partitionings[op.get_name()]), op.get_size(), basic_part.shape) if op.isInput(): predicate_inops.append(replace(new, rewrite_predicate_ops)) else: predicate_outops.append(replace(new, rewrite_predicate_ops)) pattern = Equal([ NList(predicate_outops), Predicate(self.name, predicate_inops, [op.get_size() for op in outops]) ]) # replacement # [TODO] Tuple for get_size #basic_parts = self.basic_partitionings basic_parts = self.partitionings lhss = map_thread(NList, [basic_parts[op.get_name()] for op in outops], 2) # [TODO] This is a fix for lu (maybe coup sylv as well). # Generalize and clean up for i, row in enumerate(lhss): for j, cell in enumerate(row): cell = replace(cell, [ RewriteRule( (NList([PatternPlus("PP"), PatternDot("PD") ]), Constraint(lambda d: isZero(d["PD"]))), Replacement(lambda d: NList(d["PP"].get_children()))), RewriteRule( (NList([PatternDot("PD"), PatternPlus("PP") ]), Constraint(lambda d: isZero(d["PD"]))), Replacement(lambda d: NList(d["PP"].get_children()))) ]) lhss[i][j] = cell # # [CHECK] parts = self.partitionings #parts = self.basic_partitionings parts = self.partitionings eqs = map_thread(Equal, [lhss, parts[outops[0].get_name()]], 2) output_shape = self.basic_partitionings[outops[0].get_name()].shape r, c = output_shape for eq in self.solved_subequations: lhs, rhs = eq.get_children() for row in range(r): for col in range(c): this_lhs, this_rhs = eqs[row][col].get_children() if lhs == this_lhs: eqs[row][col].set_children(1, rhs) #for row in range(r): #for col in range(c): #eqs[row][col] = equation2replacement( eqs[row][col].get_children()[1] ) replacement_str = equation2replacement( BlockedExpression(eqs, (0, 0), output_shape)) #", ".join([ #"[ " + ", ".join( [ eq for eq in row ] ) + " ]" #for row in eqs ]) + " ]" + \ #", (0,0), (%d, %d) )" % (r, c) + \ #"]), " + \ #"BlockedExpression([ " + \ #", ".join([ #"[ " + ", ".join( [ eq for eq in row ] ) + " ]" #for row in eqs ]) + " ]" + \ #", (0,0), (%d, %d) )" % (r, c) + \ #"])" # size does not matter print("* Learnt PME pattern") print("* ", RewriteRule(pattern, Replacement(replacement_str))) self.known_pmes.append( RewriteRule(pattern, Replacement(replacement_str))) with open(os.path.join("OUTPUT", self.name + "_pmes"), "ab+") as pmes_f: pickle.dump(self.known_pmes[-1], pmes_f)
def fix_temporaries(self): temp_exprs = [] for expr in itertools.chain(*self.expressions): lhs, rhs = expr.get_children() for out in lhs: if not out.isInput() and not out.isOutput( ): # [FIXME] Temporaries are not labeled. Now they are, fix. self.linv_operands.append(out) temp_exprs.append(expr) for expr in temp_exprs: #print( expr ) lhs, rhs = expr.get_children() lhs = lhs.get_children()[0] # Determine to which operands in the temporary bound # Also, which quadrant of the full temporary should be used (rows_op, r_dim), (cols_op, c_dim) = utils.size_as_func_of_operands(rhs) # set in bound_dimensions r = rows_op.parent_op.get_name() + "_" + r_dim.lower() for s in self.linv_bound_dimensions: if r in s: s.append(lhs.get_name() + "_r") c = cols_op.parent_op.get_name() + "_" + c_dim.lower() for s in self.linv_bound_dimensions: if c in s: s.append(lhs.get_name() + "_c") # partition temporary rows_parent = rows_op.parent_op.get_name() cols_parent = cols_op.parent_op.get_name() shape_rows = self.pme.part_shape[rows_parent][{ "R": 0, "C": 1 }[r_dim]] shape_cols = self.pme.part_shape[cols_parent][{ "R": 0, "C": 1 }[c_dim]] size_rows = rows_op.parent_op.get_size()[{"R": 0, "C": 1}[r_dim]] size_cols = cols_op.parent_op.get_size()[{"R": 0, "C": 1}[c_dim]] lhs.size = (size_rows, size_cols) part = partition_shape(lhs, (shape_rows, shape_cols)) self.linv_operands_basic_part[lhs.get_name()] = part part_shape = (len(part.children), len(part.children[0])) self.linv_operands_part_shape[lhs.get_name()] = part_shape part_rows = rows_op.quadrant[{"R": 0, "C": 1}[r_dim]] part_cols = cols_op.quadrant[{"R": 0, "C": 1}[c_dim]] quadrant = part[part_rows][part_cols] # replace the temporary with the proper quadrant in every expression of the linv # [TODO] why is loop invariant a list of lists? expressions = [] for expr_l in self.expressions: expressions.append([ replace(copy.deepcopy(expr), [RewriteRule(lhs, Replacement(quadrant))]) for expr in expr_l ]) self.expressions = expressions
TOS.push_back_temp( ) TOS._TOS.unset_operand( tile.children[0].children[0] ) if matched_in_this_level: ### break PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") grouping_rules = [ # A B + A C D -> A (B + C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ])) ), # A B - A C D -> A (B - C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ])) ), # B A + C D A -> (B + C D) A RewriteRule( Plus([ Times([ PS1, PD1 ]), Times([ PS2, PD1 ]) ]), Replacement(lambda d: Times([ Plus([ Times([d["PS1"]]), Times([d["PS2"]]) ]), d["PD1"] ])) ), # B A - C D A -> (B - C D) A RewriteRule( Plus([ Times([ PS1, PD1 ]), Times([ Minus([PD2]), PS2, PD1 ]) ]), Replacement(lambda d: Times([ Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), d["PS2"]])]), d["PD1"] ]))
def _checkPredicateOverwrite(statement): lhs = statement.lhs() rhs = statement.rhs() if not isinstance(rhs, Predicate): rhs_ops = [ node for node in rhs.iterate_preorder() if isinstance(node, Symbol) ] tmp_ops = [op for op in rhs_ops if op.isTemporary()] if len(tmp_ops ) == 1: # [FIXME] Quick and dirty to play with temporaries tmp = tmp_ops[0] if lhs.children[0].size == tmp.size: if not lhs.children[0].isTemporary(): overwrites = False for op in rhs_ops: try: overwrites = lhs.children[0].st_info[ 1] == op.st_info[1] except AttributeError: pass if overwrites: break if not overwrites: statements = [] statements.append(Equal([NList(lhs.children), tmp])) statement.children[1] = replace( copy.deepcopy(rhs), [RewriteRule(tmp, Replacement(lhs.children[0]))]) statements.append(statement) return statements else: # TRSM 2x2 ... pass return [statement] if not pm.DB[rhs.name].overwrite: # [] return [statement] statements = [] # [FIXME] Assumes one single operands get overwritten. Will break in the future already_copied = [] for inp, out in pm.DB[rhs.name].overwrite: if inp in already_copied: continue already_copied.append(inp) # if rhs.children[inp] != lhs.children[ out]: # [FIXME] All should have st_into try: overwrites = lhs.children[out].st_info[1] == rhs.children[ inp].st_info[1] except AttributeError: overwrites = False if overwrites: statements.append(statement) #[FIXME] Gosh... continue inpop = rhs.children[inp] outop = lhs.children[out] if inpop.isTemporary() or inpop.isInput(): # if multiple outputs overwrite input (e.g., LU) if len([o for i, o in pm.DB[rhs.name].overwrite if i == inp ]) > 1: try: outop = TOS._TOS[outop.st_info[1].name][ 0] # LU (ABR = T3; [LBR,UBR] = LU(ABR)) except: pass outop.st_info = (None, outop) # statements.append(Equal([NList([outop]), inpop])) rhs.children[inp] = outop statements.append(statement) else: lhs.children[out] = rhs.children[inp] statements.append(statement) statements.append(Equal([NList([inpop]), outop])) else: statements.append(statement) return statements
subexpr = PatternDot("subexpr") PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") PS3 = PatternStar("PS3") PSLeft = PatternStar("PSLeft") PSRight = PatternStar("PSRight") # TODO: where do we place A*inv(A) -> I? canonical_rules = [ # a * ( b + c ) * e -> a*b*e + a*c*e RewriteRule( Times([PSLeft, Plus([PS1]), PSRight]), Replacement(lambda d: Plus( [Times([d["PSLeft"], term, d["PSRight"]]) for term in d["PS1"]]))), # -( a + b) -> (-a)+(-b) RewriteRule( Minus([Plus([PS1])]), Replacement(lambda d: Plus([Minus([term]) for term in d["PS1"]]))), # -( a * b) -> (-a) * b RewriteRule(Minus([Times([PD1, PS1])]), Replacement(lambda d: Times([Minus([d["PD1"]]), d["PS1"]]))), # a * b * (-c) -> (-a) * b * c RewriteRule( Times([PD1, PS1, Minus([PD2]), PS2]), Replacement(lambda d: Times( [Minus([d["PD1"]]), d["PS1"], d["PD2"], d["PS2"]]))), # Transpose( A + B + C ) -> A^T + B^T + C^T RewriteRule( Transpose([Plus([PS1])]),