def unroll_to_depth( max_depth ): print "\n\nunroll_to_depth(%d)" % max_depth print "SYNC UP" sys.stdout.flush() cur = chill.cur_indices(0) thread_idxs = chill.thread_indices() guard_idx = thread_idxs[-1] # last one print "cur indices", print_array(cur) print "thread indices", print_array(thread_idxs) print "guard_idx = %s" % guard_idx #print "thread_idxs = ", #print thread_idxs guard_idx = thread_idxs[-1] #print "guard_idx = %s" % guard_idx # HERE FIND OUT THE LOOPS WHICH ARE COMMON BETWEEN STATEMENTS common_loops = [] comm_loops_cnt = 0 num_stmts = chill.num_statements() print "num statements %d" % num_stmts for stmt in range(num_stmts): sys.stdout.flush() print "\nSTMT %d" % stmt, cur_idxs = chill.cur_indices(stmt) print "Current Indices:", for c in cur_idxs[:-1]: print "%s," % c, print "%s" % cur_idxs[-1] # last one sys.stdout.flush() #print_code() if chk_cur_level(stmt, "tx") > 0: for ii in range(find_cur_level(stmt,"tx")-1): print "ii = %d\ncur_idxs[%d] = '%s'" % (ii+1, ii+1, cur_idxs[ii]) # print to match lua id = cur_idxs[ii] if id not in ["bx", "by", "", "tx", "ty"]: print "id %s is not in the list" % id for stmt1 in range(stmt+1, num_stmts): print "\nii %d stmt1 is %d" % (ii+1, stmt1) # print to match lua cur_idxs1 = chill.cur_indices(stmt1) print "\nstmt1 cur_idxs1 is ", for ind in cur_idxs1[:-1]: print "%s," % ind, print "%s" % cur_idxs1[-1] print "cur level(%d, %s) = %d" % (stmt, "tx", find_cur_level(stmt,"tx") ) sys.stdout.flush() endrange = find_cur_level(stmt,"tx")-1 print "for iii=1, %d do" % endrange sys.stdout.flush() for iii in range(endrange): # off by one? TODO print "stmt %d ii %d iii %d\n" % (stmt, ii+1, iii+1), sys.stdout.flush() if iii >= len(cur_idxs1): print "stmt %d ii %d iii %d cur_idxs1[%d] = NIL" % (stmt, ii+1, iii+1, iii+1, ) # print to match lua else: print "stmt %d ii %d iii %d cur_idxs1[%d] = '%s'" % (stmt, ii+1, iii+1, iii+1, cur_idxs1[iii]) # print to match lua sys.stdout.flush() # this will still probably die if iii < len(cur_idxs1) and [iii] not in ["bx", "by", "tx", "ty", ""]: if cur_idxs[ii] == cur_idxs1[iii]: print "\nfound idx:%s" % cur_idxs[ii] common_loops.append(cur_idxs[ii]) print "cl[%d] = '%s'" % ( comm_loops_cnt, cur_idxs[ii] ) comm_loops_cnt = len(common_loops) if len(common_loops) > 0: print "\n COMM LOOPS :TOTAL %d, and are " % comm_loops_cnt, print common_loops, print " this loop : %s" % common_loops[0] else: print "UNROLL can't unroll any loops?" while True: # break at bottom of loop (repeat in lua) old_num_statements = chill.num_statements() print "old_num_statements %d" % old_num_statements for stmt in range(old_num_statements): cur_idxs = chill.cur_indices(stmt) print "stmt %d cur_idxs =" % stmt, index = 0 for i in cur_idxs: index +=1 if index == len(cur_idxs): print "%s" %i else: print "%s," % i, if len(cur_idxs) > 0: guard_level = -1 if chk_cur_level(stmt, guard_idx) > 0: guard_level = find_cur_level(stmt,guard_idx) print "guard_level(sp) = %d" % guard_level if guard_level > -1: level = next_clean_level(cur_idxs,guard_level) print "next clean level %d" % level #print "looking at %d" % stmt #print "comparing %d and %d in" % (guard_level, level), #index = 0 #for i in cur_idxs: #index +=1 #if index == len(cur_idxs): # print "%s" %i #else: # print "%s," % i, # need to handle max_depth num_unrolled = 0 level_unroll_comm = level level_arr = [] #print "before while, level = %d" % level while level >= 0: print "while: level = %d" % level if num_unrolled == max_depth: break print "Unrolling %d at level %d index %s" % ( stmt, level, cur_idxs[guard_level]) # ??? level_arr.append(level) guard_level = find_cur_level(stmt,guard_idx) level = next_clean_level(cur_idxs,level+1) print "OK, NOW WE UNROLL" if level_unroll_comm >= 0: level_arr.reverse() for i,lev in enumerate(level_arr): print "\ni=%d" % i print "[Unroll]unroll(%d, %d, 0)" % (stmt, lev) chill.unroll(stmt, lev, 0) new_num_statements = chill.num_statements() if old_num_statements == new_num_statements: break # exit infinite loop
def unroll( statement, level, unroll_amount ): chill.unroll( statement, level, unroll_amount )
import chill chill.source('src/nestedloops.c') chill.destination('src/nestedloops_modified.c') chill.procedure('foo') chill.loop(0) chill.permute([2, 3, 1]) chill.unroll(0, 3, 2)