def unroll_to_depth( max_depth ): print "\n\nunroll_to_depth(%d)" % max_depth print "SYNC UP" sys.stdout.flush() cur = chill.cur_indices(0) thread_idxs = chill.thread_indices() guard_idx = thread_idxs[-1] # last one print "cur indices", print_array(cur) print "thread indices", print_array(thread_idxs) print "guard_idx = %s" % guard_idx #print "thread_idxs = ", #print thread_idxs guard_idx = thread_idxs[-1] #print "guard_idx = %s" % guard_idx # HERE FIND OUT THE LOOPS WHICH ARE COMMON BETWEEN STATEMENTS common_loops = [] comm_loops_cnt = 0 num_stmts = chill.num_statements() print "num statements %d" % num_stmts for stmt in range(num_stmts): sys.stdout.flush() print "\nSTMT %d" % stmt, cur_idxs = chill.cur_indices(stmt) print "Current Indices:", for c in cur_idxs[:-1]: print "%s," % c, print "%s" % cur_idxs[-1] # last one sys.stdout.flush() #print_code() if chk_cur_level(stmt, "tx") > 0: for ii in range(find_cur_level(stmt,"tx")-1): print "ii = %d\ncur_idxs[%d] = '%s'" % (ii+1, ii+1, cur_idxs[ii]) # print to match lua id = cur_idxs[ii] if id not in ["bx", "by", "", "tx", "ty"]: print "id %s is not in the list" % id for stmt1 in range(stmt+1, num_stmts): print "\nii %d stmt1 is %d" % (ii+1, stmt1) # print to match lua cur_idxs1 = chill.cur_indices(stmt1) print "\nstmt1 cur_idxs1 is ", for ind in cur_idxs1[:-1]: print "%s," % ind, print "%s" % cur_idxs1[-1] print "cur level(%d, %s) = %d" % (stmt, "tx", find_cur_level(stmt,"tx") ) sys.stdout.flush() endrange = find_cur_level(stmt,"tx")-1 print "for iii=1, %d do" % endrange sys.stdout.flush() for iii in range(endrange): # off by one? TODO print "stmt %d ii %d iii %d\n" % (stmt, ii+1, iii+1), sys.stdout.flush() if iii >= len(cur_idxs1): print "stmt %d ii %d iii %d cur_idxs1[%d] = NIL" % (stmt, ii+1, iii+1, iii+1, ) # print to match lua else: print "stmt %d ii %d iii %d cur_idxs1[%d] = '%s'" % (stmt, ii+1, iii+1, iii+1, cur_idxs1[iii]) # print to match lua sys.stdout.flush() # this will still probably die if iii < len(cur_idxs1) and [iii] not in ["bx", "by", "tx", "ty", ""]: if cur_idxs[ii] == cur_idxs1[iii]: print "\nfound idx:%s" % cur_idxs[ii] common_loops.append(cur_idxs[ii]) print "cl[%d] = '%s'" % ( comm_loops_cnt, cur_idxs[ii] ) comm_loops_cnt = len(common_loops) if len(common_loops) > 0: print "\n COMM LOOPS :TOTAL %d, and are " % comm_loops_cnt, print common_loops, print " this loop : %s" % common_loops[0] else: print "UNROLL can't unroll any loops?" while True: # break at bottom of loop (repeat in lua) old_num_statements = chill.num_statements() print "old_num_statements %d" % old_num_statements for stmt in range(old_num_statements): cur_idxs = chill.cur_indices(stmt) print "stmt %d cur_idxs =" % stmt, index = 0 for i in cur_idxs: index +=1 if index == len(cur_idxs): print "%s" %i else: print "%s," % i, if len(cur_idxs) > 0: guard_level = -1 if chk_cur_level(stmt, guard_idx) > 0: guard_level = find_cur_level(stmt,guard_idx) print "guard_level(sp) = %d" % guard_level if guard_level > -1: level = next_clean_level(cur_idxs,guard_level) print "next clean level %d" % level #print "looking at %d" % stmt #print "comparing %d and %d in" % (guard_level, level), #index = 0 #for i in cur_idxs: #index +=1 #if index == len(cur_idxs): # print "%s" %i #else: # print "%s," % i, # need to handle max_depth num_unrolled = 0 level_unroll_comm = level level_arr = [] #print "before while, level = %d" % level while level >= 0: print "while: level = %d" % level if num_unrolled == max_depth: break print "Unrolling %d at level %d index %s" % ( stmt, level, cur_idxs[guard_level]) # ??? level_arr.append(level) guard_level = find_cur_level(stmt,guard_idx) level = next_clean_level(cur_idxs,level+1) print "OK, NOW WE UNROLL" if level_unroll_comm >= 0: level_arr.reverse() for i,lev in enumerate(level_arr): print "\ni=%d" % i print "[Unroll]unroll(%d, %d, 0)" % (stmt, lev) chill.unroll(stmt, lev, 0) new_num_statements = chill.num_statements() if old_num_statements == new_num_statements: break # exit infinite loop
def copy_to_registers( start_loop, array_name ): #print "\n\n****** starting copy to registers" #sys.stdout.flush() stmt = 0 # assume stmt 0 cur = chill.cur_indices(stmt) # calls C table_Size = len(cur) #print "Cur indices", #print_array(cur) #print "\nThe table size is %d" % table_Size #count=1 #for c in cur: # print "%d\t%s" % (count,c) # count += 1 #print_code() # would be much cleaner if not translating this code from lua! level_tx = -1 level_ty = -1 if is_in_indices(stmt,"tx"): level_tx = find_cur_level(stmt,"tx") if is_in_indices(stmt,"ty"): level_ty = find_cur_level(stmt,"ty") #print "level_tx %d level_ty %d" % ( level_tx, level_ty ) #sys.stdout.flush() ty_lookup_idx = "" org_level_ty = level_ty # UGLY logic. Lua index starts at 1, so all tests etc here are off by 1 from the lua code # level_ty initializes to -1 , which is not a valid index, and so there is added code to # make it not try to acccess offset -1. -1 IS a valid python array index # to top it off, the else below can assign a NIL to ty_lookup_idx! if level_ty != -1 and cur[level_ty] != "": #print "IF cur[%d] = %s" % ( level_ty, cur[level_ty] ) ty_lookup_idx = cur[level_ty] else: #print "ELSE ty_lookup_idx = cur[%d] = %s" % ( level_ty, cur[level_ty-1]) ty_lookup_idx = cur[level_ty-1] #print "ty_lookup_idx '%s'" % ty_lookup_idx if level_ty > -1: #print "\ntile3(%d,%d,%d)" % (stmt,level_ty,level_tx+1) chill.tile3(stmt,level_ty,level_tx+1) #print_code() cur = chill.cur_indices(stmt) # calls C table_Size = len(cur) #print "Cur indices ", #for c in cur: # print "%s," % c, #print "\nThe table size is %d" % len(cur) #count=1 #for c in cur: # print "%d\t%s" % (count,c) # count += 1 #sys.stdout.flush() if is_in_indices(stmt,"tx"): level_tx = find_cur_level(stmt,"tx") if ty_lookup_idx != "": # perhaps incorrect test if is_in_indices(stmt,ty_lookup_idx): level_ty = find_cur_level(stmt,ty_lookup_idx) ty_lookup = 1 idx_flag = -1 # find the level of the next valid index after ty+1 #print "\nlevel_ty %d" % level_ty if level_ty > -1: #print "table_Size %d" % table_Size for num in range(-1 + level_ty+ty_lookup,table_Size): # ?? off by one? #print "num=%d cur[num] = '%s'" % (num+1, cur[num]) # num+1 is lua index ???? sys.stdout.flush() if cur[num] != "": idx_flag = find_cur_level(stmt,cur[num]) #print "idx_flag = %d" % idx_flag break #print "\n(first) I am checking all indexes after ty+1 %s" % idx_flag #print_code() #print "" how_many_levels = 1 #print "idx_flag = %d I will check levels starting with %d" % (idx_flag, idx_flag+1) # lua arrays start at index 1. the next loop in lua starts at offset 0, since idx_flag can be -1 # thus the check for "not equal nil" in lua (bad idea) # python arrays start at 0, so will check for things that lua doesn't (?) startat = idx_flag + 1 if idx_flag == -1: startat = 1 # pretend we're lua for now. TODO: fix the logic for ch_lev in range(startat,table_Size+1): # logic may be wrong (off by one) #print "ch_lev %d" % ch_lev if ch_lev <= table_Size and cur[ch_lev-1] != "": #print "cur[%d] = '%s'" % ( ch_lev, cur[ch_lev-1] ) how_many_levels += 1 #print "\nHow Many Levels %d" % how_many_levels sys.stdout.flush() sys.stdout.flush() if how_many_levels< 2: while( idx_flag >= 0): for num in range(level_ty+ty_lookup,table_Size+1): #print "at top of loop, num is %d" % num #print "cur[num] = '%s'" % cur[num-1] if cur[num-1] != "": idx = cur[num-1] #print "idx '%s'" % idx sys.stdout.flush() curlev = find_cur_level(stmt,idx) #print "curlev %d" % curlev #print "\n[COPYTOREG]tile(%d,%d,%d)"%(stmt,curlev,level_tx) chill.tile3(stmt, curlev, curlev) curlev = find_cur_level(stmt,idx) #print "curlev %d" % curlev chill.tile3(stmt,curlev,level_tx) #print "hehe '%s'" % cur[num-1] cur = chill.cur_indices(stmt) #print "Cur indices INSIDE", #for c in cur: # print "%s," % c, table_Size = len(cur) #print "\nTable Size is: %d" % len(cur) level_tx = find_cur_level(stmt,"tx") #print "\n level TX is: %d" % level_tx level_ty = find_cur_level(stmt,ty_lookup_idx) #print "\n level TY is: %d" %level_ty idx_flag = -1 #print "idx_flag = -1" #- find the level of the next valid index after ty+1 #- the following was num, which conflicts with loop we're already in, and otherwise wasn't used (?) for num2 in range( -1 + level_ty+ty_lookup ,table_Size): # lua starts index at one #print "num mucking num = %d" % num2 if(cur[num2] != ""): #print "cur[%d] = '%s'" % ( num2, cur[num2] ) idx_flag = find_cur_level(stmt,cur[num2]) #print("\n(second) I am checking all indexes after ty+1 %s",cur[num2]) break #print "num mucked to %d idx_flag = %d" % (num, idx_flag) #print "at bottom of loop, num is %d" % num #print "done with levels" # this was a block comment ??? # for num in range(level_ty+1, table_Size+1): # print "num %d" % num # if cur[num-1] != "": # idx_flag = find_cur_level(stmt,cur[num-1]) ## ugly # print "idx_flag = %d" % idx_flag # change this all to reflect the real logic which is to normalize all loops inside the thread loops. # print "change this all ...\n" # print "level_ty+1 %d table_Size-1 %d idx_flag %d" %( level_ty+1, table_Size-1, idx_flag) # sys.stdout.flush() # sys.stdout.flush() # while level_ty+1 < (table_Size-1) and idx_flag >= 0: # print "*** level_ty %d" % level_ty # for num in range(level_ty+2,table_Size+1): # lua for includes second value # print "num %d cur[num] %s" % (num, cur[num]) # if cur[num] != "": # idx = cur[num] # print "idx='%s'" % idx # #print_code() #print "ARE WE SYNCED HERE?" #print_code() # [Malik] end logic start_level = find_cur_level(stmt, start_loop) # start_loop was passed parameter! # We should hold constant any block or tile loop block_idxs = chill.block_indices() thread_idxs = chill.thread_indices() #print"\nblock indices are" #for index, val in enumerate(block_idxs): # print "%d\t%s" % ( int(index)+1 , val ) #print"\nthread indices are" #for index, val in enumerate(thread_idxs): # print "%d\t%s" % ( int(index)+1 , val ) #print "\nStart Level: %d" % start_level hold_constant = [] #print("\n Now in Blocks") for idx in block_idxs: blocklevel = find_cur_level(stmt,idx) if blocklevel >= start_level: hold_constant.append(idx) #print "\nJust inserted block %s in hold_constant" %idx #print("\n Now in Threads") for idx in thread_idxs: blocklevel = find_cur_level(stmt,idx) if blocklevel >= start_level: hold_constant.append(idx) #print "\nJust inserted thread %s in hold_constant" %idx #print "\nhold constant table is: " #for index, val in enumerate(hold_constant): # print "%d\t%s" % ( int(index)+1 , val ) #print("\nbefore datacopy pvt") old_num_stmts = chill.num_statements() #sys.stdout.flush() #print "\n[DataCopy]datacopy_privatized(%d, %s, %s, " % (stmt, start_loop, array_name), #print hold_constant, #print ")" passtoC = [stmt, start_loop, array_name ] # a list passtoC.append( len(hold_constant ) ) for h in hold_constant: passtoC.append( h ) chill.datacopy_privatized( tuple( passtoC )) sys.stdout.flush() sys.stdout.flush() new_num_statements = chill.num_statements()