def convertToASTs(self, tile_level, loop_info_table, node=None): '''To generate a sequence of ASTs that correspond to this sequence of "simple" loops''' if node == None: node = self.loops if isinstance(node, tuple): iname, subnodes = node tsize_name, titer_name, st_exp = loop_info_table[iname] id = ast.IdentExp(iname) lb = ast.IdentExp(titer_name) tmp = ast.BinOpExp(ast.IdentExp(tsize_name), ast.ParenthExp(st_exp), ast.BinOpExp.SUB) ub = ast.BinOpExp(lb, ast.ParenthExp(tmp), ast.BinOpExp.ADD) st = st_exp lbody = ast.CompStmt( self.convertToASTs(tile_level, loop_info_table, subnodes)) return self.ast_util.createForLoop(id, lb, ub, st, lbody) elif isinstance(node, list): return [ self.convertToASTs(tile_level, loop_info_table, n) for n in node ] else: return node
def __getMultiLevelTileLoop(self, tile_level, iter_names, st_exps, lbody): ''' Generate a multilevel-tile loop (for instance, suppose that the given iterator names is (i,j) and the given number of tiling levels is 2): for (it1=it2; it1<=it2+(T2i-T1i); it1+=T1i) for (jt1=jt2; jt1<=jt2+(T2j-T1j); jt1+=T1j) for (i=it1; i<=it1+(T1i-sti); i+=sti) for (j=jt1; j<=jt1+(T1j-stj); j+=stj) lbody ''' iter_names = iter_names[:] iter_names.reverse() st_exps = st_exps[:] st_exps.reverse() loop = lbody for level in range(1, tile_level + 1): if level == 1: for iname, st_exp in zip(iter_names, st_exps): n_tsize_name = self.__getTileSizeName(iname, level) lb = ast.IdentExp(self.__getTileIterName(iname, level)) loop = self.__getIntraTileLoop(iname, n_tsize_name, lb, st_exp, loop) else: for iname in iter_names: c_iname = self.__getTileIterName(iname, level - 1) n_tsize_name = self.__getTileSizeName(iname, level) lb = ast.IdentExp(self.__getTileIterName(iname, level)) st = ast.IdentExp(self.__getTileSizeName(iname, level - 1)) loop = self.__getIntraTileLoop(c_iname, n_tsize_name, lb, st, loop) return loop
def __getIntraTileLoop(self, iter_name, tsize_name, lb_exp, st_exp, lbody): ''' Generate an intra-tile loop: for (i=lb; i<=lb+(Ti-st); i+=st) lbody ''' id = ast.IdentExp(iter_name) lb = lb_exp tmp = ast.BinOpExp(ast.IdentExp(tsize_name), ast.ParenthExp(st_exp), ast.BinOpExp.SUB) ub = ast.BinOpExp(lb_exp, ast.ParenthExp(tmp), ast.BinOpExp.ADD) st = st_exp return self.ast_util.createForLoop(id, lb, ub, st, lbody)
def __getInterTileLoop(self, iter_name, tsize_name, lb_exp, ub_exp, st_exp, lbody): """ Generate an inter-tile loop: for (i=lb; i<=ub-(Ti-st); i+=Ti) lbody """ id = ast.IdentExp(iter_name) lb = lb_exp tmp = ast.BinOpExp(ast.IdentExp(tsize_name), ast.ParenthExp(st_exp), ast.BinOpExp.SUB) ub = ast.BinOpExp(ub_exp, ast.ParenthExp(tmp), ast.BinOpExp.SUB) st = ast.IdentExp(tsize_name) return self.ast_util.createForLoop(id, lb, ub, st, lbody)
def p_keyword_argument(p): 'keyword_argument : ID EQUALS assignment_expression' id = ast.IdentExp(p[1], p.lineno(1) + __start_line_no - 1) p[0] = ast.BinOpExp(id, p[3], ast.BinOpExp.EQ_ASGN, p.lineno(1) + __start_line_no - 1)
def p_primary_expression_1(p): "primary_expression : ID" p[0] = ast.IdentExp(p[1])
def __tile(self, stmt, tile_level, outer_loop_infos, preceding_stmts, lbound_info, int_vars): '''Apply tiling on the given statement''' # complain if the tiling level is not a positive integer if not isinstance(tile_level, int) or tile_level <= 0: err('orio.module.ortil.transformation internal error: invalid tiling level: %s' % tile_level) # cannot handle a directly nested compound statement if isinstance(stmt, ast.CompStmt): err('orio.module.ortil.transformation internal error: unexpected compound statement directly nested inside ' + 'another compound statement') # to handle the given expression statement or if statement if isinstance(stmt, ast.ExpStmt) or isinstance(stmt, ast.IfStmt): preceding_stmts = preceding_stmts[:] if preceding_stmts: is_tiled, last_stmts = preceding_stmts.pop() if is_tiled: preceding_stmts.append((is_tiled, last_stmts)) preceding_stmts.append((False, [stmt])) else: preceding_stmts.append((False, last_stmts + [stmt])) else: preceding_stmts.append((False, [stmt])) return preceding_stmts # to tile the given for-loop statement if isinstance(stmt, ast.ForStmt): # check if this loop is already tiled and whether it's fully tiled this_fully_tiled = stmt.fully_tiled # extract loop structure information this_linfo = self.ast_util.getForLoopInfo(stmt) this_id, this_lb_exp, this_ub_exp, this_st_exp, this_lbody = this_linfo this_iname = this_id.name # get information about the (extended) tiled outer loops outer_loop_inames = [i for i, _ in outer_loop_infos] loop_info_table = dict(outer_loop_infos) n_outer_loop_infos = outer_loop_infos + [(this_iname, this_linfo)] n_outer_loop_inames = [i for i, _ in n_outer_loop_infos] n_loop_info_table = dict(n_outer_loop_infos) # prepare loop bounds information (for iterating full rectangular tiles) need_prolog = False need_epilog = False rect_lb_exp = this_lb_exp rect_ub_exp = this_ub_exp if lbound_info: lb_name, ub_name, need_prolog, need_epilog, need_tiled_loop = lbound_info rect_lb_exp = ast.IdentExp(lb_name) rect_ub_exp = ast.IdentExp(ub_name) if not need_tiled_loop: err('orio.module.ortil.transformation internal error: unexpected case where generation of the orio.main.' + 'rectangular tiled loop is needed') # get explicit loop-bound scanning code t = self.__getLoopBoundScanningStmts(this_lbody.stmts, tile_level, n_outer_loop_inames, n_loop_info_table) scan_stmts, lbound_info_seq, ivars = t # update the newly declared integer variables int_vars.extend(ivars) # initialize the resulting statements res_stmts = preceding_stmts[:] # generate the prolog code if need_prolog: ub = ast.BinOpExp(rect_lb_exp, ast.ParenthExp(this_st_exp), ast.BinOpExp.SUB) prolog_code = self.ast_util.createForLoop( this_id, this_lb_exp, ub, this_st_exp, this_lbody) if res_stmts: is_tiled, last_stmts = res_stmts.pop() if is_tiled: res_stmts.append((is_tiled, last_stmts)) res_stmts.append((False, [prolog_code])) else: res_stmts.append((False, last_stmts + [prolog_code])) else: res_stmts.append((False, [prolog_code])) # start generating the orio.main.rectangularly tiled code # (note: the body of the tiled code may contain if-statement branches, # each needed to be recursively transformed) # example of the resulting processed statements: # s1; if (t-exp) {s2; s3;} else s4; # is represented as the following list: # [s1, t-exp, [s2, s3], [s4]] contain_loop = False processed_stmts = [] if_branches = [ processed_stmts ] # a container for storing list of if-branch statements for s, binfo in zip(this_lbody.stmts, lbound_info_seq): # check if one of the enclosed statements is a loop if isinstance(s, ast.ForStmt): contain_loop = True # perform transformation on each if-branch statements n_if_branches = [] for p_stmts in if_branches: # replicate the statement (just to be safe) s = s.replicate() # this is NOT a loop statement with bound expressions that are functions of # outer loop iterators if binfo == None: n_p_stmts = self.__tile(s, tile_level, n_outer_loop_infos, p_stmts, binfo, int_vars) while len(p_stmts) > 0: p_stmts.pop() p_stmts.extend(n_p_stmts) n_if_branches.append(p_stmts) continue # (optimization) special handling for one-time loop --> remove the if's true # condition (i.e., lb<ub) since it will never be executed. _, _, _, _, need_tiled_loop = binfo if not need_tiled_loop: if p_stmts: is_tiled, last_stmts = p_stmts.pop() if is_tiled: p_stmts.append((is_tiled, last_stmts)) p_stmts.append((False, [s])) else: p_stmts.append((False, last_stmts + [s])) else: p_stmts.append((False, [s])) n_if_branches.append(p_stmts) continue # (optimization) recursively feed in the last processed statement only, and # leave the other preceeding statements untouched --> for reducing code size if len(p_stmts) > 0: p = p_stmts.pop() last_p_stmts = [p] else: last_p_stmts = [] # perform a recursion to further tile this rectangularly tiled loop n_p_stmts = self.__tile(s.replicate(), tile_level, n_outer_loop_infos, last_p_stmts, binfo, int_vars) # compute the processed statements for both true and false conditions true_p_stmts = n_p_stmts false_p_stmts = last_p_stmts if false_p_stmts: is_tiled, last_stmts = false_p_stmts.pop() if is_tiled: false_p_stmts.append((is_tiled, last_stmts)) false_p_stmts.append((False, [s])) else: false_p_stmts.append((False, last_stmts + [s])) else: false_p_stmts.append((False, [s])) # create two sets of if-branch statements lbn, ubn, _, _, _ = binfo test_exp = ast.BinOpExp(ast.IdentExp(lbn), ast.IdentExp(ubn), ast.BinOpExp.LT) p_stmts.append(test_exp) p_stmts.append(true_p_stmts) p_stmts.append(false_p_stmts) n_if_branches.append(true_p_stmts) n_if_branches.append(false_p_stmts) # update the if-branch statements if_branches = n_if_branches # combine the loop-bound scanning statements lbody_stmts = [] lbody_stmts.extend(scan_stmts) # convert the processed statements into AST lbody_stmts.extend( self.__convertToASTs(processed_stmts, tile_level, contain_loop, n_outer_loop_inames, n_loop_info_table, int_vars)) # generate the orio.main.rectangularly tiled code lbody = ast.CompStmt(lbody_stmts) iname = self.__getTileIterName(this_iname, tile_level) tname = self.__getTileSizeName(this_iname, tile_level) tiled_code = self.__getInterTileLoop(iname, tname, rect_lb_exp, rect_ub_exp, this_st_exp, lbody) res_stmts.append((True, [tiled_code])) # mark the loop if it's a loop iterating the full rectangular tiles self.__labelFullCoreTiledLoop(tiled_code, n_outer_loop_inames) # generate the cleanup code (the epilog is already fused) if not this_fully_tiled: lb = ast.IdentExp( self.__getTileIterName(this_iname, tile_level)) cleanup_code = self.ast_util.createForLoop( this_id, lb, this_ub_exp, this_st_exp, this_lbody) res_stmts.append((False, [cleanup_code])) # return the resulting statements return res_stmts # unknown statement err('orio.module.ortil.transformation internal error: unknown type of statement: %s' % stmt.__class__.__name__)
def __convertToASTs(self, dstmts, tile_level, contain_loop, loop_inames, loop_info_table, int_vars): ''' To recursively convert the given list, containing processed statements and possibly if-branching statements, to AST. A sample of the given list is as follows: [s1, t-exp, [s2, s3], [s4]] which represents the following AST: s1; if (t-exp) {s2; s3;} else s4; ''' # initialize the list of ASTs asts = [] # iterating over each list element i = 0 while i < len(dstmts): s = dstmts[i] # if it's an AST if isinstance(s, tuple): is_tiled, stmts = s stmts = [s.replicate() for s in stmts] # already tiled; no need to enclose with an inter-tile loop nest if is_tiled: asts.extend(stmts) # need to enclose with a single-level inter-tile loop nest elif contain_loop: # add single-level tiled loop rev_inames = loop_inames[:] rev_inames.reverse() l = ast.CompStmt(stmts) for iname in rev_inames: tname = self.__getTileSizeName(iname, tile_level) lb_exp = ast.IdentExp( self.__getTileIterName(iname, tile_level)) _, _, _, st_exp, _ = loop_info_table[iname] lbody = l l = self.__getIntraTileLoop(iname, tname, lb_exp, st_exp, lbody) l.fully_tiled = True # to recursively tile all boundary tiles if self.use_boundary_tiling: new_tile_level = min(tile_level - 1, self.recursive_tile_level) if new_tile_level > 0: l = self.__startTiling(l, new_tile_level, int_vars) # insert the tiled loop to the AST list asts.append(l) # need to enclose with a multilevel inter-tile loop nest else: lbody = ast.CompStmt(stmts) st_exps = [] for iname in loop_inames: _, _, _, st_exp, _ = loop_info_table[iname] st_exps.append(st_exp) l = self.__getMultiLevelTileLoop(tile_level, loop_inames, st_exps, lbody) asts.append(l) # increment index i += 1 # if it's an if-statement's test expression else: if not isinstance(s, ast.BinOpExp): err('orio.module.ortil.transformation internal error: a test expression is expected' ) # generate AST for the true statement t1 = self.__convertToASTs(dstmts[i + 1], tile_level, contain_loop, loop_inames, loop_info_table, int_vars) # generate AST for the false statement t2 = self.__convertToASTs(dstmts[i + 2], tile_level, contain_loop, loop_inames, loop_info_table, int_vars) # generate AST for the if-statement test_exp = s.replicate() true_stmt = ast.CompStmt(t1) false_stmt = ast.CompStmt(t2) asts.append(ast.IfStmt(test_exp, true_stmt, false_stmt)) # increment index i += 3 # return the list of ASTs return asts
def __staticLoopBoundScanning(self, stmts, tile_level, outer_loop_inames, loop_info_table): ''' Assuming that the loop-bound expressions are affine functions of outer loop iterators and global parameters, we can determine the loop bounds of full tiles in compile time. This is an optimization strategy to produce more efficient code. Assumptions: 1. Lower bound expression must be in the form of: max(e_1,e_2,e_3,...,e_n) 2. Upper bound expression must be in the form of: min(e_1,e_2,e_3,...,e_n) where e_i is an affine function of outer loop iterators and global parameters Note that max(x,y,z) is implemented as nested binary max functions: max(z,max(y,z)). The same condition applies for min function. When n=1, max/min function is not needed. ''' # initialize all returned variables scan_stmts = [] lbound_info_seq = [] int_vars = [] # generate the lower and upper values of each inter-tile loop val_table = {} for iname in outer_loop_inames: _, _, _, st_exp, _ = loop_info_table[iname] lval = ast.IdentExp(self.__getTileIterName(iname, tile_level)) t = ast.BinOpExp( ast.IdentExp(self.__getTileSizeName(iname, tile_level)), ast.ParenthExp(st_exp.replicate()), ast.BinOpExp.SUB) uval = ast.BinOpExp(lval.replicate(), ast.ParenthExp(t), ast.BinOpExp.ADD) val_table[iname] = (lval, uval) # iterate over each statement to determine loop bounds that are affine functions # of outer loop iterators lb_exps_table = {} ub_exps_table = {} for stmt in stmts: # skip all non loop statements if not isinstance(stmt, ast.ForStmt): lbound_info_seq.append(None) continue # extract this loop structure id, lb_exp, ub_exp, st_exp, lbody = self.ast_util.getForLoopInfo( stmt) # see if the loop bound expressions are bound/free of outer loop iterators lb_inames = filter( lambda i: self.ast_util.containIdentName(lb_exp, i), outer_loop_inames) ub_inames = filter( lambda i: self.ast_util.containIdentName(ub_exp, i), outer_loop_inames) # skip loops with bound expressions that are free of outer loop iterators if not lb_inames and not ub_inames: lbound_info_seq.append(None) continue # check if this loop runs only once is_one_time_loop = str(lb_exp) == str(ub_exp) # generate booleans to indicate the needs of prolog, epilog, and orio.main.tiled loop if is_one_time_loop: need_tiled_loop = False need_prolog = False need_epilog = False else: need_tiled_loop = True need_prolog = len(lb_inames) > 0 need_epilog = len(ub_inames) > 0 # generate new variable names for both the new lower and upper loop bounds if need_tiled_loop: lb_name, ub_name = self.__getLoopBoundNames() int_vars.extend([lb_name, ub_name]) else: lb_name = '' ub_name = '' # append information about the new loop bounds lbinfo = (lb_name, ub_name, need_prolog, need_epilog, need_tiled_loop) lbound_info_seq.append(lbinfo) # skip generating loop-bound scanning code (if it's a one-time loop) if not need_tiled_loop: continue # determine the value of the new lower loop bound if str(lb_exp) in lb_exps_table: lb_var = lb_exps_table[str(lb_exp)] a = ast.BinOpExp(ast.IdentExp(lb_name), lb_var.replicate(), ast.BinOpExp.EQ_ASGN) else: if need_prolog: t = self.__findMinMaxVal('max', lb_exp.replicate(), lb_inames, val_table) a = ast.BinOpExp(ast.IdentExp(lb_name), t.replicate(), ast.BinOpExp.EQ_ASGN) else: a = ast.BinOpExp(ast.IdentExp(lb_name), lb_exp.replicate(), ast.BinOpExp.EQ_ASGN) lb_exps_table[str(lb_exp)] = ast.IdentExp(lb_name) scan_stmts.append(ast.ExpStmt(a)) # determine the value of the new upper loop bound if str(ub_exp) in ub_exps_table: ub_var = ub_exps_table[str(ub_exp)] a = ast.BinOpExp(ast.IdentExp(ub_name), ub_var.replicate(), ast.BinOpExp.EQ_ASGN) else: if need_epilog: t = self.__findMinMaxVal('min', ub_exp.replicate(), ub_inames, val_table) a = ast.BinOpExp(ast.IdentExp(ub_name), t.replicate(), ast.BinOpExp.EQ_ASGN) else: a = ast.BinOpExp(ast.IdentExp(ub_name), ub_exp.replicate(), ast.BinOpExp.EQ_ASGN) ub_exps_table[str(ub_exp)] = ast.IdentExp(ub_name) scan_stmts.append(ast.ExpStmt(a)) # return all necessary information return (scan_stmts, lbound_info_seq, int_vars)
def __replaceArrRefs(self, tnode, replace_table): """To replace some array references with specified identifiers""" if isinstance(tnode, ast.NumLitExp): return tnode elif isinstance(tnode, ast.StringLitExp): return tnode elif isinstance(tnode, ast.IdentExp): return tnode elif isinstance(tnode, ast.ArrayRefExp): aref_str = str(tnode) if aref_str in replace_table: iname = replace_table[aref_str] return ast.IdentExp(iname) else: return tnode elif isinstance(tnode, ast.FunCallExp): tnode.exp = self.__replaceArrRefs(tnode.exp, replace_table) tnode.args = [ self.__replaceArrRefs(a, replace_table) for a in tnode.args ] return tnode elif isinstance(tnode, ast.UnaryExp): tnode.exp = self.__replaceArrRefs(tnode.exp, replace_table) return tnode elif isinstance(tnode, ast.BinOpExp): tnode.lhs = self.__replaceArrRefs(tnode.lhs, replace_table) tnode.rhs = self.__replaceArrRefs(tnode.rhs, replace_table) return tnode elif isinstance(tnode, ast.ParenthExp): tnode.exp = self.__replaceArrRefs(tnode.exp, replace_table) return tnode elif isinstance(tnode, ast.ExpStmt): if tnode.exp: tnode.exp = self.__replaceArrRefs(tnode.exp, replace_table) return tnode elif isinstance(tnode, ast.CompStmt): tnode.stmts = [ self.__replaceArrRefs(s, replace_table) for s in tnode.stmts ] return tnode elif isinstance(tnode, ast.IfStmt): tnode.test = self.__replaceArrRefs(tnode.test, replace_table) tnode.true_stmt = self.__replaceArrRefs(tnode.true_stmt, replace_table) if tnode.false_stmt: tnode.false_stmt = self.__replaceArrRefs( tnode.false_stmt, replace_table) return tnode elif isinstance(tnode, ast.ForStmt): if tnode.init: tnode.init = self.__replaceArrRefs(tnode.init, replace_table) if tnode.test: tnode.test = self.__replaceArrRefs(tnode.test, replace_table) if tnode.iter: tnode.iter = self.__replaceArrRefs(tnode.iter, replace_table) tnode.stmt = self.__replaceArrRefs(tnode.stmt, replace_table) return tnode else: err("orio.module.ortildriver.transformation internal error:OrTilDriver: unknown type of AST: %s" % tnode.__class__.__name__)
def p_expression_4(p): 'expression : ID ID EQUALS expression' p[0] = ast.VarDeclInit(p[1], ast.IdentExp(p[2]), p[4], p.lineno(1) + __start_line_no - 1)
def p_primary_expression_1(p): "primary_expression : ID" p[0] = ast.IdentExp(p[1], getLineNumber(p.lineno(1)))
def getForLoopInfo(self, stmt): ''' Return information about the loop structure. for (id = lb; id <= ub; id = id + st) bod ''' # get rid of compound statement that contains only a single statement while isinstance(stmt, ast.CompStmt) and len(stmt.stmts) == 1: stmt = stmt.stmts[0] # check if it is a loop if not isinstance(stmt, ast.ForStmt): err('orio.module.tilic.ast_util: Tilic: input not a loop statement' ) # check loop initialization if stmt.init: if not (isinstance(stmt.init, ast.BinOpExp) and stmt.init.op_type == ast.BinOpExp.EQ_ASGN and isinstance(stmt.init.lhs, ast.IdentExp)): err('orio.module.tilic.ast_util: Tilic: loop initialization not in "id = lb" form' ) # check loop test if stmt.test: if not (isinstance(stmt.test, ast.BinOpExp) and stmt.test.op_type in (ast.BinOpExp.LE, ast.BinOpExp.LT) and isinstance(stmt.test.lhs, ast.IdentExp)): err('orio.module.tilic.ast_util: Tilic: loop test not in "id <= ub" or "id < ub" form' ) # check loop iteration if stmt.iter: if not ((isinstance(stmt.iter, ast.BinOpExp) and stmt.iter.op_type == ast.BinOpExp.EQ_ASGN and isinstance(stmt.iter.lhs, ast.IdentExp) and isinstance(stmt.iter.rhs, ast.BinOpExp) and isinstance(stmt.iter.rhs.lhs, ast.IdentExp) and stmt.iter.rhs.op_type == ast.BinOpExp.ADD and stmt.iter.lhs.name == stmt.iter.rhs.lhs.name) or (isinstance(stmt.iter, ast.UnaryExp) and isinstance( stmt.iter.exp, ast.IdentExp) and stmt.iter.op_type in (ast.UnaryExp.PRE_INC, ast.UnaryExp.POST_INC))): err('orio.module.tilic.ast_util: Tilic: loop iteration not in "id++" or "id += st" or "id = id + st" form' ) # check if the control expressions are all empty if not stmt.init and not stmt.test and not stmt.iter: err('orio.module.tilic.ast_util: Tilic: loop with an empty control expression cannot be handled' ) # check if the iterator names in the control expressions are all the same inames = [] if stmt.init: inames.append(stmt.init.lhs.name) if stmt.test: inames.append(stmt.test.lhs.name) if stmt.iter: if isinstance(stmt.iter, ast.BinOpExp): inames.append(stmt.iter.lhs.name) else: inames.append(stmt.iter.exp.name) if inames.count(inames[0]) != len(inames): err('orio.module.tilic.ast_util: Tilic: different iterator names used in the loop control expressions' ) # extract the loop structure information id = ast.IdentExp(inames[0]) lb = None ub = None st = None if stmt.init: lb = stmt.init.rhs.replicate() if stmt.test: if stmt.test.op_type == ast.BinOpExp.LT: ub = ast.BinOpExp(stmt.test.rhs.replicate(), ast.NumLitExp(1, ast.NumLitExp.INT), ast.BinOpExp.SUB) else: ub = stmt.test.rhs.replicate() if stmt.iter: if isinstance(stmt.iter, ast.BinOpExp): st = stmt.iter.rhs.rhs.replicate() else: st = ast.NumLitExp(1, ast.NumLitExp.INT) bod = stmt.stmt.replicate() # return the loop structure information return (id, lb, ub, st, bod)
def p_primary_expression_1(p): 'primary_expression : ID' p[0] = ast.IdentExp(p[1], p.lineno(1) + __start_line_no - 1)
def __getLoopBoundScanningStmts(self, stmts, tile_level, outer_loop_inames, loop_info_table): ''' Generate an explicit loop-bound scanning code used at runtime to determine the latest start and the earliest end of scanning full tiles. ''' # (optimization) generate code that determines the loop bounds of full tiles at compile time if self.affine_lbound_exps: return self.__staticLoopBoundScanning(stmts, tile_level, outer_loop_inames, loop_info_table) # initialize all returned variables scan_stmts = [] lbound_info_seq = [] int_vars = [] # iterate over each statement to find loop bounds that are functions of outer loop iterators min_int = ast.NumLitExp(-2147483648, ast.NumLitExp.INT) max_int = ast.NumLitExp(2147483647, ast.NumLitExp.INT) lb_exps_table = {} ub_exps_table = {} pre_scan_stmts = [] post_scan_stmts = [] scan_loops = SimpleLoops() for stmt in stmts: # skip all non loop statements if not isinstance(stmt, ast.ForStmt): lbound_info_seq.append(None) continue # extract this loop structure id, lb_exp, ub_exp, st_exp, lbody = self.ast_util.getForLoopInfo( stmt) # see if the loop bound expressions are bound/free of outer loop iterators lb_inames = filter( lambda i: self.ast_util.containIdentName(lb_exp, i), outer_loop_inames) ub_inames = filter( lambda i: self.ast_util.containIdentName(ub_exp, i), outer_loop_inames) # skip loops with bound expressions that are free of outer loop iterators if not lb_inames and not ub_inames: lbound_info_seq.append(None) continue # check if this loop runs only once is_one_time_loop = str(lb_exp) == str(ub_exp) # generate booleans to indicate the needs of prolog, epilog, and orio.main.tiled loop if is_one_time_loop: need_tiled_loop = False need_prolog = False need_epilog = False else: need_tiled_loop = True need_prolog = len(lb_inames) > 0 need_epilog = len(ub_inames) > 0 # generate new variable names for both the new lower and upper loop bounds if need_tiled_loop: lb_name, ub_name = self.__getLoopBoundNames() int_vars.extend([lb_name, ub_name]) else: lb_name = '' ub_name = '' # append information about the new loop bounds lbinfo = (lb_name, ub_name, need_prolog, need_epilog, need_tiled_loop) lbound_info_seq.append(lbinfo) # skip generating loop-bound scanning code (if it's a one-time loop) if not need_tiled_loop: continue # generate loop-bound scanning code for the prolog if str(lb_exp) in lb_exps_table: lb_var = lb_exps_table[str(lb_exp)] a = ast.BinOpExp(ast.IdentExp(lb_name), lb_var.replicate(), ast.BinOpExp.EQ_ASGN) post_scan_stmts.append(ast.ExpStmt(a)) else: if need_prolog: a = ast.BinOpExp(ast.IdentExp(lb_name), min_int.replicate(), ast.BinOpExp.EQ_ASGN) pre_scan_stmts.append(ast.ExpStmt(a)) a = ast.BinOpExp( ast.IdentExp(lb_name), ast.FunCallExp( ast.IdentExp('max'), [ast.IdentExp(lb_name), lb_exp.replicate()]), ast.BinOpExp.EQ_ASGN) scan_loops.insertLoop(lb_inames, ast.ExpStmt(a)) else: a = ast.BinOpExp(ast.IdentExp(lb_name), lb_exp.replicate(), ast.BinOpExp.EQ_ASGN) pre_scan_stmts.append(ast.ExpStmt(a)) lb_exps_table[str(lb_exp)] = ast.IdentExp(lb_name) # generate loop-bound scaning code for the epilog if str(ub_exp) in ub_exps_table: ub_var = ub_exps_table[str(ub_exp)] a = ast.BinOpExp(ast.IdentExp(ub_name), ub_var.replicate(), ast.BinOpExp.EQ_ASGN) post_scan_stmts.append(ast.ExpStmt(a)) else: if need_epilog: a = ast.BinOpExp(ast.IdentExp(ub_name), max_int.replicate(), ast.BinOpExp.EQ_ASGN) pre_scan_stmts.append(ast.ExpStmt(a)) a = ast.BinOpExp( ast.IdentExp(ub_name), ast.FunCallExp( ast.IdentExp('min'), [ast.IdentExp(ub_name), ub_exp.replicate()]), ast.BinOpExp.EQ_ASGN) scan_loops.insertLoop(ub_inames, ast.ExpStmt(a)) else: a = ast.BinOpExp(ast.IdentExp(ub_name), ub_exp.replicate(), ast.BinOpExp.EQ_ASGN) pre_scan_stmts.append(ast.ExpStmt(a)) ub_exps_table[str(ub_exp)] = ast.IdentExp(ub_name) # build a new loop information tabe for generating the loop-bound scanning code n_loop_info_table = {} for iname, linfo in loop_info_table.items(): _, _, _, st_exp, _ = linfo n_loop_info_table[iname] = (self.__getTileSizeName( iname, tile_level), self.__getTileIterName(iname, tile_level), st_exp) # convert the "SimpleLoop" abstractions into loop ASTs scan_loop_stmts = scan_loops.convertToASTs(tile_level, n_loop_info_table) # merge all scanning statements scan_stmts = pre_scan_stmts + scan_loop_stmts + post_scan_stmts # return all necessary information return (scan_stmts, lbound_info_seq, int_vars)
def getForLoopInfo(self, stmt): ''' Given a for-loop statement, extract information about its loop structure. Note that the for-loop must be in the following form: for (<id> = <exp>; <id> <= <exp>; <id> += <exp>) <stmt> Subtraction is not considered at the iteration expression for the sake of the implementation simplicity. ''' # get rid of compound statement that contains only a single statement while isinstance(stmt, ast.CompStmt) and len(stmt.stmts) == 1: stmt = stmt.stmts[0] # check if it is a for-loop statement if not isinstance(stmt, ast.ForStmt): err('orio.module.ortil.ast_util: OrTil:%s: not a for-loop statement' % stmt.line_no) # check initialization expression if stmt.init: while True: while isinstance(stmt.init, ast.ParenthExp): stmt.init = stmt.init.exp if (isinstance(stmt.init, ast.BinOpExp) and stmt.init.op_type == ast.BinOpExp.EQ_ASGN): while isinstance(stmt.init.lhs, ast.ParenthExp): stmt.init.lhs = stmt.init.lhs.exp while isinstance(stmt.init.rhs, ast.ParenthExp): stmt.init.rhs = stmt.init.rhs.exp if isinstance(stmt.init.lhs, ast.IdentExp): break err('orio.module.ortil.ast_util:%s: loop initialization expression not in "<id> = <exp>" form' % stmt.init.line_no) # check test expression if stmt.test: while True: while isinstance(stmt.test, ast.ParenthExp): stmt.test = stmt.test.exp if (isinstance(stmt.test, ast.BinOpExp) and stmt.test.op_type in (ast.BinOpExp.LT, ast.BinOpExp.LE)): while isinstance(stmt.test.lhs, ast.ParenthExp): stmt.test.lhs = stmt.test.lhs.exp while isinstance(stmt.test.rhs, ast.ParenthExp): stmt.test.rhs = stmt.test.rhs.exp if isinstance(stmt.test.lhs, ast.IdentExp): break err('orio.module.ortil.ast_util:%s: loop test expression not in "<id> <= <exp>" or ' + '"<id> < <exp>"form' % stmt.test.line_no) # check iteration expression if stmt.iter: while True: while isinstance(stmt.iter, ast.ParenthExp): stmt.iter = stmt.iter.exp if (isinstance(stmt.iter, ast.BinOpExp) and stmt.iter.op_type == ast.BinOpExp.EQ_ASGN): while isinstance(stmt.iter.lhs, ast.ParenthExp): stmt.iter.lhs = stmt.iter.lhs.exp while isinstance(stmt.iter.rhs, ast.ParenthExp): stmt.iter.rhs = stmt.iter.rhs.exp if isinstance(stmt.iter.lhs, ast.IdentExp): if (isinstance(stmt.iter.rhs, ast.BinOpExp) and stmt.iter.rhs.op_type in (ast.BinOpExp.ADD, ast.BinOpExp.SUB)): while isinstance(stmt.iter.rhs.lhs, ast.ParenthExp): stmt.iter.rhs.lhs = stmt.iter.rhs.lhs.exp while isinstance(stmt.iter.rhs.rhs, ast.ParenthExp): stmt.iter.rhs.rhs = stmt.iter.rhs.rhs.exp if (isinstance(stmt.iter.rhs.lhs, ast.IdentExp) and stmt.iter.lhs.name == stmt.iter.rhs.lhs.name): break elif (isinstance(stmt.iter, ast.UnaryExp) and stmt.iter.op_type in (ast.UnaryExp.POST_INC, ast.UnaryExp.PRE_INC, ast.UnaryExp.POST_DEC, ast.UnaryExp.PRE_DEC)): while isinstance(stmt.iter.exp, ast.ParenthExp): stmt.iter.exp = stmt.iter.exp.exp if isinstance(stmt.iter.exp, ast.IdentExp): break err(( 'orio.module.ortil.ast_util:%s: loop iteration expression not in "<id>++" or "<id>--" or ' + '"<id> += <exp>" or "<id> = <id> + <exp>" form') % stmt.iter.line_no) # check if the control expressions are all empty if not stmt.init and not stmt.test and not stmt.iter: err('orio.module.ortil.ast_util:%s: a loop with an empty control expression cannot be handled' % stmt.line_no) # check if the iterator names are all the same init_iname = None test_iname = None iter_iname = None if stmt.init: init_iname = stmt.init.lhs.name if stmt.test: test_iname = stmt.test.lhs.name if stmt.iter: if isinstance(stmt.iter, ast.BinOpExp): iter_iname = stmt.iter.lhs.name else: assert (isinstance( stmt.iter, ast.UnaryExp)), 'internal error:OrTil: not unary' iter_iname = stmt.iter.exp.name inames = [] if init_iname: inames.append(init_iname) if test_iname: inames.append(test_iname) if iter_iname: inames.append(iter_iname) if inames.count(inames[0]) != len(inames): err('orio.module.ortil.ast_util:%s: iterator names across init, test, and iter exps must be the same' % stmt.line_no) # extract for-loop structure information index_id = ast.IdentExp(inames[0]) lbound_exp = None ubound_exp = None stride_exp = None if stmt.init: lbound_exp = stmt.init.rhs.replicate() if stmt.test: if stmt.test.op_type == ast.BinOpExp.LT: ubound_exp = ast.BinOpExp(stmt.test.rhs.replicate(), ast.NumLitExp(1, ast.NumLitExp.INT), ast.BinOpExp.SUB) else: ubound_exp = stmt.test.rhs.replicate() if stmt.iter: if isinstance(stmt.iter, ast.BinOpExp): stride_exp = stmt.iter.rhs.rhs.replicate() if isinstance(stride_exp, ast.BinOpExp): stride_exp = ast.ParenthExp(stride_exp) if stmt.iter.rhs.op_type == ast.BinOpExp.SUB: stride_exp = ast.UnaryExp(stride_exp, ast.UnaryExp.MINUS) elif isinstance(stmt.iter, ast.UnaryExp): if stmt.iter.op_type in (ast.UnaryExp.POST_INC, ast.UnaryExp.PRE_INC): stride_exp = ast.NumLitExp(1, ast.NumLitExp.INT) elif stmt.iter.op_type in (ast.UnaryExp.POST_DEC, ast.UnaryExp.PRE_DEC): stride_exp = ast.NumLitExp(-1, ast.NumLitExp.INT) else: err('orio.module.ortil.ast_util internal error:OrTil: unexpected unary operation type' ) else: err('orio.module.ortil.ast_utilinternal error:OrTil: unexpected type of iteration expression' ) loop_body = stmt.stmt.replicate() for_loop_info = (index_id, lbound_exp, ubound_exp, stride_exp, loop_body) # return the for-loop structure information return for_loop_info
def p_primary_expression_1(p): "primary_expression : ID" p[0] = ast.IdentExp(p[1], line_no=str(p.lineno(1) + __start_line_no - 1))