示例#1
0
    def convertToASTs(self, tile_level, loop_info_table, node=None):
        '''To generate a sequence of ASTs that correspond to this sequence of "simple" loops'''

        if node == None:
            node = self.loops

        if isinstance(node, tuple):
            iname, subnodes = node
            tsize_name, titer_name, st_exp = loop_info_table[iname]
            id = ast.IdentExp(iname)
            lb = ast.IdentExp(titer_name)
            tmp = ast.BinOpExp(ast.IdentExp(tsize_name),
                               ast.ParenthExp(st_exp), ast.BinOpExp.SUB)
            ub = ast.BinOpExp(lb, ast.ParenthExp(tmp), ast.BinOpExp.ADD)
            st = st_exp
            lbody = ast.CompStmt(
                self.convertToASTs(tile_level, loop_info_table, subnodes))
            return self.ast_util.createForLoop(id, lb, ub, st, lbody)

        elif isinstance(node, list):
            return [
                self.convertToASTs(tile_level, loop_info_table, n)
                for n in node
            ]

        else:
            return node
示例#2
0
    def __getIntraTileLoop(self, iter_name, tsize_name, lb_exp, st_exp, lbody):
        '''
        Generate an intra-tile loop:
          for (i=lb; i<=lb+(Ti-st); i+=st)
             lbody
        '''

        id = ast.IdentExp(iter_name)
        lb = lb_exp
        tmp = ast.BinOpExp(ast.IdentExp(tsize_name), ast.ParenthExp(st_exp),
                           ast.BinOpExp.SUB)
        ub = ast.BinOpExp(lb_exp, ast.ParenthExp(tmp), ast.BinOpExp.ADD)
        st = st_exp
        return self.ast_util.createForLoop(id, lb, ub, st, lbody)
示例#3
0
    def __getInterTileLoop(self, iter_name, tsize_name, lb_exp, ub_exp, st_exp,
                           lbody):
        """
        Generate an inter-tile loop:
          for (i=lb; i<=ub-(Ti-st); i+=Ti)
             lbody
        """

        id = ast.IdentExp(iter_name)
        lb = lb_exp
        tmp = ast.BinOpExp(ast.IdentExp(tsize_name), ast.ParenthExp(st_exp),
                           ast.BinOpExp.SUB)
        ub = ast.BinOpExp(ub_exp, ast.ParenthExp(tmp), ast.BinOpExp.SUB)
        st = ast.IdentExp(tsize_name)
        return self.ast_util.createForLoop(id, lb, ub, st, lbody)
示例#4
0
    def getForLoopInfo(self, stmt):
        '''
        Given a for-loop statement, extract information about its loop structure.
        Note that the for-loop must be in the following form:
          for (<id> = <exp>; <id> <= <exp>; <id> += <exp>)
            <stmt>
        Subtraction is not considered at the iteration expression for the sake of
        the implementation simplicity.
        '''

        # get rid of compound statement that contains only a single statement
        while isinstance(stmt, ast.CompStmt) and len(stmt.stmts) == 1:
            stmt = stmt.stmts[0]

        # check if it is a for-loop statement
        if not isinstance(stmt, ast.ForStmt):
            err('orio.module.ortil.ast_util: OrTil:%s: not a for-loop statement'
                % stmt.line_no)

        # check initialization expression
        if stmt.init:
            while True:
                while isinstance(stmt.init, ast.ParenthExp):
                    stmt.init = stmt.init.exp
                if (isinstance(stmt.init, ast.BinOpExp)
                        and stmt.init.op_type == ast.BinOpExp.EQ_ASGN):
                    while isinstance(stmt.init.lhs, ast.ParenthExp):
                        stmt.init.lhs = stmt.init.lhs.exp
                    while isinstance(stmt.init.rhs, ast.ParenthExp):
                        stmt.init.rhs = stmt.init.rhs.exp
                    if isinstance(stmt.init.lhs, ast.IdentExp):
                        break
                err('orio.module.ortil.ast_util:%s: loop initialization expression not in "<id> = <exp>" form'
                    % stmt.init.line_no)

        # check test expression
        if stmt.test:
            while True:
                while isinstance(stmt.test, ast.ParenthExp):
                    stmt.test = stmt.test.exp
                if (isinstance(stmt.test, ast.BinOpExp) and stmt.test.op_type
                        in (ast.BinOpExp.LT, ast.BinOpExp.LE)):
                    while isinstance(stmt.test.lhs, ast.ParenthExp):
                        stmt.test.lhs = stmt.test.lhs.exp
                    while isinstance(stmt.test.rhs, ast.ParenthExp):
                        stmt.test.rhs = stmt.test.rhs.exp
                    if isinstance(stmt.test.lhs, ast.IdentExp):
                        break
                err('orio.module.ortil.ast_util:%s: loop test expression not in "<id> <= <exp>" or '
                    + '"<id> < <exp>"form' % stmt.test.line_no)

        # check iteration expression
        if stmt.iter:
            while True:
                while isinstance(stmt.iter, ast.ParenthExp):
                    stmt.iter = stmt.iter.exp
                if (isinstance(stmt.iter, ast.BinOpExp)
                        and stmt.iter.op_type == ast.BinOpExp.EQ_ASGN):
                    while isinstance(stmt.iter.lhs, ast.ParenthExp):
                        stmt.iter.lhs = stmt.iter.lhs.exp
                    while isinstance(stmt.iter.rhs, ast.ParenthExp):
                        stmt.iter.rhs = stmt.iter.rhs.exp
                    if isinstance(stmt.iter.lhs, ast.IdentExp):
                        if (isinstance(stmt.iter.rhs, ast.BinOpExp)
                                and stmt.iter.rhs.op_type
                                in (ast.BinOpExp.ADD, ast.BinOpExp.SUB)):
                            while isinstance(stmt.iter.rhs.lhs,
                                             ast.ParenthExp):
                                stmt.iter.rhs.lhs = stmt.iter.rhs.lhs.exp
                            while isinstance(stmt.iter.rhs.rhs,
                                             ast.ParenthExp):
                                stmt.iter.rhs.rhs = stmt.iter.rhs.rhs.exp
                            if (isinstance(stmt.iter.rhs.lhs, ast.IdentExp)
                                    and stmt.iter.lhs.name
                                    == stmt.iter.rhs.lhs.name):
                                break
                elif (isinstance(stmt.iter, ast.UnaryExp) and stmt.iter.op_type
                      in (ast.UnaryExp.POST_INC, ast.UnaryExp.PRE_INC,
                          ast.UnaryExp.POST_DEC, ast.UnaryExp.PRE_DEC)):
                    while isinstance(stmt.iter.exp, ast.ParenthExp):
                        stmt.iter.exp = stmt.iter.exp.exp
                    if isinstance(stmt.iter.exp, ast.IdentExp):
                        break
                err((
                    'orio.module.ortil.ast_util:%s: loop iteration expression not in "<id>++" or "<id>--" or '
                    + '"<id> += <exp>" or "<id> = <id> + <exp>" form') %
                    stmt.iter.line_no)

        # check if the control expressions are all empty
        if not stmt.init and not stmt.test and not stmt.iter:
            err('orio.module.ortil.ast_util:%s: a loop with an empty control expression cannot be handled'
                % stmt.line_no)

        # check if the iterator names are all the same
        init_iname = None
        test_iname = None
        iter_iname = None
        if stmt.init:
            init_iname = stmt.init.lhs.name
        if stmt.test:
            test_iname = stmt.test.lhs.name
        if stmt.iter:
            if isinstance(stmt.iter, ast.BinOpExp):
                iter_iname = stmt.iter.lhs.name
            else:
                assert (isinstance(
                    stmt.iter,
                    ast.UnaryExp)), 'internal error:OrTil: not unary'
                iter_iname = stmt.iter.exp.name
        inames = []
        if init_iname:
            inames.append(init_iname)
        if test_iname:
            inames.append(test_iname)
        if iter_iname:
            inames.append(iter_iname)
        if inames.count(inames[0]) != len(inames):
            err('orio.module.ortil.ast_util:%s: iterator names across init, test, and iter exps must be the same'
                % stmt.line_no)

        # extract for-loop structure information
        index_id = ast.IdentExp(inames[0])
        lbound_exp = None
        ubound_exp = None
        stride_exp = None
        if stmt.init:
            lbound_exp = stmt.init.rhs.replicate()
        if stmt.test:
            if stmt.test.op_type == ast.BinOpExp.LT:
                ubound_exp = ast.BinOpExp(stmt.test.rhs.replicate(),
                                          ast.NumLitExp(1, ast.NumLitExp.INT),
                                          ast.BinOpExp.SUB)
            else:
                ubound_exp = stmt.test.rhs.replicate()
        if stmt.iter:
            if isinstance(stmt.iter, ast.BinOpExp):
                stride_exp = stmt.iter.rhs.rhs.replicate()
                if isinstance(stride_exp, ast.BinOpExp):
                    stride_exp = ast.ParenthExp(stride_exp)
                if stmt.iter.rhs.op_type == ast.BinOpExp.SUB:
                    stride_exp = ast.UnaryExp(stride_exp, ast.UnaryExp.MINUS)
            elif isinstance(stmt.iter, ast.UnaryExp):
                if stmt.iter.op_type in (ast.UnaryExp.POST_INC,
                                         ast.UnaryExp.PRE_INC):
                    stride_exp = ast.NumLitExp(1, ast.NumLitExp.INT)
                elif stmt.iter.op_type in (ast.UnaryExp.POST_DEC,
                                           ast.UnaryExp.PRE_DEC):
                    stride_exp = ast.NumLitExp(-1, ast.NumLitExp.INT)
                else:
                    err('orio.module.ortil.ast_util internal error:OrTil: unexpected unary operation type'
                        )
            else:
                err('orio.module.ortil.ast_utilinternal error:OrTil: unexpected type of iteration expression'
                    )

        loop_body = stmt.stmt.replicate()
        for_loop_info = (index_id, lbound_exp, ubound_exp, stride_exp,
                         loop_body)

        # return the for-loop structure information
        return for_loop_info
示例#5
0
def p_primary_expression_5(p):
    '''primary_expression : LPAREN expression RPAREN'''
    p[0] = ast.ParenthExp(p[2], p.lineno(1) + __start_line_no - 1)
示例#6
0
文件: code_parser.py 项目: phrb/Orio
def p_primary_expression_5(p):
    """primary_expression : LPAREN expression RPAREN"""
    p[0] = ast.ParenthExp(p[2])
示例#7
0
    def __tile(self, stmt, tile_level, outer_loop_infos, preceding_stmts,
               lbound_info, int_vars):
        '''Apply tiling on the given statement'''

        # complain if the tiling level is not a positive integer
        if not isinstance(tile_level, int) or tile_level <= 0:
            err('orio.module.ortil.transformation internal error: invalid tiling level: %s'
                % tile_level)

        # cannot handle a directly nested compound statement
        if isinstance(stmt, ast.CompStmt):
            err('orio.module.ortil.transformation internal error: unexpected compound statement directly nested inside '
                + 'another compound statement')

        # to handle the given expression statement or if statement
        if isinstance(stmt, ast.ExpStmt) or isinstance(stmt, ast.IfStmt):
            preceding_stmts = preceding_stmts[:]
            if preceding_stmts:
                is_tiled, last_stmts = preceding_stmts.pop()
                if is_tiled:
                    preceding_stmts.append((is_tiled, last_stmts))
                    preceding_stmts.append((False, [stmt]))
                else:
                    preceding_stmts.append((False, last_stmts + [stmt]))
            else:
                preceding_stmts.append((False, [stmt]))
            return preceding_stmts

        # to tile the given for-loop statement
        if isinstance(stmt, ast.ForStmt):

            # check if this loop is already tiled and whether it's fully tiled
            this_fully_tiled = stmt.fully_tiled

            # extract loop structure information
            this_linfo = self.ast_util.getForLoopInfo(stmt)
            this_id, this_lb_exp, this_ub_exp, this_st_exp, this_lbody = this_linfo
            this_iname = this_id.name

            # get information about the (extended) tiled outer loops
            outer_loop_inames = [i for i, _ in outer_loop_infos]
            loop_info_table = dict(outer_loop_infos)
            n_outer_loop_infos = outer_loop_infos + [(this_iname, this_linfo)]
            n_outer_loop_inames = [i for i, _ in n_outer_loop_infos]
            n_loop_info_table = dict(n_outer_loop_infos)

            # prepare loop bounds information (for iterating full rectangular tiles)
            need_prolog = False
            need_epilog = False
            rect_lb_exp = this_lb_exp
            rect_ub_exp = this_ub_exp
            if lbound_info:
                lb_name, ub_name, need_prolog, need_epilog, need_tiled_loop = lbound_info
                rect_lb_exp = ast.IdentExp(lb_name)
                rect_ub_exp = ast.IdentExp(ub_name)
                if not need_tiled_loop:
                    err('orio.module.ortil.transformation internal error: unexpected case where generation of the orio.main.'
                        + 'rectangular tiled loop is needed')

            # get explicit loop-bound scanning code
            t = self.__getLoopBoundScanningStmts(this_lbody.stmts, tile_level,
                                                 n_outer_loop_inames,
                                                 n_loop_info_table)
            scan_stmts, lbound_info_seq, ivars = t

            # update the newly declared integer variables
            int_vars.extend(ivars)

            # initialize the resulting statements
            res_stmts = preceding_stmts[:]

            # generate the prolog code
            if need_prolog:
                ub = ast.BinOpExp(rect_lb_exp, ast.ParenthExp(this_st_exp),
                                  ast.BinOpExp.SUB)
                prolog_code = self.ast_util.createForLoop(
                    this_id, this_lb_exp, ub, this_st_exp, this_lbody)
                if res_stmts:
                    is_tiled, last_stmts = res_stmts.pop()
                    if is_tiled:
                        res_stmts.append((is_tiled, last_stmts))
                        res_stmts.append((False, [prolog_code]))
                    else:
                        res_stmts.append((False, last_stmts + [prolog_code]))
                else:
                    res_stmts.append((False, [prolog_code]))

            # start generating the orio.main.rectangularly tiled code
            # (note: the body of the tiled code may contain if-statement branches,
            # each needed to be recursively transformed)
            # example of the resulting processed statements:
            #   s1; if (t-exp) {s2; s3;} else s4;
            # is represented as the following list:
            #   [s1, t-exp, [s2, s3], [s4]]
            contain_loop = False
            processed_stmts = []
            if_branches = [
                processed_stmts
            ]  # a container for storing list of if-branch statements
            for s, binfo in zip(this_lbody.stmts, lbound_info_seq):

                # check if one of the enclosed statements is a loop
                if isinstance(s, ast.ForStmt):
                    contain_loop = True

                # perform transformation on each if-branch statements
                n_if_branches = []
                for p_stmts in if_branches:

                    # replicate the statement (just to be safe)
                    s = s.replicate()

                    # this is NOT a loop statement with bound expressions that are functions of
                    # outer loop iterators
                    if binfo == None:
                        n_p_stmts = self.__tile(s, tile_level,
                                                n_outer_loop_infos, p_stmts,
                                                binfo, int_vars)
                        while len(p_stmts) > 0:
                            p_stmts.pop()
                        p_stmts.extend(n_p_stmts)
                        n_if_branches.append(p_stmts)
                        continue

                    # (optimization) special handling for one-time loop --> remove the if's true
                    # condition (i.e., lb<ub) since it will never be executed.
                    _, _, _, _, need_tiled_loop = binfo
                    if not need_tiled_loop:
                        if p_stmts:
                            is_tiled, last_stmts = p_stmts.pop()
                            if is_tiled:
                                p_stmts.append((is_tiled, last_stmts))
                                p_stmts.append((False, [s]))
                            else:
                                p_stmts.append((False, last_stmts + [s]))
                        else:
                            p_stmts.append((False, [s]))
                        n_if_branches.append(p_stmts)
                        continue

                    # (optimization) recursively feed in the last processed statement only, and
                    # leave the other preceeding statements untouched --> for reducing code size
                    if len(p_stmts) > 0:
                        p = p_stmts.pop()
                        last_p_stmts = [p]
                    else:
                        last_p_stmts = []

                    # perform a recursion to further tile this rectangularly tiled loop
                    n_p_stmts = self.__tile(s.replicate(), tile_level,
                                            n_outer_loop_infos, last_p_stmts,
                                            binfo, int_vars)

                    # compute the processed statements for both true and false conditions
                    true_p_stmts = n_p_stmts
                    false_p_stmts = last_p_stmts
                    if false_p_stmts:
                        is_tiled, last_stmts = false_p_stmts.pop()
                        if is_tiled:
                            false_p_stmts.append((is_tiled, last_stmts))
                            false_p_stmts.append((False, [s]))
                        else:
                            false_p_stmts.append((False, last_stmts + [s]))
                    else:
                        false_p_stmts.append((False, [s]))

                    # create two sets of if-branch statements
                    lbn, ubn, _, _, _ = binfo
                    test_exp = ast.BinOpExp(ast.IdentExp(lbn),
                                            ast.IdentExp(ubn), ast.BinOpExp.LT)
                    p_stmts.append(test_exp)
                    p_stmts.append(true_p_stmts)
                    p_stmts.append(false_p_stmts)
                    n_if_branches.append(true_p_stmts)
                    n_if_branches.append(false_p_stmts)

                # update the if-branch statements
                if_branches = n_if_branches

            # combine the loop-bound scanning statements
            lbody_stmts = []
            lbody_stmts.extend(scan_stmts)

            # convert the processed statements into AST
            lbody_stmts.extend(
                self.__convertToASTs(processed_stmts, tile_level, contain_loop,
                                     n_outer_loop_inames, n_loop_info_table,
                                     int_vars))

            # generate the orio.main.rectangularly tiled code
            lbody = ast.CompStmt(lbody_stmts)
            iname = self.__getTileIterName(this_iname, tile_level)
            tname = self.__getTileSizeName(this_iname, tile_level)
            tiled_code = self.__getInterTileLoop(iname, tname, rect_lb_exp,
                                                 rect_ub_exp, this_st_exp,
                                                 lbody)
            res_stmts.append((True, [tiled_code]))

            # mark the loop if it's a loop iterating the full rectangular tiles
            self.__labelFullCoreTiledLoop(tiled_code, n_outer_loop_inames)

            # generate the cleanup code (the epilog is already fused)
            if not this_fully_tiled:
                lb = ast.IdentExp(
                    self.__getTileIterName(this_iname, tile_level))
                cleanup_code = self.ast_util.createForLoop(
                    this_id, lb, this_ub_exp, this_st_exp, this_lbody)
                res_stmts.append((False, [cleanup_code]))

            # return the resulting statements
            return res_stmts

        # unknown statement
        err('orio.module.ortil.transformation internal error: unknown type of statement: %s'
            % stmt.__class__.__name__)
示例#8
0
    def __staticLoopBoundScanning(self, stmts, tile_level, outer_loop_inames,
                                  loop_info_table):
        ''' 
        Assuming that the loop-bound expressions are affine functions of outer loop iterators and 
        global parameters, we can determine the loop bounds of full tiles in compile time.
        This is an optimization strategy to produce more efficient code.
        Assumptions: 
          1. Lower bound expression must be in the form of: max(e_1,e_2,e_3,...,e_n)
          2. Upper bound expression must be in the form of: min(e_1,e_2,e_3,...,e_n)
          where e_i is an affine function of outer loop iterators and global parameters
        Note that max(x,y,z) is implemented as nested binary max functions: max(z,max(y,z)). The same
        condition applies for min function.
        When n=1, max/min function is not needed.
        '''

        # initialize all returned variables
        scan_stmts = []
        lbound_info_seq = []
        int_vars = []

        # generate the lower and upper values of each inter-tile loop
        val_table = {}
        for iname in outer_loop_inames:
            _, _, _, st_exp, _ = loop_info_table[iname]
            lval = ast.IdentExp(self.__getTileIterName(iname, tile_level))
            t = ast.BinOpExp(
                ast.IdentExp(self.__getTileSizeName(iname, tile_level)),
                ast.ParenthExp(st_exp.replicate()), ast.BinOpExp.SUB)
            uval = ast.BinOpExp(lval.replicate(), ast.ParenthExp(t),
                                ast.BinOpExp.ADD)
            val_table[iname] = (lval, uval)

        # iterate over each statement to determine loop bounds that are affine functions
        # of outer loop iterators
        lb_exps_table = {}
        ub_exps_table = {}
        for stmt in stmts:

            # skip all non loop statements
            if not isinstance(stmt, ast.ForStmt):
                lbound_info_seq.append(None)
                continue

            # extract this loop structure
            id, lb_exp, ub_exp, st_exp, lbody = self.ast_util.getForLoopInfo(
                stmt)

            # see if the loop bound expressions are bound/free of outer loop iterators
            lb_inames = filter(
                lambda i: self.ast_util.containIdentName(lb_exp, i),
                outer_loop_inames)
            ub_inames = filter(
                lambda i: self.ast_util.containIdentName(ub_exp, i),
                outer_loop_inames)

            # skip loops with bound expressions that are free of outer loop iterators
            if not lb_inames and not ub_inames:
                lbound_info_seq.append(None)
                continue

            # check if this loop runs only once
            is_one_time_loop = str(lb_exp) == str(ub_exp)

            # generate booleans to indicate the needs of prolog, epilog, and orio.main.tiled loop
            if is_one_time_loop:
                need_tiled_loop = False
                need_prolog = False
                need_epilog = False
            else:
                need_tiled_loop = True
                need_prolog = len(lb_inames) > 0
                need_epilog = len(ub_inames) > 0

            # generate new variable names for both the new lower and upper loop bounds
            if need_tiled_loop:
                lb_name, ub_name = self.__getLoopBoundNames()
                int_vars.extend([lb_name, ub_name])
            else:
                lb_name = ''
                ub_name = ''

            # append information about the new loop bounds
            lbinfo = (lb_name, ub_name, need_prolog, need_epilog,
                      need_tiled_loop)
            lbound_info_seq.append(lbinfo)

            # skip generating loop-bound scanning code (if it's a one-time loop)
            if not need_tiled_loop:
                continue

            # determine the value of the new lower loop bound
            if str(lb_exp) in lb_exps_table:
                lb_var = lb_exps_table[str(lb_exp)]
                a = ast.BinOpExp(ast.IdentExp(lb_name), lb_var.replicate(),
                                 ast.BinOpExp.EQ_ASGN)
            else:
                if need_prolog:
                    t = self.__findMinMaxVal('max', lb_exp.replicate(),
                                             lb_inames, val_table)
                    a = ast.BinOpExp(ast.IdentExp(lb_name), t.replicate(),
                                     ast.BinOpExp.EQ_ASGN)
                else:
                    a = ast.BinOpExp(ast.IdentExp(lb_name), lb_exp.replicate(),
                                     ast.BinOpExp.EQ_ASGN)
                lb_exps_table[str(lb_exp)] = ast.IdentExp(lb_name)
            scan_stmts.append(ast.ExpStmt(a))

            # determine the value of the new upper loop bound
            if str(ub_exp) in ub_exps_table:
                ub_var = ub_exps_table[str(ub_exp)]
                a = ast.BinOpExp(ast.IdentExp(ub_name), ub_var.replicate(),
                                 ast.BinOpExp.EQ_ASGN)
            else:
                if need_epilog:
                    t = self.__findMinMaxVal('min', ub_exp.replicate(),
                                             ub_inames, val_table)
                    a = ast.BinOpExp(ast.IdentExp(ub_name), t.replicate(),
                                     ast.BinOpExp.EQ_ASGN)
                else:
                    a = ast.BinOpExp(ast.IdentExp(ub_name), ub_exp.replicate(),
                                     ast.BinOpExp.EQ_ASGN)
                ub_exps_table[str(ub_exp)] = ast.IdentExp(ub_name)
            scan_stmts.append(ast.ExpStmt(a))

        # return all necessary information
        return (scan_stmts, lbound_info_seq, int_vars)
示例#9
0
    def __findMinMaxVal(self,
                        min_or_max,
                        exp,
                        var_names,
                        val_table,
                        up_sign=1):
        '''
        To statically find the actual min/max value of the given expression, based on the given 
        bound variables. The given table records the lowest and highest values of each bound variable.
        The up_sign argument carries the positive/negative sign from the upper level of the AST.
        '''

        # numerical expression
        if isinstance(exp, ast.NumLitExp):
            return exp

        # string expression
        elif isinstance(exp, ast.StringLitExp):
            err('orio.module.ortil.transformation: OrTil: invalid string expression found in loop bound expression: %s'
                % exp)

        # identifier expression
        elif isinstance(exp, ast.IdentExp):

            # do nothing if the identifier is not in the given list of variables to be replaced
            if exp.name not in var_names:
                return exp

            # replace the identifier with its apropriate value (depending on min/max, and upper sign)
            lval, uval = val_table[exp.name]
            if min_or_max == 'max':
                if up_sign == 1:
                    val = ast.ParenthExp(uval.replicate())
                else:
                    val = ast.ParenthExp(lval.replicate())
            elif min_or_max == 'min':
                if up_sign == 1:
                    val = ast.ParenthExp(lval.replicate())
                else:
                    val = ast.ParenthExp(uval.replicate())
            else:
                err('orio.module.ortil.transformation internal error: unrecognized min/max argument value'
                    )

            # return the obtained min/max value
            return val

        # array reference expression
        elif isinstance(exp, ast.ArrayRefExp):
            err('orio.module.ortil.transformation: invalid array-reference expression found in loop bound '
                + 'expression: %s' % exp)

        # function call expression
        elif isinstance(exp, ast.FunCallExp):

            # check the function name
            if (not isinstance(exp.exp, ast.IdentExp)) or exp.exp.name not in (
                    'min', 'max'):
                err((
                    'orio.module.ortil.transformation: function name found in loop bound expression must be '
                    + 'min/max, obtained: %s') % exp.exp)

            # recursion on each function argument
            exp.args = []
            for a in exp.args:
                exp.args.append(
                    self.__findMinMaxVal(min_or_max, a, var_names, val_table,
                                         up_sign))

            # return the computed expression
            return exp

        # unary operation expression
        elif isinstance(exp, ast.UnaryExp):

            # check the unary operation
            if exp.op_type not in (ast.UnaryExp.PLUS, ast.UnaryExp.MINUS):
                err((
                    'orio.module.ortil.transformation: unary operation found in loop bound expression must '
                    + 'be +/-, obtained: %s') % exp.exp)

            # update the sign, and do recursion on the inner expression
            if exp.op_type == ast.UnaryExp.MINUS:
                up_sign *= -1
            exp.exp = self.__findMinMaxVal(min_or_max, exp.exp, var_names,
                                           val_table, up_sign)

            # return the computed expression
            return exp

        # binary operation expression
        elif isinstance(exp, ast.BinOpExp):

            # check the binary operation
            if exp.op_type not in (ast.BinOpExp.ADD, ast.BinOpExp.SUB,
                                   ast.BinOpExp.MUL):
                err((
                    'orio.module.ortil.transformation: binary operation found in loop bound expression must '
                    + 'be +/-/*, obtained: %s') % exp)

            # do recursion on both operands
            exp.lhs = self.__findMinMaxVal(min_or_max, exp.lhs, var_names,
                                           val_table, up_sign)
            if exp.op_type == ast.BinOpExp.SUB:
                up_sign *= -1
            exp.rhs = self.__findMinMaxVal(min_or_max, exp.rhs, var_names,
                                           val_table, up_sign)

            # return the computed expression
            return exp

        # parenthesized expression
        elif isinstance(exp, ast.ParenthExp):
            parenth_before = isinstance(exp.exp, ast.ParenthExp)
            exp.exp = self.__findMinMaxVal(min_or_max, exp.exp, var_names,
                                           val_table, up_sign)
            parenth_after = isinstance(exp.exp, ast.ParenthExp)
            if (not parenth_before) and parenth_after:
                return exp.exp
            return exp

        # unrecognized expression
        else:
            err('orio.module.ortil.transformation internal error: unknown type of expression: %s'
                % exp.__class__.__name__)
示例#10
0
    def __addIdentWithConstant(self, tnode, iname, constant):
        """Add with the given constant all identifiers that match to the specified name"""

        if isinstance(tnode, ast.NumLitExp):
            return tnode

        elif isinstance(tnode, ast.StringLitExp):
            return tnode

        elif isinstance(tnode, ast.IdentExp):
            if tnode.name == iname:
                a = ast.BinOpExp(
                    tnode.replicate(),
                    ast.NumLitExp(constant, ast.NumLitExp.INT),
                    ast.BinOpExp.ADD,
                )
                return ast.ParenthExp(a)
            else:
                return tnode

        elif isinstance(tnode, ast.ArrayRefExp):
            tnode.exp = self.__addIdentWithConstant(tnode.exp, iname, constant)
            tnode.sub_exp = self.__addIdentWithConstant(
                tnode.sub_exp, iname, constant)
            if self.constant_folding:
                tnode.exp = self.__foldConstant(tnode.exp)
                tnode.sub_exp = self.__foldConstant(tnode.sub_exp)
            return tnode

        elif isinstance(tnode, ast.FunCallExp):
            tnode.exp = self.__addIdentWithConstant(tnode.exp, iname, constant)
            tnode.args = [
                self.__addIdentWithConstant(a, iname, constant)
                for a in tnode.args
            ]
            return tnode

        elif isinstance(tnode, ast.UnaryExp):
            tnode.exp = self.__addIdentWithConstant(tnode.exp, iname, constant)
            return tnode

        elif isinstance(tnode, ast.BinOpExp):
            tnode.lhs = self.__addIdentWithConstant(tnode.lhs, iname, constant)
            tnode.rhs = self.__addIdentWithConstant(tnode.rhs, iname, constant)
            return tnode

        elif isinstance(tnode, ast.ParenthExp):
            tnode.exp = self.__addIdentWithConstant(tnode.exp, iname, constant)
            return tnode

        elif isinstance(tnode, ast.ExpStmt):
            if tnode.exp:
                tnode.exp = self.__addIdentWithConstant(
                    tnode.exp, iname, constant)
            return tnode

        elif isinstance(tnode, ast.CompStmt):
            tnode.stmts = [
                self.__addIdentWithConstant(s, iname, constant)
                for s in tnode.stmts
            ]
            return tnode

        elif isinstance(tnode, ast.IfStmt):
            tnode.test = self.__addIdentWithConstant(tnode.test, iname,
                                                     constant)
            tnode.true_stmt = self.__addIdentWithConstant(
                tnode.true_stmt, iname, constant)
            if tnode.false_stmt:
                tnode.false_stmt = self.__addIdentWithConstant(
                    tnode.false_stmt, iname, constant)
            return tnode

        elif isinstance(tnode, ast.ForStmt):
            if tnode.init:
                tnode.init = self.__addIdentWithConstant(
                    tnode.init, iname, constant)
            if tnode.test:
                tnode.test = self.__addIdentWithConstant(
                    tnode.test, iname, constant)
            if tnode.iter:
                tnode.iter = self.__addIdentWithConstant(
                    tnode.iter, iname, constant)
            tnode.stmt = self.__addIdentWithConstant(tnode.stmt, iname,
                                                     constant)
            return tnode

        else:
            err("orio.module.ortildriver.transformation internal error: unknown type of AST: %s"
                % tnode.__class__.__name__)
示例#11
0
文件: code_parser.py 项目: zhjp0/Orio
def p_primary_expression_5(p):
    '''primary_expression : LPAREN expression RPAREN'''
    p[0] = ast.ParenthExp(p[2])
示例#12
0
def p_primary_expression_5(p):
    """primary_expression : LPAREN expression RPAREN"""
    p[0] = ast.ParenthExp(p[2], getLineNumber(p.lineno(1)))
示例#13
0
def p_primary_expression_5(p):
    """primary_expression : LPAREN expression RPAREN"""
    p[0] = ast.ParenthExp(p[2], line_no=str(p.lineno(1) + __start_line_no - 1))