def unroll_to_depth( max_depth ):
    print "\n\nunroll_to_depth(%d)" % max_depth
    print "SYNC UP"
    sys.stdout.flush()

    cur = chill.cur_indices(0)
    thread_idxs = chill.thread_indices()
    guard_idx = thread_idxs[-1]  # last one

    print "cur    indices",
    print_array(cur)
    print "thread indices", 
    print_array(thread_idxs)
    print "guard_idx = %s" % guard_idx

    #print "thread_idxs = ",
    #print thread_idxs
    guard_idx = thread_idxs[-1]
    #print "guard_idx = %s" % guard_idx

    #  HERE FIND OUT THE LOOPS WHICH ARE COMMON BETWEEN STATEMENTS
    common_loops = []
    comm_loops_cnt = 0
    num_stmts = chill.num_statements()
    print "num statements %d" % num_stmts

    for stmt in range(num_stmts):
        sys.stdout.flush()
        print "\nSTMT %d" % stmt,
        cur_idxs = chill.cur_indices(stmt)
        print "Current Indices:",
        for c in cur_idxs[:-1]:
            print "%s," % c,
        print "%s" % cur_idxs[-1]   # last one
        sys.stdout.flush()
        #print_code()
        
        if chk_cur_level(stmt, "tx") > 0:
            
            for ii in range(find_cur_level(stmt,"tx")-1):
                print "ii = %d\ncur_idxs[%d] = '%s'" % (ii+1, ii+1, cur_idxs[ii]) # print to match lua
                id = cur_idxs[ii]
                if id not in ["bx", "by", "", "tx", "ty"]:

                    print "id %s is not in the list" % id

                    for stmt1 in range(stmt+1, num_stmts):
                        print "\nii %d stmt1 is %d" % (ii+1, stmt1)  # print to match lua 
                        cur_idxs1 = chill.cur_indices(stmt1)
                        print "\nstmt1 cur_idxs1 is ",
                        for ind in cur_idxs1[:-1]:
                            print "%s," % ind,
                        print "%s" % cur_idxs1[-1]

                        print "cur level(%d, %s) = %d" % (stmt, "tx", find_cur_level(stmt,"tx") )
                        sys.stdout.flush()

                        endrange = find_cur_level(stmt,"tx")-1
                        print "for iii=1, %d do" % endrange
                        sys.stdout.flush()
                        for iii in range(endrange):   # off by one?  TODO 
                            print "stmt %d   ii %d   iii %d\n" % (stmt, ii+1, iii+1),
                            sys.stdout.flush()
                            
                            if iii >= len(cur_idxs1):
                                print "stmt %d   ii %d   iii %d  cur_idxs1[%d] = NIL" % (stmt, ii+1, iii+1, iii+1, )  # print to match lua 
                            else:
                                print "stmt %d   ii %d   iii %d  cur_idxs1[%d] = '%s'" % (stmt, ii+1, iii+1, iii+1, cur_idxs1[iii])  # print to match lua 
                            sys.stdout.flush()

                            # this will still probably die 
                            if iii < len(cur_idxs1) and [iii] not in ["bx", "by", "tx", "ty", ""]:
                                if cur_idxs[ii] == cur_idxs1[iii]:
                                    print "\nfound idx:%s" % cur_idxs[ii]
                                    common_loops.append(cur_idxs[ii])
                                    print "cl[%d] = '%s'" % ( comm_loops_cnt, cur_idxs[ii] )
                                    comm_loops_cnt = len(common_loops)

    if len(common_loops) > 0:
        print "\n COMM LOOPS :TOTAL %d, and are " % comm_loops_cnt,
        print common_loops, 
        print " this loop : %s" % common_loops[0]
    else:
        print "UNROLL can't unroll any loops?"


    while True:  # break at bottom of loop   (repeat in lua)
        old_num_statements = chill.num_statements()
        print "old_num_statements %d" % old_num_statements

        for stmt in range(old_num_statements):
            cur_idxs = chill.cur_indices(stmt)
            print "stmt %d    cur_idxs =" % stmt,
            index = 0
            for i in cur_idxs:
                index +=1
                if index == len(cur_idxs):
                    print "%s" %i
                else:
                    print "%s," % i,

            if len(cur_idxs) > 0:
                guard_level = -1
                if chk_cur_level(stmt, guard_idx) > 0:
                    guard_level = find_cur_level(stmt,guard_idx)
                print "guard_level(sp) = %d" % guard_level
                if guard_level > -1:
                    level = next_clean_level(cur_idxs,guard_level)
                    print "next clean level %d" % level

                    
                    #print "looking at %d" % stmt
                    #print "comparing %d and %d in" % (guard_level, level),
                    #index = 0
                    #for i in cur_idxs:
                    #index +=1
                    #if index == len(cur_idxs):
                    #    print "%s" %i
                    #else:
                    #    print "%s," % i,

                    # need to handle max_depth
                    num_unrolled = 0
                    level_unroll_comm = level
                    level_arr = []

                    #print "before while, level = %d" % level 
                    while level >= 0:
                        print "while: level = %d" % level 
                        if num_unrolled == max_depth:
                            break

                        print "Unrolling %d at level %d index %s" % ( stmt, level, cur_idxs[guard_level])  # ??? 
                        level_arr.append(level)

                        guard_level = find_cur_level(stmt,guard_idx)
                        level = next_clean_level(cur_idxs,level+1)

                    print "OK, NOW WE UNROLL"
                    if level_unroll_comm >= 0:
                        level_arr.reverse()  
                        for i,lev in enumerate(level_arr):
                            print "\ni=%d" % i
                            print "[Unroll]unroll(%d, %d, 0)" % (stmt, lev)
                            chill.unroll(stmt, lev, 0)


        new_num_statements = chill.num_statements()
        if old_num_statements == new_num_statements:
            break  # exit infinite loop
def unroll( statement, level, unroll_amount ):
    chill.unroll( statement, level, unroll_amount )
示例#3
0
def unroll_to_depth( max_depth ):
    print "\n\nunroll_to_depth(%d)" % max_depth
    print "SYNC UP"
    sys.stdout.flush()

    cur = chill.cur_indices(0)
    thread_idxs = chill.thread_indices()
    guard_idx = thread_idxs[-1]  # last one

    print "cur    indices",
    print_array(cur)
    print "thread indices", 
    print_array(thread_idxs)
    print "guard_idx = %s" % guard_idx

    #print "thread_idxs = ",
    #print thread_idxs
    guard_idx = thread_idxs[-1]
    #print "guard_idx = %s" % guard_idx

    #  HERE FIND OUT THE LOOPS WHICH ARE COMMON BETWEEN STATEMENTS
    common_loops = []
    comm_loops_cnt = 0
    num_stmts = chill.num_statements()
    print "num statements %d" % num_stmts

    for stmt in range(num_stmts):
        sys.stdout.flush()
        print "\nSTMT %d" % stmt,
        cur_idxs = chill.cur_indices(stmt)
        print "Current Indices:",
        for c in cur_idxs[:-1]:
            print "%s," % c,
        print "%s" % cur_idxs[-1]   # last one
        sys.stdout.flush()
        #print_code()
        
        if chk_cur_level(stmt, "tx") > 0:
            
            for ii in range(find_cur_level(stmt,"tx")-1):
                print "ii = %d\ncur_idxs[%d] = '%s'" % (ii+1, ii+1, cur_idxs[ii]) # print to match lua
                id = cur_idxs[ii]
                if id not in ["bx", "by", "", "tx", "ty"]:

                    print "id %s is not in the list" % id

                    for stmt1 in range(stmt+1, num_stmts):
                        print "\nii %d stmt1 is %d" % (ii+1, stmt1)  # print to match lua 
                        cur_idxs1 = chill.cur_indices(stmt1)
                        print "\nstmt1 cur_idxs1 is ",
                        for ind in cur_idxs1[:-1]:
                            print "%s," % ind,
                        print "%s" % cur_idxs1[-1]

                        print "cur level(%d, %s) = %d" % (stmt, "tx", find_cur_level(stmt,"tx") )
                        sys.stdout.flush()

                        endrange = find_cur_level(stmt,"tx")-1
                        print "for iii=1, %d do" % endrange
                        sys.stdout.flush()
                        for iii in range(endrange):   # off by one?  TODO 
                            print "stmt %d   ii %d   iii %d\n" % (stmt, ii+1, iii+1),
                            sys.stdout.flush()
                            
                            if iii >= len(cur_idxs1):
                                print "stmt %d   ii %d   iii %d  cur_idxs1[%d] = NIL" % (stmt, ii+1, iii+1, iii+1, )  # print to match lua 
                            else:
                                print "stmt %d   ii %d   iii %d  cur_idxs1[%d] = '%s'" % (stmt, ii+1, iii+1, iii+1, cur_idxs1[iii])  # print to match lua 
                            sys.stdout.flush()

                            # this will still probably die 
                            if iii < len(cur_idxs1) and [iii] not in ["bx", "by", "tx", "ty", ""]:
                                if cur_idxs[ii] == cur_idxs1[iii]:
                                    print "\nfound idx:%s" % cur_idxs[ii]
                                    common_loops.append(cur_idxs[ii])
                                    print "cl[%d] = '%s'" % ( comm_loops_cnt, cur_idxs[ii] )
                                    comm_loops_cnt = len(common_loops)

    if len(common_loops) > 0:
        print "\n COMM LOOPS :TOTAL %d, and are " % comm_loops_cnt,
        print common_loops, 
        print " this loop : %s" % common_loops[0]
    else:
        print "UNROLL can't unroll any loops?"


    while True:  # break at bottom of loop   (repeat in lua)
        old_num_statements = chill.num_statements()
        print "old_num_statements %d" % old_num_statements

        for stmt in range(old_num_statements):
            cur_idxs = chill.cur_indices(stmt)
            print "stmt %d    cur_idxs =" % stmt,
            index = 0
            for i in cur_idxs:
                index +=1
                if index == len(cur_idxs):
                    print "%s" %i
                else:
                    print "%s," % i,

            if len(cur_idxs) > 0:
                guard_level = -1
                if chk_cur_level(stmt, guard_idx) > 0:
                    guard_level = find_cur_level(stmt,guard_idx)
                print "guard_level(sp) = %d" % guard_level
                if guard_level > -1:
                    level = next_clean_level(cur_idxs,guard_level)
                    print "next clean level %d" % level

                    
                    #print "looking at %d" % stmt
                    #print "comparing %d and %d in" % (guard_level, level),
                    #index = 0
                    #for i in cur_idxs:
                    #index +=1
                    #if index == len(cur_idxs):
                    #    print "%s" %i
                    #else:
                    #    print "%s," % i,

                    # need to handle max_depth
                    num_unrolled = 0
                    level_unroll_comm = level
                    level_arr = []

                    #print "before while, level = %d" % level 
                    while level >= 0:
                        print "while: level = %d" % level 
                        if num_unrolled == max_depth:
                            break

                        print "Unrolling %d at level %d index %s" % ( stmt, level, cur_idxs[guard_level])  # ??? 
                        level_arr.append(level)

                        guard_level = find_cur_level(stmt,guard_idx)
                        level = next_clean_level(cur_idxs,level+1)

                    print "OK, NOW WE UNROLL"
                    if level_unroll_comm >= 0:
                        level_arr.reverse()  
                        for i,lev in enumerate(level_arr):
                            print "\ni=%d" % i
                            print "[Unroll]unroll(%d, %d, 0)" % (stmt, lev)
                            chill.unroll(stmt, lev, 0)


        new_num_statements = chill.num_statements()
        if old_num_statements == new_num_statements:
            break  # exit infinite loop
示例#4
0
import chill

chill.source('src/nestedloops.c')
chill.destination('src/nestedloops_modified.c')
chill.procedure('foo')
chill.loop(0)

chill.permute([2, 3, 1])
chill.unroll(0, 3, 2)
示例#5
0
def unroll( statement, level, unroll_amount ):
    chill.unroll( statement, level, unroll_amount )