def to_canonical_IO(self): self.equation = [ replace_all( eq, canonical_rules + canonicalIO_rules + simplify_rules ) \ for eq in self.equation ] # Minimize number of nodes among: # 1) as is, # 2) applying minus to both sides, and # 3) applying transpose to both sides minimal = [] for eq in self.equation: alternative_forms = [ eq, simplify( to_canonical(Equal([Minus([eq.lhs()]), Minus([eq.rhs()])]))), simplify( to_canonical( Equal([Transpose([eq.lhs()]), Transpose([eq.rhs()])]))) ] _, new = min([(alt.num_nodes(), alt) for alt in alternative_forms]) minimal.append(new) # Set minimal forms self.equation = NList(minimal) # nodes = list( itertools.chain(*[[node for node in eq.iterate_preorder()] for eq in self.equation])) self.operands = [op for op in self.operands if op in nodes]
def tile_expr( _expr ): tiles = [] expr = copy.deepcopy( _expr ) lhs, rhs = expr.get_children() if isinstance( rhs, Predicate ): tiles_per_arg = [] for i, arg in enumerate( rhs.get_children() ): if not isOperand( arg ): tiles_per_arg.append( list(all_tilings( arg )) ) else: tiles_per_arg.append( [[arg]] ) cross = itertools.product( *tiles_per_arg ) for comb in cross: new_pred = copy.deepcopy( rhs ) new_ch = [ t[-1] for t in comb ] for i,ch in enumerate( new_ch ): new_pred.set_children( i, ch ) updates = list(itertools.chain.from_iterable( [t[:-1] for t in comb] )) tiles.append( updates + [Equal([lhs, new_pred])] ) else: if isOperand(rhs): tiles.append( [_expr] ) else: for tiling in all_tilings( rhs ): last = tiling.pop() # This is just a (temporary) symbol (output of one-to-last) one_to_last = tiling.pop() lhs, rhs = expr.get_children() one_to_last.set_children( 0, lhs ) # assign one-to-last to actual lhs of eq tiling.append( one_to_last ) tiles.append( tiling ) return tiles
def algorithm_initialization(self, init_state): init = [] for expr in init_state: lhs, rhs = expr.get_children() lhs_ch = lhs.get_children() #init.extend([ Equal([ NList([lch]), rhs ]) for lch in lhs_ch if not isZero(lch) and not isZero(rhs) ]) init.extend([ Equal([NList([lch]), rhs]) for lch in lhs_ch if not isZero(lch) ]) return init
def tile( self, tree, node, match_dict ): _T = TOS.new_temp() temp = Symbol("T" + str(_T)) temp.set_property( properties.TEMPORARY ) # if isinstance( node, Predicate ): if node == tree: print( "[Warning] If predicate has multiple outputs, may break if not careful from caller" ) return Equal([ NList([ temp ]), copy.deepcopy( node ) ]), \ replace( tree, [RewriteRule( copy.deepcopy( node ), Replacement(temp) )] ) else: tile_expr = self.create_tile( match_dict ) propagate_properties( tile_expr, temp ) propagate_st_info( tile_expr, temp ) ## [FIXME] Quick and dirty test #if isUpperTriangular( tile_expr ): #print( temp, "is upper triangular" ) #temp.set_property( properties.UPPER_TRIANGULAR ) return Equal([ NList([ temp ]), self.create_tile( match_dict ) ]), \ replace( tree, [self.create_rewrite_rule( match_dict, temp )] )
def _checkPredicateOverwrite(statement): lhs = statement.lhs() rhs = statement.rhs() if not isinstance(rhs, Predicate): rhs_ops = [ node for node in rhs.iterate_preorder() if isinstance(node, Symbol) ] tmp_ops = [op for op in rhs_ops if op.isTemporary()] if len(tmp_ops ) == 1: # [FIXME] Quick and dirty to play with temporaries tmp = tmp_ops[0] if lhs.children[0].size == tmp.size: if not lhs.children[0].isTemporary(): overwrites = False for op in rhs_ops: try: overwrites = lhs.children[0].st_info[ 1] == op.st_info[1] except AttributeError: pass if overwrites: break if not overwrites: statements = [] statements.append(Equal([NList(lhs.children), tmp])) statement.children[1] = replace( copy.deepcopy(rhs), [RewriteRule(tmp, Replacement(lhs.children[0]))]) statements.append(statement) return statements else: # TRSM 2x2 ... pass return [statement] if not pm.DB[rhs.name].overwrite: # [] return [statement] statements = [] # [FIXME] Assumes one single operands get overwritten. Will break in the future already_copied = [] for inp, out in pm.DB[rhs.name].overwrite: if inp in already_copied: continue already_copied.append(inp) # if rhs.children[inp] != lhs.children[ out]: # [FIXME] All should have st_into try: overwrites = lhs.children[out].st_info[1] == rhs.children[ inp].st_info[1] except AttributeError: overwrites = False if overwrites: statements.append(statement) #[FIXME] Gosh... continue inpop = rhs.children[inp] outop = lhs.children[out] if inpop.isTemporary() or inpop.isInput(): # if multiple outputs overwrite input (e.g., LU) if len([o for i, o in pm.DB[rhs.name].overwrite if i == inp ]) > 1: try: outop = TOS._TOS[outop.st_info[1].name][ 0] # LU (ABR = T3; [LBR,UBR] = LU(ABR)) except: pass outop.st_info = (None, outop) # statements.append(Equal([NList([outop]), inpop])) rhs.children[inp] = outop statements.append(statement) else: lhs.children[out] = rhs.children[inp] statements.append(statement) statements.append(Equal([NList([inpop]), outop])) else: statements.append(statement) return statements
def learn_pattern(self): inops = [op for op in self.operands if op.isInput()] outops = [op for op in self.operands if op.isOutput()] # single_assignment = len( self.equation.children ) == 1 and \ isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL # op_to_pattern = [ RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name)) for op in self.operands ] pattern = NList([ replace_all(copy.deepcopy(eq), op_to_pattern) for eq in self.equation ]) if single_assignment: props_str = [ symbol_props_to_constraints_no_io(op) for op in self.operands ] constraint = Constraint(" and ".join( [prop for prop in props_str if prop])) else: constraint = Constraint(" and ".join( [symbol_props_to_constraints(op) for op in self.operands])) # [TODO] Tuple for get_size replacement = Replacement( "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join([ op.name for op in inops ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops]))) # [TODO] This should be part of the verbose option print("* Learnt pattern") print("* ", pattern, end="") if constraint.to_eval: print("with ", constraint.to_eval) print(" --> ") print("* ", replacement.to_eval) print("**********************************") # [TODO] Maybe sort known ops by specificity (a la mathematica) #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) ) if single_assignment: expr = pattern.children[0] expr.children[0] = NList([expr.children[0]]) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) # With minus replacement = Replacement( "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])" % (", ".join([op.name for op in outops]), self.name, ", ".join( [op.name for op in inops]), ", ".join( ["%s.get_size()" % op.get_name() for op in outops]))) expr = copy.deepcopy(expr) expr.children[1] = Minus([expr.children[1]]) expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1])) known_ops_single.insert( 0, RewriteRule((expr, constraint), replacement)) #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f: #pickle.dump( known_ops_single[1], patt_f ) #pickle.dump( known_ops_single[0], patt_f ) else: known_ops.insert(0, RewriteRule((pattern, constraint), replacement)) with open(os.path.join("OUTPUT", self.name + "_patterns"), "wb") as patt_f: pickle.dump(known_ops[0], patt_f) pattern = Equal([ NList([PatternDot(op.get_name()) for op in outops]), Predicate(self.name, [PatternDot(op.get_name()) for op in inops], [op.get_size() for op in outops]) ]) replacement = Replacement(equation2replacement(self.equation)) op_to_implicit.append(RewriteRule(pattern, replacement))
from BindDimensions import bindDimensions from graph import build_dep_graph, subgraphs, zero_out_lower from Partitioning import partition_shape, repartition, repartition_shape, repart_group from Tiling import tile_expr from utils import sort from RewritingExtension import * from BackEnd.MatlabCode import generate_matlab_code, click2matlab #from BackEnd.LatexCode import generate_latex_code from BackEnd.LGenCode import generate_lgen_files #from BackEnd.CCode import generate_c_code known_ops = [ # Straight assignment RewriteRule( (NList([Equal([PatternDot("LHS"), PatternDot("RHS")])]), Constraint( "isinstance(LHS, Symbol) and isOutput(LHS) and isInput(RHS)")), Replacement("Equal([ NList([LHS]), RHS ])")) ] known_ops_single = [] known_pmes = [] op_to_implicit = [] # Add a verbose option # For now it would help to illustrate the process class Operation(object): def __init__(self, name, operands, equation, overwrite): self.name = name self.operands = operands
canonicalIO_rules = [ ## Equal[Minus, Minus] -> remove minuses ## add support for head_[_] #RewriteRule( #( #Equal([ Minus([LHS]), Minus([RHS]) ]), #Constraint("True") #), #Replacement("Equal([ LHS, RHS ])") #), # Input to the right # Plus RewriteRule( ( Equal([ Plus([PS1, subexpr, PS2]), RHS ]), Constraint(lambda d: isInput(d["subexpr"])) ), Replacement(lambda d: Equal([ Plus([ d["PS1"], d["PS2"] ]), \ Plus([ d["RHS"], Minus([d["subexpr"]])]) ])) ), # Minus RewriteRule( ( Equal([ Minus([subexpr]), RHS ]), Constraint(lambda d: isInput(d["subexpr"])) ), Replacement(lambda d: Equal([ Zero(d["subexpr"].get_size()), \ Plus([d["RHS"], d["subexpr"]]) ])) ), # Transpose
def equation(self, ast): lhs = ast['lhs'] rhs = ast['rhs'] if lhs.get_size() != rhs.get_size(): raise SizeError('Equation\'s lhs and rhs have different sizes' % (lhs.get_size(), rhs.get_size())) return Equal([lhs, rhs])
class triu(Operator): def __init__(self, arg): Operator.__init__(self, [arg], [], UNARY) self.size = arg.get_size() # Patterns for inv to trsm A = PatternDot("A") B = PatternDot("B") C = PatternDot("C") # [TODO] Complete the set of patterns trsm_patterns = [ #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \ RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \ Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ), ] # # Produces Matlab code for: # - Loop-based code # def generate_matlab_code(operation, matlab_dir): # Recursive code out_path = os.path.join(matlab_dir, operation.name + ".m") # [FIX] At some this should be opname_rec.m with open(out_path, "w") as out: generate_matlab_recursive_code(operation, operation.pmes[-1], out) # pmes[-1] should be the 2x2 one
pm.DB["rdiv_utn"].overwrite = [] pm.DB["rdiv_utn_ow"] = pm.PredicateMetadata("rdiv_utn_ow", tuple()) pm.DB["rdiv_utn_ow"].overwrite = [(1, 0)] pm.DB["rdiv_utu"] = pm.PredicateMetadata("rdiv_utu", tuple()) pm.DB["rdiv_utu"].overwrite = [] pm.DB["rdiv_utu_ow"] = pm.PredicateMetadata("rdiv_utu_ow", tuple()) pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)] A = PatternDot("A") B = PatternDot("B") X = PatternDot("X") trsm2lgen_rules = [ # X = i(t(A)) B -> ldiv_lni RewriteRule(( Equal([NList([X]), Times([Inverse([A]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name" )), Replacement(lambda d: Equal([ NList([d["X"]]), Predicate("ldiv_lni", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]))), # X = i(t(A)) B -> ldiv_lni_ow RewriteRule( (Equal([NList([X]), Times([Inverse([A]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name" )), Replacement(lambda d: Equal([