示例#1
0
    def readTransfArgs(self, perf_params, transf_args):
        '''Process the given transformation arguments'''

        # expected argument names
        PRAGMAS = 'pragma_str'

        # all expected transformation arguments
        pragmas = []

        # iterate over all transformation arguments
        for aname, rhs, line_no in transf_args:

            # evaluate the RHS expression
            try:
                rhs = eval(str(rhs), perf_params)
            except Exception, e:
                g.err(__name__+': at line %s, failed to evaluate the argument expression: %s\n --> %s: %s'
                      % (line_no, rhs, e.__class__.__name__, e))

            # pragma directives
            if aname == PRAGMAS:
                pragmas = (rhs, line_no)

            # unknown argument name
            else:
                g.err(__name__+': %s: unrecognized transformation argument: "%s"' % (line_no, aname))
示例#2
0
    def __init__(self, language='C'):
        '''Instantiates a code generator'''

        self.generator = None
        if language.lower() in ['c', 'c++', 'cxx']:
            self.generator = CodeGen_C()
        #elif language.lower() in ['f', 'f90', 'fortran']:
        #    self.generator = CodeGen_F()
        #elif language.lower() in ['cuda']:
        #    from orio.module.loop.codegen_cuda import CodeGen_CUDA
        #    self.generator = CodeGen_CUDA()
        else:
            g.err(__name__+': Unknown language specified for code generation: %s' % language)
        pass
示例#3
0
    def getDeviceProps(self):
        '''Get device properties'''
        # write the query code
        qsrc  = "enum_opencl_props.c"
        qexec = qsrc + ".o"
        qout  = qexec + ".props"

        try:
            f = open(qsrc, 'w')
            f.write(OPENCL_DEVICE_QUERY_SKELET)
            f.close()
        except:
            g.err('%s: cannot open file for writing: %s' % (self.__class__, qsrc))
        
        # compile the query
        if self.tinfo is not None and self.tinfo.build_cmd is not None:
            cmd = self.tinfo.build_cmd
        else:
            cmd = 'gcc -framework OpenCL'
            
        cmd += ' -o %s %s' % (qexec, qsrc)

        status = os.system(cmd)
        if status:
            g.err('%s: failed to compile OpenCL device query code: "%s"' % (self.__class__, cmd))

        # execute the query
        runcmd = './%s > ./%s' % (qexec, qout)
        status = os.system(runcmd)
        if status:
            g.err('%s: failed to execute OpenCL device query code: "%s"' % (self.__class__, runcmd))
        os.remove(qsrc)
        os.remove(qexec)
        
        # read device properties
        platforms = []
        try:
            f = open(qout, 'r')
            mostRecentWasDevice = False
            for line in f:
                eline = ast.literal_eval(line)
                if eline[0] == 'PLATFORM':
                    mostRecentWasDevice = False
                    platforms.append({'devices':[]})
                elif eline[0] == 'DEVICE':
                    mostRecentWasDevice = True
                    platforms[-1]['devices'].append({})
                else:
                    if mostRecentWasDevice:
                        platforms[-1]['devices'][-1][eline[0]] = eline[1]
                    else:
                        platforms[-1][eline[0]] = eline[1]
            f.close()
            #print platforms
        except:
            g.err('%s: cannot open query output file for reading: %s' % (self.__class__, qout))
            
        # return queried device props
        return platforms
示例#4
0
    def __transformStmt(self, stmt):
        '''Apply code transformation on the given statement'''
        
        if isinstance(stmt, ast.ExpStmt):
            return stmt
        
        elif isinstance(stmt, ast.CompStmt):
            stmt.kids = [self.__transformStmt(s) for s in stmt.kids]
            return stmt

        elif isinstance(stmt, ast.IfStmt):
            stmt.true_stmt = self.__transformStmt(stmt.true_stmt)
            if stmt.false_stmt:
                stmt.false_stmt = self.__transformStmt(stmt.false_stmt)
            return stmt

        elif isinstance(stmt, ast.ForStmt):
            stmt.stmt = self.__transformStmt(stmt.stmt)
            return stmt

        elif isinstance(stmt, ast.Comment):
            stmt.stmt = stmt
            return stmt
        
        elif isinstance(stmt, ast.TransformStmt):

            # transform the nested statement
            stmt.stmt = self.__transformStmt(stmt.stmt)

            # check for repeated transformation argument names
            arg_names = {}
            for [aname, _, line_no] in stmt.args:
                if aname in arg_names:
                    g.err(__name__ + ': %s: repeated transformation argument: "%s"' % (line_no, aname))
                arg_names[aname] = None

            # dynamically load the transformation submodule class
            class_name = stmt.name
            submod_name = '.'.join([TSUBMOD_NAME, class_name.lower(), class_name.lower()])
            submod_class = self.dloader.loadClass(submod_name, class_name)
            
            # apply code transformations
            t = submod_class(self.perf_params, stmt.args, stmt.stmt, self.language, self.tinfo)
            transformed_stmt = t.transform()

            return transformed_stmt

        else:
            g.err(__name__+': internal error: unknown statement type: %s' % stmt.__class__.__name__)
示例#5
0
    def checkTransfArgs(self, pragmas):
        '''Check the semantics of the given transformation arguments'''

        # evaluate the pragma directives
        rhs, line_no = pragmas
        if isinstance(rhs, str):
            pragmas = [rhs]
        else:
            if ((not isinstance(rhs, list) and not isinstance(rhs, tuple)) or
                not reduce(lambda x,y: x and y, map(lambda x: isinstance(x, str), rhs), True)):
                g.err(__name__+':%s: pragma directives must be a list/tuple of strings: %s'
                      % (line_no, rhs))
            pragmas = rhs

        # return information about the transformation arguments
        return pragmas
示例#6
0
    def containIdentName(self, exp, iname):
        '''
        Check if the given expression contains an identifier whose name matches to the given name
        '''

        if exp == None:
            return False
        
        if isinstance(exp, orio.module.loop.ast.NumLitExp):
            return False
        
        elif isinstance(exp, orio.module.loop.ast.StringLitExp):
            return False
        
        elif isinstance(exp, orio.module.loop.ast.IdentExp):
            return exp.name == iname
        
        elif isinstance(exp, orio.module.loop.ast.ArrayRefExp):
            return self.containIdentName(exp.exp, iname) or self.containIdentName(exp.sub_exp, iname)
        
        elif isinstance(exp, orio.module.loop.ast.FunCallExp):
            has_match = reduce(lambda x,y: x or y,
                               [self.containIdentName(a, iname) for a in exp.args],
                               False)
            return self.containIdentName(exp.exp, iname) or has_match
        
        elif isinstance(exp, orio.module.loop.ast.UnaryExp):
            return self.containIdentName(exp.exp, iname)
        
        elif isinstance(exp, orio.module.loop.ast.BinOpExp):
            return self.containIdentName(exp.lhs, iname) or self.containIdentName(exp.rhs, iname)
        
        elif isinstance(exp, orio.module.loop.ast.ParenthExp):
            return self.containIdentName(exp.exp, iname)
        
        elif isinstance(exp, orio.module.loop.ast.NewAST):
            return False
        
        elif isinstance(exp, orio.module.loop.ast.Comment):
            return False

        else:
            g.err('orio.module.loop.ast_lib.common_lib internal error:  unexpected AST type: "%s"' % exp.__class__.__name__)
示例#7
0
    def isComplexExp(self, exp):
        '''
        To determine if the given expression is complex. Simple expressions contain only a variable
        or a number or a string.
        '''
        
        if isinstance(exp, orio.module.loop.ast.NumLitExp):
            return False
        
        # a rare case
        elif isinstance(exp, orio.module.loop.ast.StringLitExp):
            return False
        
        elif isinstance(exp, orio.module.loop.ast.IdentExp):
            return False
        
        # a rare case
        elif isinstance(exp, orio.module.loop.ast.ArrayRefExp):
            return True
        
        elif isinstance(exp, orio.module.loop.ast.FunCallExp):
            return True
        
        elif isinstance(exp, orio.module.loop.ast.UnaryExp):
            return self.isComplexExp(exp.exp)
        
        elif isinstance(exp, orio.module.loop.ast.BinOpExp):
            return True
        
        elif isinstance(exp, orio.module.loop.ast.ParenthExp):
            return self.isComplexExp(exp.exp)
        
        # a rare case
        elif isinstance(exp, orio.module.loop.ast.NewAST):
            return True
        
        elif isinstance(exp, orio.module.loop.ast.Comment):
            return False

        else:
            g.err('orio.module.loop.ast_lib.common_lib internal error:  unexpected AST type: "%s"' % exp.__class__.__name__)
示例#8
0
    def readTransfArgs(self, perf_params, transf_args):
        '''Process the given transformation arguments'''

        # expected argument names
        PREFETCH = 'prefetch'
        DISTANCE = 'prefetch_distance'

        # default argument values
        prefetches = []
        dist = 0

        # iterate over all transformation arguments
        for aname, rhs, line_no in transf_args:

            # evaluate the RHS expression
            try:
                rhs = eval(str(rhs), perf_params)
            except Exception, e:
                g.err(__name__+': at line %s, failed to evaluate the argument expression: %s\n --> %s: %s'
                      % (line_no, rhs, e.__class__.__name__, e))

            if aname == PREFETCH:
                prefetches += list(rhs)

            elif aname == DISTANCE:
                if not isinstance(rhs, int):
                    g.err(__name__+': %s: %s must be a positive integer: %s\n' % (line_no, aname, rhs))
                else:
                    dist = rhs

            # unknown argument name
            else:
                g.err(__name__+': %s: unrecognized transformation argument: "%s"' % (line_no, aname))
示例#9
0
def t_error(t):
    g.err('orio.main.tspec.pparser.lexer: illegal character (%s) at line %s' %
          (t.value[0], t.lexer.lineno))
示例#10
0
 def t_error(self, t):
   g.err('%s: illegal character (%s) at line %s' % (self.__class__, t.value[0], t.lexer.lineno))
示例#11
0
文件: common_lib.py 项目: phrb/Orio
    def rewriteNode(self, r, n):
        """ Rewrite the given node with the given rewrite function: post-order traversal, in-place update. """

        if isinstance(n, orio.module.loop.ast.NumLitExp):
            return r(n)

        elif isinstance(n, orio.module.loop.ast.StringLitExp):
            return r(n)

        elif isinstance(n, orio.module.loop.ast.IdentExp):
            return r(n)

        elif isinstance(n, orio.module.loop.ast.VarDecl):
            return r(n)

        elif isinstance(n, orio.module.loop.ast.ArrayRefExp):
            n.exp = self.rewriteNode(r, n.exp)
            n.sub_exp = self.rewriteNode(r, n.sub_exp)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.FunCallExp):
            n.exp = self.rewriteNode(r, n.exp)
            n.args = map(lambda x: self.rewriteNode(r, x), n.args)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.UnaryExp):
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.BinOpExp):
            n.lhs = self.rewriteNode(r, n.lhs)
            n.rhs = self.rewriteNode(r, n.rhs)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.ParenthExp):
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.Comment):
            n.text = self.rewriteNode(r, n.text)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.ExpStmt):
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.GotoStmt):
            n.target = self.rewriteNode(r, n.target)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.CompStmt):
            n.stmts = map(lambda x: self.rewriteNode(r, x), n.stmts)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.IfStmt):
            n.test = self.rewriteNode(r, n.test)
            n.true_stmt = self.rewriteNode(r, n.true_stmt)
            if n.false_stmt:
                n.false_stmt = self.rewriteNode(r, n.false_stmt)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.ForStmt):
            if n.init:
                n.init = self.rewriteNode(r, n.init)
            if n.test:
                n.test = self.rewriteNode(r, n.test)
            if n.iter:
                n.iter = self.rewriteNode(r, n.iter)
            n.stmt = self.rewriteNode(r, n.stmt)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.AssignStmt):
            n.var = self.rewriteNode(r, n.var)
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.TransformStmt):
            n.name = self.rewriteNode(r, n.name)
            n.args = self.rewriteNode(r, n.args)
            n.stmt = self.rewriteNode(r, n.stmt)
            return r(n)

        else:
            g.err(
                'orio.module.loop.ast_lib.common_lib.rewriteNode: unexpected AST type: "%s"'
                % n.__class__.__name__
            )
示例#12
0
文件: pparser.py 项目: axelyamel/Orio
def t_error(t):
    g.err('orio.main.tspec.pparser.lexer: illegal character (%s) at line %s' % (t.value[0], t.lexer.lineno))
示例#13
0
    def generate(self, tnode, indent = '  ', extra_indent = '  '):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.NumLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.StringLitExp):
            s += '"' + str(tnode.val) + '"'

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += '[' + self.generate(tnode.sub_exp, indent, extra_indent) + ']'

        elif isinstance(tnode, ast.FunCallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join([self.generate(x, indent, extra_indent) for x in tnode.args])
            s += ')'

        elif isinstance(tnode, ast.UnaryExp):
            s = self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:
                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = '!' + s
            elif tnode.op_type == tnode.PRE_INC:
                s = ' ++' + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = ' --' + s
            elif tnode.op_type == tnode.POST_INC:
                s = s + '++ '
            elif tnode.op_type == tnode.POST_DEC:
                s = s + '-- '
            elif tnode.op_type == tnode.DEREF:
                s = '*' + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = '&' + s
            else:
                g.err('orio.module.loop.codegen_opencl internal error: unknown unary operator type: %s' % tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.MUL:
                s += '*'
            elif tnode.op_type == tnode.DIV:
                s += '/'
            elif tnode.op_type == tnode.MOD:
                s += '%'
            elif tnode.op_type == tnode.ADD:
                s += '+'
            elif tnode.op_type == tnode.SUB:
                s += '-'
            elif tnode.op_type == tnode.LT:
                s += '<'
            elif tnode.op_type == tnode.GT:
                s += '>'
            elif tnode.op_type == tnode.LE:
                s += '<='
            elif tnode.op_type == tnode.GE:
                s += '>='
            elif tnode.op_type == tnode.EQ:
                s += '=='
            elif tnode.op_type == tnode.NE:
                s += '!='
            elif tnode.op_type == tnode.LOR:
                s += '||'
            elif tnode.op_type == tnode.LAND:
                s += '&&'
            elif tnode.op_type == tnode.COMMA:
                s += ','
            elif tnode.op_type == tnode.EQ_ASGN:
                s += '='
            elif tnode.op_type == tnode.ASGN_ADD:
                s += '+='
            elif tnode.op_type == tnode.ASGN_SHR:
                s += '>>='
            elif tnode.op_type == tnode.ASGN_SHL:
                s += '<<='
            elif tnode.op_type == tnode.BAND:
                s += '&'
            elif tnode.op_type == tnode.SHR:
                s += '>>'
            elif tnode.op_type == tnode.BOR:
                s += '|'
            else:
                g.err('orio.module.loop.codegen_opencl internal error: unknown binary operator type: %s' % tnode.op_type)
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.TernaryExp):
            s += self.generate(tnode.test, indent, extra_indent) + '?'
            s += self.generate(tnode.true_expr,  indent, extra_indent) + ':'
            s += self.generate(tnode.false_expr, indent, extra_indent)

        elif isinstance(tnode, ast.ParenthExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.Comment):
            s += indent
            if tnode.text:
                s += '/*' + tnode.text + '*/'
            s += '\n'
            
        elif isinstance(tnode, ast.ExpStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.exp:
                s += self.generate(tnode.exp, indent, extra_indent)
            s += ';\n'

        elif isinstance(tnode, ast.GotoStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.target:
                s += 'goto ' + tnode.target + ';\n'
                
        elif isinstance(tnode, ast.CompStmt):
            s += indent + '{\n'
            for stmt in tnode.stmts:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + '}\n'

        elif isinstance(tnode, ast.IfStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + 'if (' + self.generate(tnode.test, indent, extra_indent) + ') '
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                s += tstmt_s[tstmt_s.index('{'):]
                if tnode.false_stmt:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.generate(tnode.true_stmt, indent + extra_indent, extra_indent)
                if tnode.false_stmt:
                    s += indent + 'else '
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent, extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.generate(tnode.false_stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.ForStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + 'for ('
            if tnode.init:
                if isinstance(tnode.init, ast.VarDeclInit):
                  s += str(tnode.init.type_name) + ' '
                  s += self.generate(tnode.init.var_name, indent, extra_indent)
                  s += '=' + self.generate(tnode.init.init_exp, indent, extra_indent)
                else:
                  s += self.generate(tnode.init, indent, extra_indent)
            s += '; '
            if tnode.test:
                s += self.generate(tnode.test, indent, extra_indent)
            s += '; '
            if tnode.iter:
                s += self.generate(tnode.iter, indent, extra_indent)
            s += ') '
            if isinstance(tnode.stmt, ast.CompStmt): 
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.AssignStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + tnode.var + '='
            s += self.generate(tnode.exp, indent, extra_indent)
            s += ';\n'
            
        elif isinstance(tnode, ast.TransformStmt):
            g.err('orio.module.loop.codegen_opencl internal error: a transformation statement is never generated as an output')

        elif isinstance(tnode, ast.VarDecl):
            s += indent + str(tnode.type_name) + ' '
            if isinstance(tnode.var_names[0], ast.IdentExp): 
                s += ', '.join(map(self.generate, tnode.var_names))
            else:
                s += ', '.join(tnode.var_names)
            s += ';\n'
 
        elif isinstance(tnode, ast.VarDeclInit):
            s += indent + str(tnode.type_name) + ' '
            s += self.generate(tnode.var_name, indent, extra_indent)
            s += '=' + self.generate(tnode.init_exp, indent, extra_indent)
            s += ';\n'

        elif isinstance(tnode, ast.FieldDecl):
            s += tnode.ty + ' '
            if isinstance(tnode.name, ast.IdentExp):
                s+= tnode.name.name
            else:
                s += tnode.name

        elif isinstance(tnode, ast.FunDecl):
            s += indent + ' '.join(tnode.modifiers) + ' '
            s += tnode.return_type + ' '
            s += tnode.name + '('
            s += ', '.join(map(self.generate, tnode.params)) + ') '
            s += self.generate(tnode.body, indent, extra_indent)

        elif isinstance(tnode, ast.Pragma):
            s += indent + '#pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.Container):
            s += self.generate(tnode.ast, indent, extra_indent)

        elif isinstance(tnode, ast.WhileStmt):
            s += indent + 'while (' + self.generate(tnode.test, indent, extra_indent)
            s += ') '
            if isinstance(tnode.stmt, ast.CompStmt): 
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.CastExpr):
            s += '(' + tnode.ctype + ')'
            s += self.generate(tnode.expr, indent, extra_indent)
        
        else:
            g.err('orio.module.loop.codegen_opencl internal error: unrecognized type of AST: %s' % tnode.__class__.__name__)

        return s
示例#14
0
文件: common_lib.py 项目: phrb/Orio
    def replaceIdent(self, tnode, iname_from, iname_to):
        """Replace the names of all matching identifiers with the given name"""

        if isinstance(tnode, orio.module.loop.ast.NumLitExp):
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.StringLitExp):
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.IdentExp):
            if tnode.name == iname_from:
                tnode.name = iname_to
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.ArrayRefExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            tnode.sub_exp = self.replaceIdent(tnode.sub_exp, iname_from, iname_to)
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.FunCallExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            tnode.args = [
                self.replaceIdent(a, iname_from, iname_to) for a in tnode.args
            ]
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.UnaryExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.BinOpExp):
            tnode.lhs = self.replaceIdent(tnode.lhs, iname_from, iname_to)
            tnode.rhs = self.replaceIdent(tnode.rhs, iname_from, iname_to)
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.ParenthExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.ExpStmt):
            if tnode.exp:
                tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.CompStmt):
            tnode.stmts = [
                self.replaceIdent(s, iname_from, iname_to) for s in tnode.stmts
            ]
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.IfStmt):
            tnode.test = self.replaceIdent(tnode.test, iname_from, iname_to)
            tnode.true_stmt = self.replaceIdent(tnode.true_stmt, iname_from, iname_to)
            if tnode.false_stmt:
                tnode.false_stmt = self.replaceIdent(
                    tnode.false_stmt, iname_from, iname_to
                )
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.ForStmt):
            if tnode.init:
                tnode.init = self.replaceIdent(tnode.init, iname_from, iname_to)
            if tnode.test:
                tnode.test = self.replaceIdent(tnode.test, iname_from, iname_to)
            if tnode.iter:
                tnode.iter = self.replaceIdent(tnode.iter, iname_from, iname_to)
            tnode.stmt = self.replaceIdent(tnode.stmt, iname_from, iname_to)
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.TransformStmt):
            g.err(
                'orio.module.loop.ast_lib.common_lib internal error:  unexpected AST type: "%s"'
                % tnode.__class__.__name__
            )

        elif isinstance(tnode, orio.module.loop.ast.NewAST):
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.Comment):
            return tnode

        else:
            g.err(
                'orio.module.loop.ast_lib.common_lib internal error:  unexpected AST type: "%s"'
                % tnode.__class__.__name__
            )
示例#15
0
 def t_error(self, t):
     g.err('orio.module.loops.lexer: illegal character (%s) at line %s' %
           (t.value[0], t.lexer.lineno))
示例#16
0
文件: cuda.py 项目: zhjp0/Orio
    def readTransfArgs(self, perf_params, transf_args):
        '''Process the given transformation arguments'''

        # expected argument names
        THREADCOUNT = 'threadCount'
        BLOCKCOUNT = 'blockCount'
        CB = 'cacheBlocks'
        PHM = 'pinHostMem'
        STREAMCOUNT = 'streamCount'
        DOMAIN = 'domain'
        DOD = 'dataOnDevice'
        UIF = 'unrollInner'
        PREFERL1SZ = 'preferL1Size'

        # default argument values
        szwarp = self.props['warpSize']
        smcount = self.props['multiProcessorCount']
        threadCount = szwarp
        blockCount = smcount
        cacheBlocks = False
        pinHost = False
        streamCount = 1
        domain = None
        dataOnDevice = False
        unrollInner = None
        preferL1Size = 0

        # iterate over all transformation arguments
        errors = ''
        for aname, rhs, line_no in transf_args:

            # evaluate the RHS expression
            try:
                rhs = eval(rhs, perf_params)
            except Exception, e:
                g.err(
                    'orio.module.loop.submodule.cuda.cuda: %s: failed to evaluate the argument expression: %s\n --> %s: %s'
                    % (line_no, rhs, e.__class__.__name__, e))

            if aname == THREADCOUNT:
                if not isinstance(
                        rhs, int
                ) or rhs <= 0 or rhs > self.props['maxThreadsPerBlock']:
                    errors += 'line %s: threadCount must be a positive integer less than device limit of maxThreadsPerBlock of %s: %s' % (
                        line_no, self.props['maxThreadsPerBlock'], rhs)
                elif rhs % szwarp != 0:
                    errors += 'line %s: threadCount is not a multiple of warp size of %s: %s' % (
                        line_no, szwarp, rhs)
                else:
                    threadCount = rhs
            elif aname == BLOCKCOUNT:
                if not isinstance(
                        rhs,
                        int) or rhs <= 0 or rhs > self.props['maxGridSize'][0]:
                    errors += 'line %s: %s must be a positive integer less than device limit of maxGridSize[0]=%s: %s\n' % (
                        line_no, aname, self.props['maxGridSize'][0], rhs)
                elif rhs % smcount != 0:
                    errors += 'line %s: blockCount is not a multiple of SM count of %s: %s' % (
                        line_no, smcount, rhs)
                else:
                    blockCount = rhs
            elif aname == CB:
                if not isinstance(rhs, bool):
                    errors += 'line %s: %s must be a boolean: %s\n' % (
                        line_no, aname, rhs)
                else:
                    cacheBlocks = rhs
            elif aname == PHM:
                if not isinstance(rhs, bool):
                    errors += 'line %s: %s must be a boolean: %s\n' % (
                        line_no, aname, rhs)
                else:
                    pinHost = rhs
            elif aname == STREAMCOUNT:
                if not isinstance(rhs, int) or rhs <= 0:
                    errors += 'line %s: %s must be a positive integer: %s\n' % (
                        line_no, aname, rhs)
                else:
                    if rhs > 1:
                        overlap = self.props['deviceOverlap']
                        if overlap == 0:
                            errors += '%s=%s: deviceOverlap=%s, overlap of data transfer and kernel execution is not supported\n' % (
                                aname, rhs, overlap)
                        concs = self.props['concurrentKernels']
                        if concs == 0:
                            errors += '%s=%s: device concurrentKernels=%s, concurrent kernel execution is not supported\n' % (
                                aname, rhs, concs)
                    streamCount = rhs
            elif aname == DOMAIN:
                if not isinstance(rhs, str):
                    errors += 'line %s: %s must be a string: %s\n' % (
                        line_no, aname, rhs)
                else:
                    domain = rhs
            elif aname == DOD:
                if not isinstance(rhs, bool):
                    errors += 'line %s: %s must be a boolean: %s\n' % (
                        line_no, aname, rhs)
                else:
                    dataOnDevice = rhs
            elif aname == UIF:
                if not isinstance(rhs, int) or rhs <= 0:
                    errors += 'line %s: %s must be a positive integer: %s\n' % (
                        line_no, aname, rhs)
                else:
                    unrollInner = rhs
            elif aname == PREFERL1SZ:
                if not isinstance(rhs, int) or rhs not in [16, 32, 48]:
                    errors += 'line %s: %s must be either 16, 32 or 48 KB: %s\n' % (
                        line_no, aname, rhs)
                else:
                    major = self.props['major']
                    if major < 2:
                        errors += '%s=%s: L1 cache is not resizable on compute capability less than 2.x, current comp.cap.=%s.%s\n' % (
                            aname, rhs, major, self.props['minor'])
                    elif major < 3 and rhs == 32:
                        errors += '%s=%s: L1 cache cannot be set to %s on compute capability less than 3.x, current comp.cap.=%s.%s\n' % (
                            aname, rhs, rhs, major, self.props['minor'])
                    preferL1Size = rhs
            else:
                g.err('%s: %s: unrecognized transformation argument: "%s"' %
                      (self.__class__, line_no, aname))
示例#17
0
    def readTransfArgs(self, perf_params, transf_args):
        '''Process the given transformation arguments'''

        # expected argument names
        PLATFORM    = 'platform'
        DEVICE      = 'device'
        WORKGROUPS  = 'workGroups'
        WORKITEMS   = 'workItemsPerGroup'
        CB          = 'cacheBlocks'
        STREAMCOUNT = 'streamCount'
        UIF         = 'unrollInner'
        CLFLAGS     = 'clFlags'
        THREADCOUNT = 'threadCount'
        BLOCKCOUNT  = 'blockCount'
        VECHINT     = 'vecHint'
        SIZEHINT    = 'sizeHint'

        # default argument values
        platform = 0
        device = 0
        workGroups  = None
        workItemsPerGroup   = None
        cacheBlocks  = False
        streamCount  = 1
        unrollInner  = None
        clFlags      = None
        vecHint      = 0
        sizeHint     = False

        # iterate over all transformation arguments
        errors = ''
        for aname, rhs, line_no in transf_args:
            # evaluate the RHS expression
            try:
                rhs = eval(rhs, perf_params)
            except Exception, e:
                g.err('orio.module.loop.submodule.opencl.opencl: %s: failed to evaluate the argument expression: %s\n --> %s: %s' % (line_no, rhs,e.__class__.__name__, e))

            if aname == PLATFORM:
                # TODO: validate
                platform = rhs
            elif aname == DEVICE:
                # TODO: validate
                device = rhs
            elif aname == WORKGROUPS:
                # TODO: validate
                workGroups = rhs
            elif aname == WORKITEMS:
                # TODO: validate
                workItemsPerGroup = rhs
            elif aname == CB:
                # TODO: validate
                cacheBlocks = rhs
            elif aname == STREAMCOUNT:
                # TODO: validate
                streamCount = rhs
            elif aname == UIF:
                # TODO: validate
                unrollInner = rhs
            elif aname == CLFLAGS:
                clFlags = rhs
            elif aname == THREADCOUNT:
                g.warn("Interpreting CUDA threadCount as OpenCL workItemsPerGroup")
                workItemsPerGroup = rhs
            elif aname == BLOCKCOUNT:
                g.warn("Interpreting CUDA blockCount as OpenCL workGroups")
                workGroups = rhs
            elif aname == VECHINT:
                vecHint = rhs
            elif aname == SIZEHINT:
                sizeHint = rhs
            else:
                g.err('%s: %s: unrecognized transformation argument: "%s"' % (self.__class__, line_no, aname))
示例#18
0
def p_error(p):
    g.err(
        "orio.module.splingo.parser: error in input line #%s, at token-type '%s', token-value '%s'"
        % (p.lineno, p.type, p.value))
示例#19
0
    def generate(self, tnode, indent='  ', extra_indent='  '):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.NumLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.StringLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += '[' + self.generate(tnode.sub_exp, indent, extra_indent) + ']'

        elif isinstance(tnode, ast.FunCallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join(
                [self.generate(x, indent, extra_indent) for x in tnode.args])
            s += ')'

        elif isinstance(tnode, ast.UnaryExp):
            s = self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:
                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = '!' + s
            elif tnode.op_type == tnode.PRE_INC:
                s = ' ++' + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = ' --' + s
            elif tnode.op_type == tnode.POST_INC:
                s = s + '++ '
            elif tnode.op_type == tnode.POST_DEC:
                s = s + '-- '
            elif tnode.op_type == tnode.DEREF:
                s = '*' + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = '&' + s
            else:
                g.err(
                    'orio.module.loop.codegen internal error: unknown unary operator type: %s'
                    % tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.MUL:
                s += ' * '
            elif tnode.op_type == tnode.DIV:
                s += ' / '
            elif tnode.op_type == tnode.MOD:
                s += ' % '
            elif tnode.op_type == tnode.ADD:
                s += ' + '
            elif tnode.op_type == tnode.SUB:
                s += ' - '
            elif tnode.op_type == tnode.LT:
                s += ' < '
            elif tnode.op_type == tnode.GT:
                s += ' > '
            elif tnode.op_type == tnode.LE:
                s += ' <= '
            elif tnode.op_type == tnode.GE:
                s += ' >= '
            elif tnode.op_type == tnode.EQ:
                s += ' == '
            elif tnode.op_type == tnode.NE:
                s += ' != '
            elif tnode.op_type == tnode.LOR:
                s += ' || '
            elif tnode.op_type == tnode.LAND:
                s += ' && '
            elif tnode.op_type == tnode.COMMA:
                s += ', '
            elif tnode.op_type == tnode.EQ_ASGN:
                #print "(((((( Binop: tnode.lhs.meta=%s, tnode.rhs.meta=%s ))))) " \
                #    % (str(tnode.lhs.meta),str(tnode.rhs.meta))

                s += ' = '
            else:
                g.err(
                    'orio.module.loop.codegen internal error: unknown binary operator type: %s'
                    % tnode.op_type)
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.ParenthExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.Comment):
            s += indent
            if tnode.text:
                s += '/*' + tnode.text + '*/'
            s += '\n'

        elif isinstance(tnode, ast.ExpStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.exp:
                s += self.generate(tnode.exp, indent, extra_indent)
            s += ';\n'

        elif isinstance(tnode, ast.GotoStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.target:
                s += 'goto ' + tnode.target + ';\n'

        elif isinstance(tnode, ast.CompStmt):
            try:
                tmp = tnode.meta.get('id')
                fake_loop = False
                #if tmp and (not tmp in self.ids):
                if tmp and g.Globals().marker_loops:
                    #self.ids.append(tmp)
                    fake_loop = True
                    #s += tmp + ': \n'
                    fake_scope_loop = 'for (int %s=0; %s < 1; %s++)' % (
                        tmp, tmp, tmp)
                    s += indent + fake_scope_loop
                    old_indent = indent
                    indent += extra_indent
                s += indent + '{\n'

                self.alldecls = set([])
                for stmt in tnode.stmts:
                    g.debug('generating code for stmt type: %s' %
                            stmt.__class__.__name__,
                            obj=self,
                            level=7)
                    s += self.generate(stmt, indent + extra_indent,
                                       extra_indent)
                    g.debug('code so far:' + s, obj=self, level=7)

                s += indent + '}\n'
                if fake_loop: indent = old_indent
            except Exception as e:
                g.err(
                    'orio.module.loop.codegen:%s: encountered an error in C code generation for CompStmt: %s %s'
                    % (tnode.line_no, e.__class__, e))

        elif isinstance(tnode, ast.IfStmt):
            try:
                if tnode.getLabel(): s += tnode.getLabel() + ':'
                s += indent + 'if (' + self.generate(tnode.test, indent,
                                                     extra_indent) + ') '
                if isinstance(tnode.true_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.true_stmt, indent,
                                            extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                    if tnode.false_stmt:
                        s = s[:-1] + ' else '
                else:
                    s += '\n'
                    s += self.generate(tnode.true_stmt, indent + extra_indent,
                                       extra_indent)
                    if tnode.false_stmt:
                        s += indent + 'else '
                if tnode.false_stmt:
                    if isinstance(tnode.false_stmt, ast.CompStmt):
                        tstmt_s = self.generate(tnode.false_stmt, indent,
                                                extra_indent)
                        s += tstmt_s[tstmt_s.index('{'):]
                    else:
                        s += '\n'
                        s += self.generate(tnode.false_stmt,
                                           indent + extra_indent, extra_indent)
            except Exception as e:
                g.err(
                    'orio.module.loop.codegen:%s: encountered an error in C code generation for IfStmt: %s %s '
                    % (tnode.line_no, e.__class__, e))

        elif isinstance(tnode, ast.ForStmt):
            try:
                tmp = tnode.meta.get('id')
                fake_loop = False
                parent_with_id = False
                if tnode.parent:
                    if isinstance(tnode.parent, ast.CompStmt) or isinstance(
                            tnode.parent, ast.ForStmt):
                        if tnode.parent.meta.get('id'):
                            parent_with_id = True
                if not parent_with_id and tmp and g.Globals(
                ).marker_loops:  # and tmp not in self.ids:
                    #self.ids.append(tmp)
                    fake_loop = True
                    #s += tmp + ': \n'
                    fake_scope_loop = 'for (int %s=0; %s < 1; %s++)' % (
                        tmp, tmp, tmp)
                    s += indent + fake_scope_loop + ' {\n'
                    old_indent = indent
                    indent += extra_indent
                local_decl = True

                # In some cases, we wish loop index variables to be accessible after the
                # corresponding loop. For example, the remainder loop generated by register tiling reuses the
                # index variable from the preceding loop, hence, it is declared before the actual loop,
                # so that it can be accessed later.
                if tnode.init and tnode.meta.get('declare_vars_outside'):
                    s += indent + 'int %s;\n' % ', '.join(
                        tnode.meta['declare_vars_outside'])
                    local_decl = False
                s += indent + 'for ('
                if tnode.init:
                    if isinstance(tnode.init, ast.BinOpExp) and local_decl:
                        #if tnode.init.lhs.name.startswith('_orio_'):  # Orio-generated variable
                        s += 'int '
                    s += self.generate(tnode.init, indent, extra_indent)
                s += '; '
                if tnode.test:
                    s += self.generate(tnode.test, indent, extra_indent)
                s += '; '
                if tnode.iter:
                    s += self.generate(tnode.iter, indent, extra_indent)
                s += ') '
                if isinstance(tnode.stmt, ast.CompStmt):
                    stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                    s += stmt_s[stmt_s.index('{'):]
                    self.alldecls = set([])
                else:
                    s += '\n'
                    s += self.generate(tnode.stmt, indent + extra_indent,
                                       extra_indent)

                if fake_loop and tmp:
                    s += indent + '} // ' + fake_scope_loop + '\n'
                    indent = old_indent
            except Exception as e:
                g.err(
                    'orio.module.loop.codegen:%s: encountered an error in C code generation: %s %s'
                    % (tnode.line_no, e.__class__, e))

        elif isinstance(tnode, ast.TransformStmt):
            g.err(
                'orio.module.loop.codegen internal error: a transformation statement is never generated as an output'
            )

        elif isinstance(tnode, ast.VarDecl):
            qual = ''
            if tnode.qualifier.strip():
                qual = str(tnode.qualifier) + ' '
            sv = indent + qual + str(tnode.type_name) + ' '
            sv += ', '.join(tnode.var_names)
            sv += ';\n'
            if not sv in self.alldecls:
                s += sv
                self.alldecls.add(sv)

        elif isinstance(tnode, ast.VarDeclInit):
            qual = ''
            if tnode.qualifier.strip():
                qual = str(tnode.qualifier) + ' '
            s += indent + qual + str(tnode.type_name) + ' '
            s += self.generate(tnode.var_name, indent, extra_indent)
            s += '=' + self.generate(tnode.init_exp, indent, extra_indent)
            s += ';'

        elif isinstance(tnode, ast.Pragma):
            s += '#pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.Container):
            s += self.generate(tnode.ast, indent, extra_indent)

        elif isinstance(tnode, ast.DeclStmt):
            for d in tnode.vars():
                s += self.generate(d, indent, '')
        else:
            g.err(
                'orio.module.loop.codegen internal error: unrecognized type of AST: %s\n%s'
                % (tnode.__class__.__name__, str(tnode)))
        return s
示例#20
0
    def generate(self,
                 tnode,
                 indent='  ',
                 extra_indent='  ',
                 doloop_inc=False):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.NumLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.StringLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            # Now get all the indices
            tmpnode = tnode
            prevtmpnode = tnode
            indices = []
            while isinstance(tmpnode, ast.ArrayRefExp):
                indices.append(tmpnode.sub_exp)
                prevtmpnode = tmpnode
                tmpnode = tmpnode.exp

            indices.reverse()
            s += self.generate(prevtmpnode.exp, indent,
                               extra_indent)  # the variable name
            s += '(' + ','.join(
                [self.generate(x, indent, extra_indent)
                 for x in indices]) + ')'

        elif isinstance(tnode, ast.FunCallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join(
                [self.generate(x, indent, extra_indent) for x in tnode.args])
            s += ')'

        elif isinstance(tnode, ast.UnaryExp):
            s = self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:

                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = 'NOT(' + s + ')'
            elif tnode.op_type == tnode.PRE_INC:
                s += '\n' + indent + s + ' = ' + s + ' + 1\n'
            elif tnode.op_type == tnode.PRE_DEC:
                s += '\n' + indent + s + ' = ' + s + ' - 1\n'
            elif tnode.op_type == tnode.POST_INC:
                s += s + '\n' + indent + s + ' = ' + s + ' + 1\n'
            elif tnode.op_type == tnode.POST_DEC:
                s += s + '\n' + indent + s + ' = ' + s + ' - 1\n'
            else:
                g.err(
                    'orio.module.loop.codegen internal error: unknown unary operator type: %s'
                    % tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            if tnode.op_type not in [tnode.MOD, tnode.COMMA]:
                if not doloop_inc:
                    s += self.generate(tnode.lhs, indent, extra_indent)
                if tnode.op_type == tnode.MUL:
                    s += '*'
                elif tnode.op_type == tnode.DIV:
                    s += '/'
                elif tnode.op_type == tnode.ADD:
                    s += '+'
                elif tnode.op_type == tnode.SUB:
                    s += '-'
                elif tnode.op_type == tnode.LT:
                    s += '<'
                elif tnode.op_type == tnode.GT:
                    s += '>'
                elif tnode.op_type == tnode.LE:
                    s += '<='
                elif tnode.op_type == tnode.GE:
                    s += '>='
                elif tnode.op_type == tnode.EQ:
                    s += '=='
                elif tnode.op_type == tnode.NE:
                    s += '!='
                elif tnode.op_type == tnode.LOR:
                    s += '.OR.'
                elif tnode.op_type == tnode.LAND:
                    s += '.AND.'
                elif tnode.op_type == tnode.EQ_ASGN:
                    s += '='
                else:
                    g.err(
                        'orio.module.loop.codegen internal error: unknown binary operator type: %s'
                        % tnode.op_type)

                s += self.generate(tnode.rhs, indent, extra_indent)

            else:

                if tnode.op_type == tnode.MOD:
                    s += 'MOD(' + self.generate(tnode.lhs, indent, extra_indent) + ', ' \
                        + self.generate(tnode.rhs, indent, extra_indent) + ')'
                elif tnode.op_type == tnode.COMMA:
                    # TODO: We need to implement an AST canonicalization step for Fortran before generating the code.
                    print(
                        'internal warning: Fortran code generator does not fully support the comma operator -- the generated code may not compile.'
                    )
                    s += self.generate(tnode.rhs, indent, extra_indent)
                    s += '\n' + indent + self.generate(tnode.lhs, indent,
                                                       extra_indent)
                    s += '\n! ORIO Warining: check code above and fix problems.'

        elif isinstance(tnode, ast.ParenthExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.Comment):
            s += indent
            if tnode.text:
                s += '!' + tnode.text
            s += '\n'

        elif isinstance(tnode, ast.ExpStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent
            if tnode.exp:
                s += self.generate(tnode.exp, indent, extra_indent)
            s += '\n'

        elif isinstance(tnode, ast.GotoStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent
            if tnode.target:
                s += 'goto ' + tnode.target + '\n'

        elif isinstance(tnode, ast.CompStmt):
            s += indent + '{\n'
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            for stmt in tnode.stmts:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + '}\n'

        elif isinstance(tnode, ast.IfStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent + 'if (' + self.generate(tnode.test, indent,
                                                 extra_indent) + ') then \n'
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                # TODO: fix below cludge -- { is missing for some reason in some compound ifs
                if tstmt_s.count('{') > 0: s += tstmt_s[tstmt_s.index('{'):]
                else: s += tstmt_s
                if tnode.false_stmt:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.generate(tnode.true_stmt, indent + extra_indent,
                                   extra_indent)
                if tnode.false_stmt:
                    s += indent + 'else '
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent,
                                            extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.generate(tnode.false_stmt, indent + extra_indent,
                                       extra_indent)
            s += indent + 'end if\n'

        elif isinstance(tnode, ast.ForStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent + 'do '
            if tnode.init:
                s += self.generate(tnode.init, indent, extra_indent)
            if not tnode.test:
                g.err(
                    'orio.module.loop.codegen:  missing loop test expression. Fortran code generation requires a loop test expression.'
                )

            if not tnode.iter:
                g.err(
                    'orio.module.loop.codegen:  missing loop increment expression. Fortran code generation requires a loop increment expression.'
                )
            s += ', '
            if not isinstance(tnode.test, ast.BinOpExp):
                g.err(
                    'orio.module.loop.codegen internal error: cannot handle code generation for loop test expression'
                )

            if tnode.test.op_type not in [
                    tnode.test.LE, tnode.test.LT, tnode.test.GE, tnode.test.GT
            ]:
                g.err(
                    'orio.module.loop.codegen internal error: cannot generate Fortran loop, only <, >, <=, >= are recognized in the loop limit test'
                )

            # Generate the loop bound
            s += self.generate(tnode.test.rhs, indent, extra_indent)

            # Check whether we need to change the bound
            #if tnode.test.op_type == tnode.test.LE: s += ' + 1'
            #if tnode.test.op_type == tnode.test.GE: s += ' - 1'

            s += ', '
            # Generate the loop increment/decrement step

            if not isinstance(tnode.iter, (ast.BinOpExp, ast.UnaryExp)):
                g.err(
                    'orio.module.loop.codegen internal error: cannot handle code generation for loop increment expression'
                )

            unary = False
            if isinstance(tnode.iter, ast.UnaryExp):
                incr_decr = [
                    tnode.iter.POST_DEC, tnode.iter.PRE_DEC,
                    tnode.iter.POST_INC, tnode.iter.PRE_INC
                ]
                unary = True

            if not ((isinstance(tnode.iter, ast.BinOpExp)
                     and tnode.iter.op_type == tnode.iter.EQ_ASGN) or
                    (isinstance(tnode.iter, ast.UnaryExp)
                     and tnode.iter.op_type in incr_decr)):
                g.err(
                    'orio.module.loop.codegen internal error: cannot handle code generation for loop increment expression'
                )

            if tnode.test.op_type in [tnode.test.GT, tnode.test.GE] \
                and unary and tnode.iter.op_type in [tnode.iter.PRE_DEC, tnode.iter.POST_DEC]:
                s += '-'

            if unary and tnode.iter.op_type in incr_decr:  # ++i
                s += '1'
            else:
                s += self.generate(tnode.iter.rhs, indent, extra_indent, True)

            if isinstance(tnode.stmt, ast.CompStmt):
                s += self.generate(tnode.stmt, indent, extra_indent)
                s += '\n' + indent + 'end do\n'
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent,
                                   extra_indent)
                s += '\n' + indent + 'end do\n'

        elif isinstance(tnode, ast.TransformStmt):
            g.err(
                'orio.module.loop.codegen internal error: a transformation statement is never generated as an output'
            )

        elif isinstance(tnode, ast.VarDecl):

            if tnode.type_name not in list(self.ftypes.keys()):
                g.err(
                    'orio.module.loop.codegen internal error: Cannot generate Fortran type for '
                    + tnode.type_name)

            s += indent + str(self.ftypes[tnode.type_name]) + ' '
            s += ', '.join(tnode.var_names)
            s += '\n'

        elif isinstance(tnode, ast.Pragma):
            s += '$pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.Container):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += self.generate(tnode.ast, indent, extra_indent)

        else:
            g.err(
                'orio.module.loop.codegen internal error: unrecognized type of AST: %s'
                % tnode.__class__.__name__)

        return s
示例#21
0
    def readTransfArgs(self, perf_params, transf_args):
        '''Process the given transformation arguments'''

        # expected argument names
        PLATFORM    = 'platform'
        DEVICE      = 'device'
        WORKGROUPS  = 'workGroups'
        WORKITEMS   = 'workItemsPerGroup'
        CB          = 'cacheBlocks'
        STREAMCOUNT = 'streamCount'
        UIF         = 'unrollInner'
        CLFLAGS     = 'clFlags'
        THREADCOUNT = 'threadCount'
        BLOCKCOUNT  = 'blockCount'
        VECHINT     = 'vecHint'
        SIZEHINT    = 'sizeHint'

        # default argument values
        platform = 0
        device = 0
        workGroups  = None
        workItemsPerGroup   = None
        cacheBlocks  = False
        streamCount  = 1
        unrollInner  = None
        clFlags      = None
        vecHint      = 0
        sizeHint     = False

        # iterate over all transformation arguments
        errors = ''
        for aname, rhs, line_no in transf_args:
            # evaluate the RHS expression
            try:
                rhs = eval(rhs, perf_params)
            except Exception as e:
                g.err('orio.module.loop.submodule.opencl.opencl: %s: failed to evaluate the argument expression: %s\n --> %s: %s' % (line_no, rhs,e.__class__.__name__, e))

            if aname == PLATFORM:
                # TODO: validate
                platform = rhs
            elif aname == DEVICE:
                # TODO: validate
                device = rhs
            elif aname == WORKGROUPS:
                # TODO: validate
                workGroups = rhs
            elif aname == WORKITEMS:
                # TODO: validate
                workItemsPerGroup = rhs
            elif aname == CB:
                # TODO: validate
                cacheBlocks = rhs
            elif aname == STREAMCOUNT:
                # TODO: validate
                streamCount = rhs
            elif aname == UIF:
                # TODO: validate
                unrollInner = rhs
            elif aname == CLFLAGS:
                clFlags = rhs
            elif aname == THREADCOUNT:
                g.warn("Interpreting CUDA threadCount as OpenCL workItemsPerGroup")
                workItemsPerGroup = rhs
            elif aname == BLOCKCOUNT:
                g.warn("Interpreting CUDA blockCount as OpenCL workGroups")
                workGroups = rhs
            elif aname == VECHINT:
                vecHint = rhs
            elif aname == SIZEHINT:
                sizeHint = rhs
            else:
                g.err('%s: %s: unrecognized transformation argument: "%s"' % (self.__class__, line_no, aname))

        if not errors == '':
            raise Exception('%s: errors evaluating transformation args:\n%s' % (self.__class__, errors))

        # return evaluated transformation arguments
        return {
          PLATFORM:platform,
          DEVICE:device,
          WORKGROUPS:workGroups,
          WORKITEMS:workItemsPerGroup,
          CB:cacheBlocks,
          STREAMCOUNT:streamCount,
          UIF:unrollInner,
          CLFLAGS:clFlags,
          VECHINT:vecHint,
          SIZEHINT:sizeHint,}
示例#22
0
    def generate(self, tnode, indent = '  ', extra_indent = '  '):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.NumLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.StringLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += '[' + self.generate(tnode.sub_exp, indent, extra_indent) + ']'

        elif isinstance(tnode, ast.FunCallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join(map(lambda x: self.generate(x, indent, extra_indent), tnode.args))
            s += ')'

        elif isinstance(tnode, ast.UnaryExp):
            s = self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:
                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = '!' + s
            elif tnode.op_type == tnode.PRE_INC:
                s = ' ++' + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = ' --' + s
            elif tnode.op_type == tnode.POST_INC:
                s = s + '++ '
            elif tnode.op_type == tnode.POST_DEC:
                s = s + '-- '
            elif tnode.op_type == tnode.DEREF:
                s = '*' + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = '&' + s
            else:
                g.err('orio.module.loop.codegen internal error: unknown unary operator type: %s' % tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.MUL:
                s += '*'
            elif tnode.op_type == tnode.DIV:
                s += '/'
            elif tnode.op_type == tnode.MOD:
                s += '%'
            elif tnode.op_type == tnode.ADD:
                s += '+'
            elif tnode.op_type == tnode.SUB:
                s += '-'
            elif tnode.op_type == tnode.LT:
                s += '<'
            elif tnode.op_type == tnode.GT:
                s += '>'
            elif tnode.op_type == tnode.LE:
                s += '<='
            elif tnode.op_type == tnode.GE:
                s += '>='
            elif tnode.op_type == tnode.EQ:
                s += '=='
            elif tnode.op_type == tnode.NE:
                s += '!='
            elif tnode.op_type == tnode.LOR:
                s += '||'
            elif tnode.op_type == tnode.LAND:
                s += '&&'
            elif tnode.op_type == tnode.COMMA:
                s += ','
            elif tnode.op_type == tnode.EQ_ASGN:
                #print "(((((( Binop: tnode.lhs.meta=%s, tnode.rhs.meta=%s ))))) " \
                #    % (str(tnode.lhs.meta),str(tnode.rhs.meta))

                s += '='
            else:
                g.err('orio.module.loop.codegen internal error: unknown binary operator type: %s' % tnode.op_type)
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.ParenthExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.Comment):
            s += indent
            if tnode.text:
                s += '/*' + tnode.text + '*/'
            s += '\n'
            
        elif isinstance(tnode, ast.ExpStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.exp:
                s += self.generate(tnode.exp, indent, extra_indent)
            s += ';\n'

        elif isinstance(tnode, ast.GotoStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.target:
                s += 'goto ' + tnode.target + ';\n'
                
        elif isinstance(tnode, ast.CompStmt):
            s += indent + '{\n'
            self.alldecls = set([])
            for stmt in tnode.stmts:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + '}\n'

        elif isinstance(tnode, ast.IfStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + 'if (' + self.generate(tnode.test, indent, extra_indent) + ') '
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                s += tstmt_s[tstmt_s.index('{'):]
                if tnode.false_stmt:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.generate(tnode.true_stmt, indent + extra_indent, extra_indent)
                if tnode.false_stmt:
                    s += indent + 'else '
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent, extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.generate(tnode.false_stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.ForStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + 'for ('
            if tnode.init:
                s += self.generate(tnode.init, indent, extra_indent)
            s += '; '
            if tnode.test:
                s += self.generate(tnode.test, indent, extra_indent)
            s += '; '
            if tnode.iter:
                s += self.generate(tnode.iter, indent, extra_indent)
            s += ') '
            if isinstance(tnode.stmt, ast.CompStmt): 
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
                self.alldecls = set([])
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.TransformStmt):
            g.err('orio.module.loop.codegen internal error: a transformation statement is never generated as an output')

        elif isinstance(tnode, ast.VarDecl):
            sv = indent + str(tnode.type_name) + ' '
            sv += ', '.join(tnode.var_names)
            sv += ';\n'
            if not sv in self.alldecls: 
                s += sv
                self.alldecls.add(sv)

        elif isinstance(tnode, ast.VarDeclInit):
            s += indent + str(tnode.type_name) + ' '
            s += self.generate(tnode.var_name, indent, extra_indent)
            s += '=' + self.generate(tnode.init_exp, indent, extra_indent)
            s += ';\n'

        elif isinstance(tnode, ast.Pragma):
            s += '#pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.Container):
            s += self.generate(tnode.ast, indent, extra_indent)

        else:
            g.err('orio.module.loop.codegen internal error: unrecognized type of AST: %s' % tnode.__class__.__name__)

        return s
示例#23
0
def p_error(p):
    col = find_column(p.lexer.lexdata, p)
    g.err(
        __name__ +
        ": unexpected token-type '%s', token-value '%s' at line %s, column %s"
        % (p.type, p.value, p.lineno, col))
示例#24
0
    def generate(self, tnode, indent = '  ', extra_indent = '  '):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.Comment):
            s += indent + '/*' + tnode.text + '*/\n'
            
        elif isinstance(tnode, ast.LitExp):
            s += str(tnode.val).encode('string-escape')

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += '[' + self.generate(tnode.sub, indent, extra_indent) + ']'

        elif isinstance(tnode, ast.CallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join(map(lambda x: self.generate(x, indent, extra_indent), tnode.args))
            s += ')'

        elif isinstance(tnode, ast.CastExp):
            s += '(' + self.generate(tnode.castto, indent, extra_indent) + ')'
            s += self.generate(tnode.exp, indent, extra_indent)

        elif isinstance(tnode, ast.UnaryExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:
                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = '!' + s
            elif tnode.op_type == tnode.BNOT:
                s = '~' + s
            elif tnode.op_type == tnode.PRE_INC:
                s = ' ++' + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = ' --' + s
            elif tnode.op_type == tnode.POST_INC:
                s += '++ '
            elif tnode.op_type == tnode.POST_DEC:
                s += '-- '
            elif tnode.op_type == tnode.DEREF:
                s = '*' + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = '&' + s
            elif tnode.op_type == tnode.SIZEOF:
                s = 'sizeof ' + s
            else:
                g.err(__name__+': internal error: unknown unary operator type: %s' % tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s += '+'
            elif tnode.op_type == tnode.MINUS:
                s += '-'
            elif tnode.op_type == tnode.MULT:
                s += '*'
            elif tnode.op_type == tnode.DIV:
                s += '/'
            elif tnode.op_type == tnode.MOD:
                s += '%'
            elif tnode.op_type == tnode.LT:
                s += '<'
            elif tnode.op_type == tnode.GT:
                s += '>'
            elif tnode.op_type == tnode.LE:
                s += '<='
            elif tnode.op_type == tnode.GE:
                s += '>='
            elif tnode.op_type == tnode.EE:
                s += '=='
            elif tnode.op_type == tnode.NE:
                s += '!='
            elif tnode.op_type == tnode.LOR:
                s += '||'
            elif tnode.op_type == tnode.LAND:
                s += '&&'
            elif tnode.op_type == tnode.EQ:
                s += '='
            elif tnode.op_type == tnode.PLUSEQ:
                s += '+='
            elif tnode.op_type == tnode.MINUSEQ:
                s += '-='
            elif tnode.op_type == tnode.MULTEQ:
                s += '*='
            elif tnode.op_type == tnode.DIVEQ:
                s += '/='
            elif tnode.op_type == tnode.MODEQ:
                s += '%='
            elif tnode.op_type == tnode.COMMA:
                s += ','
            elif tnode.op_type == tnode.BOR:
                s += '|'
            elif tnode.op_type == tnode.BAND:
                s += '&'
            elif tnode.op_type == tnode.BXOR:
                s += '^'
            elif tnode.op_type == tnode.BSHL:
                s += '<<'
            elif tnode.op_type == tnode.BSHR:
                s += '>>'
            elif tnode.op_type == tnode.BSHLEQ:
                s += '<<='
            elif tnode.op_type == tnode.BSHREQ:
                s += '>>='
            elif tnode.op_type == tnode.BANDEQ:
                s += '&='
            elif tnode.op_type == tnode.BXOREQ:
                s += '^='
            elif tnode.op_type == tnode.BOREQ:
                s += '|='
            elif tnode.op_type == tnode.DOT:
                s += '.'
            elif tnode.op_type == tnode.SELECT:
                s += '->'
            else:
                g.err(__name__+': internal error: unknown binary operator type: %s' % tnode.op_type)
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.TernaryExp):
            s += self.generate(tnode.test, indent, extra_indent) + '?'
            s += self.generate(tnode.true_exp, indent, extra_indent) + ':'
            s += self.generate(tnode.false_exp, indent, extra_indent)

        elif isinstance(tnode, ast.ParenExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.CompStmt):
            s += indent + '{\n'
            for stmt in tnode.kids:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + '}\n'

        elif isinstance(tnode, ast.ExpStmt):
            s += indent + self.generate(tnode.exp, indent, extra_indent) + ';\n'

        elif isinstance(tnode, ast.IfStmt):
            s += indent + 'if (' + self.generate(tnode.test, indent, extra_indent) + ') '
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                s += tstmt_s[tstmt_s.index('{'):]
                if tnode.false_stmt:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.generate(tnode.true_stmt, indent + extra_indent, extra_indent)
                if tnode.false_stmt:
                    s += indent + 'else '
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent, extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.generate(tnode.false_stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.ForStmt):
            s += indent + 'for ('
            if tnode.init:
                s += self.generate(tnode.init, indent, extra_indent)
            s += '; '
            if tnode.test:
                s += self.generate(tnode.test, indent, extra_indent)
            s += '; '
            if tnode.iter:
                s += self.generate(tnode.iter, indent, extra_indent)
            s += ') '
            if isinstance(tnode.stmt, ast.CompStmt): 
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.WhileStmt):
            s += indent + 'while (' + self.generate(tnode.test, indent, extra_indent) + ') '
            if isinstance(tnode.stmt, ast.CompStmt): 
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.VarDec):
            if not tnode.isnested:
                s += indent
            s += ' '.join(tnode.type_name) + ' '
            s += ', '.join(map(lambda x: self.generate(x, indent, extra_indent), tnode.var_inits))
            if not tnode.isnested:
                s += ';\n'

        elif isinstance(tnode, ast.ParamDec):
            s += indent + str(tnode.ty) + ' ' + str(tnode.name)

        elif isinstance(tnode, ast.FunDec):
            s += indent + str(tnode.return_type) + ' ' + str(tnode.modifiers)
            s += tnode.name + '('
            s += ', '.join(map(lambda x: self.generate(x, indent, extra_indent), tnode.params))
            s += ')' + self.generate(tnode.body, indent, extra_indent)

        elif isinstance(tnode, ast.Pragma):
            s += indent + '#pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.TransformStmt):
            g.err(__name__+': internal error: a transformation statement is never generated as an output')

        else:
            g.err(__name__+': internal error: unrecognized type of AST: %s' % tnode.__class__.__name__)

        return s
示例#25
0
 def t_error(self, t):
     g.err('%s: illegal character (%s) at line %s' %
           (self.__class__, t.value[0], t.lexer.lineno))
示例#26
0
文件: cuda.py 项目: zhjp0/Orio
    def getDeviceProps(self):
        '''Get device properties'''

        # First, check if user specified the device properties file
        if self.tinfo.device_spec_file:
            qout = self.tinfo.device_spec_file
        else:
            # generate the query code
            qsrc = "enum_cuda_props.cu"
            qexec = qsrc + ".o"
            qout = qexec + ".props"
            if not os.path.exists(qout):
                # check for nvcc
                qcmd = 'which nvcc'
                status = os.system(qcmd)
                if status != 0:
                    g.err("%s: could not locate nvcc with '%s'" %
                          (self.__class__, qcmd))

                try:
                    f = open(qsrc, 'w')
                    f.write(CUDA_DEVICE_QUERY_SKELET)
                    f.close()
                except:
                    g.err('%s: cannot open file for writing: %s' %
                          (self.__class__, qsrc))

                # compile the query
                cmd = 'nvcc -o %s %s' % (qexec, qsrc)
                status = os.system(cmd)
                if status:
                    g.err(
                        '%s: failed to compile cuda device query code: "%s"' %
                        (self.__class__, cmd))

                # execute the query
                runcmd = './%s' % (qexec)
                status = os.system(runcmd)
                if status:
                    g.err(
                        '%s: failed to execute cuda device query code: "%s"' %
                        (self.__class__, runcmd))
                os.remove(qsrc)
                os.remove(qexec)

        # read device properties
        props = {}
        try:
            f = open(qout, 'r')
            for line in f:
                eline = ast.literal_eval(line)
                props[eline[0]] = eline[1]
            f.close()
        except:
            g.err('%s: cannot open query output file for reading: %s' %
                  (self.__class__, qout))

        if props['devId'] == -2:
            g.err("%s: there is no CUDA 1.0 enabled GPU on this machine" %
                  self.__class__)

        if props['major'] < 2 and props['minor'] < 3:
            g.warn(
                "%s: running on compute capability less than 1.3 is not recommended, detected %s.%s."
                % (self.__class__, props['major'], props['minor']))

        # set the arch to the latest supported by the device
        if self.tinfo is None:
            bcmd = "gcc"
        else:
            bcmd = self.tinfo.build_cmd

        if bcmd.startswith('gcc'):
            bcmd = 'nvcc'
        if bcmd.find('-arch') == -1:
            bcmd += ' -arch=sm_' + str(props['major']) + str(props['minor'])
        if self.perf_params is not None and self.perf_params.has_key(
                'CFLAGS') and bcmd.find('@CFLAGS') == -1:
            bcmd += ' @CFLAGS'
        if self.tinfo is not None:
            self.tinfo.build_cmd = bcmd

        # return queried device props
        return props
示例#27
0
    def collectNode(self, f, n):
        ''' Collect within the given node a list using the given function: pre-order traversal. '''
        
        if isinstance(n, orio.module.loop.ast.NumLitExp):
            return f(n)
        
        elif isinstance(n, orio.module.loop.ast.StringLitExp):
            return f(n)
        
        elif isinstance(n, orio.module.loop.ast.IdentExp):
            return f(n)
        
        elif isinstance(n, orio.module.loop.ast.ArrayRefExp):
            return f(n) + self.collectNode(f, n.exp) + self.collectNode(f, n.sub_exp)

        elif isinstance(n, orio.module.loop.ast.FunCallExp):
            return reduce(lambda x,y: x + y,
                          [self.collectNode(f, a) for a in n.args],
                          f(n))
        
        elif isinstance(n, orio.module.loop.ast.CastExpr):
            return f(n) + self.collectNode(f, n.expr)
        
        elif isinstance(n, orio.module.loop.ast.UnaryExp):
            return f(n) + self.collectNode(f, n.exp)
        
        elif isinstance(n, orio.module.loop.ast.BinOpExp):
            return f(n) + self.collectNode(f, n.lhs) + self.collectNode(f, n.rhs)
        
        elif isinstance(n, orio.module.loop.ast.ParenthExp):
            return f(n) + self.collectNode(f, n.exp)
        
        elif isinstance(n, orio.module.loop.ast.Comment):
            return f(n) + self.collectNode(f, n.text)
        
        elif isinstance(n, orio.module.loop.ast.ExpStmt):
            return f(n) + self.collectNode(f, n.exp)
        
        elif isinstance(n, orio.module.loop.ast.GotoStmt):
            return f(n) + self.collectNode(f, n.target)
        elif isinstance(n, orio.module.loop.ast.VarDecl):
            return f(n)
        elif isinstance(n, orio.module.loop.ast.VarDeclInit):
            return f(n)
        elif isinstance(n, orio.module.loop.ast.CompStmt):
            return reduce(lambda x,y: x + y,
                          [self.collectNode(f, a) for a in n.stmts],
                          f(n))
        
        elif isinstance(n, orio.module.loop.ast.IfStmt):
            result = self.collectNode(f, n.test) + self.collectNode(f, n.true_stmt)
            if n.false_stmt:
                result += self.collectNode(f, n.false_stmt)
            return result
        
        elif isinstance(n, orio.module.loop.ast.ForStmt):
            result = []
            if n.init:
                result += self.collectNode(f, n.init)
            if n.test:
                result += self.collectNode(f, n.test)
            if n.iter:
                result += self.collectNode(f, n.iter)
            result += self.collectNode(f, n.stmt)
            return result
        
        elif isinstance(n, orio.module.loop.ast.AssignStmt):
            return f(n) + self.collectNode(f, n.var) + self.collectNode(f, n.exp)
        
        elif isinstance(n, orio.module.loop.ast.TransformStmt):
            return f(n) + self.collectNode(f, n.name) + self.collectNode(f, n.args) + self.collectNode(f, n.stmt)

        else:
            g.err('orio.module.loop.ast_lib.common_lib.collectNode: unexpected AST type: "%s"' % n.__class__.__name__)
示例#28
0
文件: codegen_cuda.py 项目: phrb/Orio
    def generate(self, tnode, indent="  ", extra_indent="  "):
        """To generate code that corresponds to the given AST"""

        s = ""

        if isinstance(tnode, ast.NumLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.StringLitExp):
            s += '"' + str(tnode.val) + '"'

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += "[" + self.generate(tnode.sub_exp, indent, extra_indent) + "]"

        elif isinstance(tnode, ast.FunCallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + "("
            s += ",".join(
                map(lambda x: self.generate(x, indent, extra_indent), tnode.args)
            )
            s += ")"

        elif isinstance(tnode, ast.UnaryExp):
            s = self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = "+" + s
            elif tnode.op_type == tnode.MINUS:
                s = "-" + s
            elif tnode.op_type == tnode.LNOT:
                s = "!" + s
            elif tnode.op_type == tnode.PRE_INC:
                s = " ++" + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = " --" + s
            elif tnode.op_type == tnode.POST_INC:
                s = s + "++ "
            elif tnode.op_type == tnode.POST_DEC:
                s = s + "-- "
            elif tnode.op_type == tnode.DEREF:
                s = "*" + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = "&" + s
            else:
                g.err(
                    "orio.module.loop.codegen_cuda internal error: unknown unary operator type: %s"
                    % tnode.op_type
                )

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.MUL:
                s += "*"
            elif tnode.op_type == tnode.DIV:
                s += "/"
            elif tnode.op_type == tnode.MOD:
                s += "%"
            elif tnode.op_type == tnode.ADD:
                s += "+"
            elif tnode.op_type == tnode.SUB:
                s += "-"
            elif tnode.op_type == tnode.LT:
                s += "<"
            elif tnode.op_type == tnode.GT:
                s += ">"
            elif tnode.op_type == tnode.LE:
                s += "<="
            elif tnode.op_type == tnode.GE:
                s += ">="
            elif tnode.op_type == tnode.EQ:
                s += "=="
            elif tnode.op_type == tnode.NE:
                s += "!="
            elif tnode.op_type == tnode.LOR:
                s += "||"
            elif tnode.op_type == tnode.LAND:
                s += "&&"
            elif tnode.op_type == tnode.COMMA:
                s += ","
            elif tnode.op_type == tnode.EQ_ASGN:
                s += "="
            elif tnode.op_type == tnode.ASGN_ADD:
                s += "+="
            elif tnode.op_type == tnode.ASGN_SHR:
                s += ">>="
            elif tnode.op_type == tnode.ASGN_SHL:
                s += "<<="
            elif tnode.op_type == tnode.BAND:
                s += "&"
            elif tnode.op_type == tnode.SHR:
                s += ">>"
            else:
                g.err(
                    "orio.module.loop.codegen_cuda internal error: unknown binary operator type: %s"
                    % tnode.op_type
                )
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.TernaryExp):
            s += self.generate(tnode.test, indent, extra_indent) + "?"
            s += self.generate(tnode.true_expr, indent, extra_indent) + ":"
            s += self.generate(tnode.false_expr, indent, extra_indent)

        elif isinstance(tnode, ast.ParenthExp):
            s += "(" + self.generate(tnode.exp, indent, extra_indent) + ")"

        elif isinstance(tnode, ast.Comment):
            s += indent
            if tnode.text:
                s += "/*" + tnode.text + "*/"
            s += "\n"

        elif isinstance(tnode, ast.ExpStmt):
            if tnode.getLabel():
                s += tnode.getLabel() + ":"
            s += indent
            if tnode.exp:
                s += self.generate(tnode.exp, indent, extra_indent)
            s += ";\n"

        elif isinstance(tnode, ast.GotoStmt):
            if tnode.getLabel():
                s += tnode.getLabel() + ":"
            s += indent
            if tnode.target:
                s += "goto " + tnode.target + ";\n"

        elif isinstance(tnode, ast.CompStmt):
            s += indent + "{\n"
            for stmt in tnode.stmts:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + "}\n"

        elif isinstance(tnode, ast.IfStmt):
            if tnode.getLabel():
                s += tnode.getLabel() + ":"
            s += (
                indent + "if (" + self.generate(tnode.test, indent, extra_indent) + ") "
            )
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                s += tstmt_s[tstmt_s.index("{") :]
                if tnode.false_stmt:
                    s = s[:-1] + " else "
            else:
                s += "\n"
                s += self.generate(tnode.true_stmt, indent + extra_indent, extra_indent)
                if tnode.false_stmt:
                    s += indent + "else "
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent, extra_indent)
                    s += tstmt_s[tstmt_s.index("{") :]
                else:
                    s += "\n"
                    s += self.generate(
                        tnode.false_stmt, indent + extra_indent, extra_indent
                    )

        elif isinstance(tnode, ast.ForStmt):
            if tnode.getLabel():
                s += tnode.getLabel() + ":"
            s += indent + "for ("
            if tnode.init:
                if isinstance(tnode.init, ast.VarDeclInit):
                    s += str(tnode.init.type_name) + " "
                    s += self.generate(tnode.init.var_name, indent, extra_indent)
                    s += "=" + self.generate(tnode.init.init_exp, indent, extra_indent)
                else:
                    s += self.generate(tnode.init, indent, extra_indent)
            s += "; "
            if tnode.test:
                s += self.generate(tnode.test, indent, extra_indent)
            s += "; "
            if tnode.iter:
                s += self.generate(tnode.iter, indent, extra_indent)
            s += ") "
            if isinstance(tnode.stmt, ast.CompStmt):
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index("{") :]
            else:
                s += "\n"
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.AssignStmt):
            if tnode.getLabel():
                s += tnode.getLabel() + ":"
            s += indent + tnode.var + "="
            s += self.generate(tnode.exp, indent, extra_indent)
            s += ";\n"

        elif isinstance(tnode, ast.TransformStmt):
            g.err(
                "orio.module.loop.codegen_cuda internal error: a transformation statement is never generated as an output"
            )

        elif isinstance(tnode, ast.VarDecl):
            s += indent + str(tnode.type_name) + " "
            if isinstance(tnode.var_names[0], ast.IdentExp):
                s += ", ".join(map(self.generate, tnode.var_names))
            else:
                s += ", ".join(tnode.var_names)
            s += ";\n"

        elif isinstance(tnode, ast.VarDeclInit):
            s += indent + str(tnode.type_name) + " "
            s += self.generate(tnode.var_name, indent, extra_indent)
            s += "=" + self.generate(tnode.init_exp, indent, extra_indent)
            s += ";\n"

        elif isinstance(tnode, ast.FieldDecl):
            s += tnode.ty + " "
            if isinstance(tnode.name, ast.IdentExp):
                s += tnode.name.name
            else:
                s += tnode.name

        elif isinstance(tnode, ast.FunDecl):
            s += indent + " ".join(tnode.modifiers) + " "
            s += tnode.return_type + " "
            s += tnode.name + "("
            s += ", ".join(map(self.generate, tnode.params)) + ") "
            s += self.generate(tnode.body, indent, extra_indent)

        elif isinstance(tnode, ast.Pragma):
            s += indent + "#pragma " + str(tnode.pstring) + "\n"

        elif isinstance(tnode, ast.Container):
            s += self.generate(tnode.ast, indent, extra_indent)

        elif isinstance(tnode, ast.WhileStmt):
            s += indent + "while (" + self.generate(tnode.test, indent, extra_indent)
            s += ") "
            if isinstance(tnode.stmt, ast.CompStmt):
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index("{") :]
            else:
                s += "\n"
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.CastExpr):
            s += "(" + tnode.ctype + ")"
            s += self.generate(tnode.expr, indent, extra_indent)

        else:
            g.err(
                "orio.module.loop.codegen_cuda internal error: unrecognized type of AST: %s"
                % tnode.__class__.__name__
            )

        return s
示例#29
0
    def replaceIdent(self, tnode, iname_from, iname_to):
        '''Replace the names of all matching identifiers with the given name'''

        if isinstance(tnode, orio.module.loop.ast.NumLitExp):
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.StringLitExp):
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.IdentExp):
            if tnode.name == iname_from:
                tnode.name = iname_to
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.ArrayRefExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            tnode.sub_exp = self.replaceIdent(tnode.sub_exp, iname_from, iname_to)
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.FunCallExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            tnode.args = [self.replaceIdent(a, iname_from, iname_to) for a in tnode.args]
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.UnaryExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.BinOpExp):
            tnode.lhs = self.replaceIdent(tnode.lhs, iname_from, iname_to)
            tnode.rhs = self.replaceIdent(tnode.rhs, iname_from, iname_to)
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.ParenthExp):
            tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.ExpStmt):
            if tnode.exp:
                tnode.exp = self.replaceIdent(tnode.exp, iname_from, iname_to)
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.CompStmt):
            tnode.stmts = [self.replaceIdent(s, iname_from, iname_to) for s in tnode.stmts]
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.IfStmt):
            tnode.test = self.replaceIdent(tnode.test, iname_from, iname_to)
            tnode.true_stmt = self.replaceIdent(tnode.true_stmt, iname_from, iname_to)
            if tnode.false_stmt:
                tnode.false_stmt = self.replaceIdent(tnode.false_stmt, iname_from, iname_to)
            return tnode
            
        elif isinstance(tnode, orio.module.loop.ast.ForStmt):
            if tnode.init:
                tnode.init = self.replaceIdent(tnode.init, iname_from, iname_to)
            if tnode.test:
                tnode.test = self.replaceIdent(tnode.test, iname_from, iname_to)
            if tnode.iter:
                tnode.iter = self.replaceIdent(tnode.iter, iname_from, iname_to)
            tnode.stmt = self.replaceIdent(tnode.stmt, iname_from, iname_to)
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.TransformStmt):
            g.err('orio.module.loop.ast_lib.common_lib internal error:  unexpected AST type: "%s"' % tnode.__class__.__name__)
        
        elif isinstance(tnode, orio.module.loop.ast.NewAST):
            return tnode

        elif isinstance(tnode, orio.module.loop.ast.Comment):
            return tnode

        else:
            g.err('orio.module.loop.ast_lib.common_lib internal error:  unexpected AST type: "%s"' % tnode.__class__.__name__)
示例#30
0
文件: common_lib.py 项目: phrb/Orio
    def collectNode(self, f, n):
        """ Collect within the given node a list using the given function: pre-order traversal. """

        if isinstance(n, orio.module.loop.ast.NumLitExp):
            return f(n)

        elif isinstance(n, orio.module.loop.ast.StringLitExp):
            return f(n)

        elif isinstance(n, orio.module.loop.ast.IdentExp):
            return f(n)

        elif isinstance(n, orio.module.loop.ast.ArrayRefExp):
            return f(n) + self.collectNode(f, n.exp) + self.collectNode(f, n.sub_exp)

        elif isinstance(n, orio.module.loop.ast.FunCallExp):
            return reduce(
                lambda x, y: x + y, [self.collectNode(f, a) for a in n.args], f(n)
            )

        elif isinstance(n, orio.module.loop.ast.CastExpr):
            return f(n) + self.collectNode(f, n.expr)

        elif isinstance(n, orio.module.loop.ast.UnaryExp):
            return f(n) + self.collectNode(f, n.exp)

        elif isinstance(n, orio.module.loop.ast.BinOpExp):
            return f(n) + self.collectNode(f, n.lhs) + self.collectNode(f, n.rhs)

        elif isinstance(n, orio.module.loop.ast.ParenthExp):
            return f(n) + self.collectNode(f, n.exp)

        elif isinstance(n, orio.module.loop.ast.Comment):
            return f(n) + self.collectNode(f, n.text)

        elif isinstance(n, orio.module.loop.ast.ExpStmt):
            return f(n) + self.collectNode(f, n.exp)

        elif isinstance(n, orio.module.loop.ast.GotoStmt):
            return f(n) + self.collectNode(f, n.target)
        elif isinstance(n, orio.module.loop.ast.VarDecl):
            return f(n)
        elif isinstance(n, orio.module.loop.ast.VarDeclInit):
            return f(n)
        elif isinstance(n, orio.module.loop.ast.CompStmt):
            return reduce(
                lambda x, y: x + y, [self.collectNode(f, a) for a in n.stmts], f(n)
            )

        elif isinstance(n, orio.module.loop.ast.IfStmt):
            result = self.collectNode(f, n.test) + self.collectNode(f, n.true_stmt)
            if n.false_stmt:
                result += self.collectNode(f, n.false_stmt)
            return result

        elif isinstance(n, orio.module.loop.ast.ForStmt):
            result = []
            if n.init:
                result += self.collectNode(f, n.init)
            if n.test:
                result += self.collectNode(f, n.test)
            if n.iter:
                result += self.collectNode(f, n.iter)
            result += self.collectNode(f, n.stmt)
            return result

        elif isinstance(n, orio.module.loop.ast.AssignStmt):
            return f(n) + self.collectNode(f, n.var) + self.collectNode(f, n.exp)

        elif isinstance(n, orio.module.loop.ast.TransformStmt):
            return (
                f(n)
                + self.collectNode(f, n.name)
                + self.collectNode(f, n.args)
                + self.collectNode(f, n.stmt)
            )

        else:
            g.err(
                'orio.module.loop.ast_lib.common_lib.collectNode: unexpected AST type: "%s"'
                % n.__class__.__name__
            )
示例#31
0
 def t_error(self, t):
     g.err("orio.module.loops.lexer: illegal character (%s) at line %s" % (t.value[0], t.lexer.lineno))
示例#32
0
文件: pparser.py 项目: nchaimov/Orio
def p_error(p):
    g.err("orio.main.tspec.pparser.parser: error in input line #%s, at token-type '%s', token-value '%s'"
          % (p.lineno, p.type, p.value))
示例#33
0
def p_error(p):
    g.err("orio.module.splingo.parser: error in input line #%s, at token-type '%s', token-value '%s'" % (p.lineno, p.type, p.value))
示例#34
0
    def generate(self, tnode, indent='  ', extra_indent='  '):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.Comment):
            s += indent + '/*' + tnode.text + '*/\n'

        elif isinstance(tnode, ast.LitExp):
            s += str(tnode.val).encode('string-escape')

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += '[' + self.generate(tnode.sub, indent, extra_indent) + ']'

        elif isinstance(tnode, ast.CallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join(
                map(lambda x: self.generate(x, indent, extra_indent),
                    tnode.args))
            s += ')'

        elif isinstance(tnode, ast.CastExp):
            s += '(' + self.generate(tnode.castto, indent, extra_indent) + ')'
            s += self.generate(tnode.exp, indent, extra_indent)

        elif isinstance(tnode, ast.UnaryExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:
                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = '!' + s
            elif tnode.op_type == tnode.BNOT:
                s = '~' + s
            elif tnode.op_type == tnode.PRE_INC:
                s = ' ++' + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = ' --' + s
            elif tnode.op_type == tnode.POST_INC:
                s += '++ '
            elif tnode.op_type == tnode.POST_DEC:
                s += '-- '
            elif tnode.op_type == tnode.DEREF:
                s = '*' + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = '&' + s
            elif tnode.op_type == tnode.SIZEOF:
                s = 'sizeof ' + s
            else:
                g.err(__name__ +
                      ': internal error: unknown unary operator type: %s' %
                      tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s += '+'
            elif tnode.op_type == tnode.MINUS:
                s += '-'
            elif tnode.op_type == tnode.MULT:
                s += '*'
            elif tnode.op_type == tnode.DIV:
                s += '/'
            elif tnode.op_type == tnode.MOD:
                s += '%'
            elif tnode.op_type == tnode.LT:
                s += '<'
            elif tnode.op_type == tnode.GT:
                s += '>'
            elif tnode.op_type == tnode.LE:
                s += '<='
            elif tnode.op_type == tnode.GE:
                s += '>='
            elif tnode.op_type == tnode.EE:
                s += '=='
            elif tnode.op_type == tnode.NE:
                s += '!='
            elif tnode.op_type == tnode.LOR:
                s += '||'
            elif tnode.op_type == tnode.LAND:
                s += '&&'
            elif tnode.op_type == tnode.EQ:
                s += '='
            elif tnode.op_type == tnode.PLUSEQ:
                s += '+='
            elif tnode.op_type == tnode.MINUSEQ:
                s += '-='
            elif tnode.op_type == tnode.MULTEQ:
                s += '*='
            elif tnode.op_type == tnode.DIVEQ:
                s += '/='
            elif tnode.op_type == tnode.MODEQ:
                s += '%='
            elif tnode.op_type == tnode.COMMA:
                s += ','
            elif tnode.op_type == tnode.BOR:
                s += '|'
            elif tnode.op_type == tnode.BAND:
                s += '&'
            elif tnode.op_type == tnode.BXOR:
                s += '^'
            elif tnode.op_type == tnode.BSHL:
                s += '<<'
            elif tnode.op_type == tnode.BSHR:
                s += '>>'
            elif tnode.op_type == tnode.BSHLEQ:
                s += '<<='
            elif tnode.op_type == tnode.BSHREQ:
                s += '>>='
            elif tnode.op_type == tnode.BANDEQ:
                s += '&='
            elif tnode.op_type == tnode.BXOREQ:
                s += '^='
            elif tnode.op_type == tnode.BOREQ:
                s += '|='
            elif tnode.op_type == tnode.DOT:
                s += '.'
            elif tnode.op_type == tnode.SELECT:
                s += '->'
            else:
                g.err(__name__ +
                      ': internal error: unknown binary operator type: %s' %
                      tnode.op_type)
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.TernaryExp):
            s += self.generate(tnode.test, indent, extra_indent) + '?'
            s += self.generate(tnode.true_exp, indent, extra_indent) + ':'
            s += self.generate(tnode.false_exp, indent, extra_indent)

        elif isinstance(tnode, ast.ParenExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.CompStmt):
            s += indent + '{\n'
            for stmt in tnode.kids:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + '}\n'

        elif isinstance(tnode, ast.ExpStmt):
            s += indent + self.generate(tnode.exp, indent,
                                        extra_indent) + ';\n'

        elif isinstance(tnode, ast.IfStmt):
            s += indent + 'if (' + self.generate(tnode.test, indent,
                                                 extra_indent) + ') '
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                s += tstmt_s[tstmt_s.index('{'):]
                if tnode.false_stmt:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.generate(tnode.true_stmt, indent + extra_indent,
                                   extra_indent)
                if tnode.false_stmt:
                    s += indent + 'else '
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent,
                                            extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.generate(tnode.false_stmt, indent + extra_indent,
                                       extra_indent)

        elif isinstance(tnode, ast.ForStmt):
            s += indent + 'for ('
            if tnode.init:
                s += self.generate(tnode.init, indent, extra_indent)
            s += '; '
            if tnode.test:
                s += self.generate(tnode.test, indent, extra_indent)
            s += '; '
            if tnode.iter:
                s += self.generate(tnode.iter, indent, extra_indent)
            s += ') '
            if isinstance(tnode.stmt, ast.CompStmt):
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent,
                                   extra_indent)

        elif isinstance(tnode, ast.WhileStmt):
            s += indent + 'while (' + self.generate(tnode.test, indent,
                                                    extra_indent) + ') '
            if isinstance(tnode.stmt, ast.CompStmt):
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent,
                                   extra_indent)

        elif isinstance(tnode, ast.VarDec):
            if not tnode.isnested:
                s += indent
            s += ' '.join(tnode.type_name) + ' '
            s += ', '.join(
                map(lambda x: self.generate(x, indent, extra_indent),
                    tnode.var_inits))
            if not tnode.isnested:
                s += ';\n'

        elif isinstance(tnode, ast.ParamDec):
            s += indent + str(tnode.ty) + ' ' + str(tnode.name)

        elif isinstance(tnode, ast.FunDec):
            s += indent + str(tnode.return_type) + ' ' + str(tnode.modifiers)
            s += tnode.name + '('
            s += ', '.join(
                map(lambda x: self.generate(x, indent, extra_indent),
                    tnode.params))
            s += ')' + self.generate(tnode.body, indent, extra_indent)

        elif isinstance(tnode, ast.Pragma):
            s += indent + '#pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.TransformStmt):
            g.err(
                __name__ +
                ': internal error: a transformation statement is never generated as an output'
            )

        else:
            g.err(__name__ + ': internal error: unrecognized type of AST: %s' %
                  tnode.__class__.__name__)

        return s
示例#35
0
    def getDeviceProps(self):
      '''Get device properties'''

      # write the query code
      qsrc  = "enum_cuda_props.cu"
      qexec = qsrc + ".o"
      qout  = qexec + ".props"
      if not os.path.exists(qout):
        # check for nvcc
        qcmd = 'which nvcc'
        status = os.system(qcmd)
        if status != 0:
          g.err("%s: could not locate nvcc with '%s'" % (self.__class__, qcmd))

        try:
          f = open(qsrc, 'w')
          f.write(CUDA_DEVICE_QUERY_SKELET)
          f.close()
        except:
          g.err('%s: cannot open file for writing: %s' % (self.__class__, qsrc))
        
        # compile the query
        cmd = 'nvcc -o %s %s' % (qexec, qsrc)
        status = os.system(cmd)
        if status:
          g.err('%s: failed to compile cuda device query code: "%s"' % (self.__class__, cmd))

        # execute the query
        runcmd = './%s' % (qexec)
        status = os.system(runcmd)
        if status:
          g.err('%s: failed to execute cuda device query code: "%s"' % (self.__class__, runcmd))
        os.remove(qsrc)
        os.remove(qexec)
        
      # read device properties
      props = {}
      try:
        f = open(qout, 'r')
        for line in f:
            eline = ast.literal_eval(line)
            props[eline[0]] = eline[1]
        f.close()
      except:
        g.err('%s: cannot open query output file for reading: %s' % (self.__class__, qout))
  
      if props['devId'] == -2:
        g.err("%s: there is no CUDA 1.0 enabled GPU on this machine" % self.__class__)
      
      if props['major'] < 2 and props['minor'] < 3:
        g.warn("%s: running on compute capability less than 1.3 is not recommended, detected %s.%s." % (self.__class__, props['major'], props['minor']))

      # set the arch to the latest supported by the device
      if self.tinfo is None:
          bcmd = "gcc"
      else:
          bcmd = self.tinfo.build_cmd
          
      if bcmd.startswith('gcc'):
        bcmd = 'nvcc'
      if bcmd.find('-arch') == -1:
        bcmd += ' -arch=sm_' + str(props['major']) + str(props['minor'])
      if self.perf_params.has_key('CFLAGS') and bcmd.find('@CFLAGS') == -1:
        bcmd += ' @CFLAGS'
      self.tinfo.build_cmd = bcmd

      # return queried device props
      return props
示例#36
0
    def generate(self, tnode, indent = '  ', extra_indent = '  '):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.NumLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.StringLitExp):
            s += '"' + str(tnode.val) + '"'

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += '[' + self.generate(tnode.sub_exp, indent, extra_indent) + ']'

        elif isinstance(tnode, ast.FunCallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join(map(lambda x: self.generate(x, indent, extra_indent), tnode.args))
            s += ')'

        elif isinstance(tnode, ast.UnaryExp):
            s = self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:
                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = '!' + s
            elif tnode.op_type == tnode.PRE_INC:
                s = ' ++' + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = ' --' + s
            elif tnode.op_type == tnode.POST_INC:
                s = s + '++ '
            elif tnode.op_type == tnode.POST_DEC:
                s = s + '-- '
            elif tnode.op_type == tnode.DEREF:
                s = '*' + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = '&' + s
            else:
                g.err('orio.module.loop.codegen_opencl internal error: unknown unary operator type: %s' % tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.MUL:
                s += '*'
            elif tnode.op_type == tnode.DIV:
                s += '/'
            elif tnode.op_type == tnode.MOD:
                s += '%'
            elif tnode.op_type == tnode.ADD:
                s += '+'
            elif tnode.op_type == tnode.SUB:
                s += '-'
            elif tnode.op_type == tnode.LT:
                s += '<'
            elif tnode.op_type == tnode.GT:
                s += '>'
            elif tnode.op_type == tnode.LE:
                s += '<='
            elif tnode.op_type == tnode.GE:
                s += '>='
            elif tnode.op_type == tnode.EQ:
                s += '=='
            elif tnode.op_type == tnode.NE:
                s += '!='
            elif tnode.op_type == tnode.LOR:
                s += '||'
            elif tnode.op_type == tnode.LAND:
                s += '&&'
            elif tnode.op_type == tnode.COMMA:
                s += ','
            elif tnode.op_type == tnode.EQ_ASGN:
                s += '='
            elif tnode.op_type == tnode.ASGN_ADD:
                s += '+='
            elif tnode.op_type == tnode.ASGN_SHR:
                s += '>>='
            elif tnode.op_type == tnode.ASGN_SHL:
                s += '<<='
            elif tnode.op_type == tnode.BAND:
                s += '&'
            elif tnode.op_type == tnode.SHR:
                s += '>>'
            elif tnode.op_type == tnode.BOR:
                s += '|'
            else:
                g.err('orio.module.loop.codegen_opencl internal error: unknown binary operator type: %s' % tnode.op_type)
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.TernaryExp):
            s += self.generate(tnode.test, indent, extra_indent) + '?'
            s += self.generate(tnode.true_expr,  indent, extra_indent) + ':'
            s += self.generate(tnode.false_expr, indent, extra_indent)

        elif isinstance(tnode, ast.ParenthExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.Comment):
            s += indent
            if tnode.text:
                s += '/*' + tnode.text + '*/'
            s += '\n'
            
        elif isinstance(tnode, ast.ExpStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.exp:
                s += self.generate(tnode.exp, indent, extra_indent)
            s += ';\n'

        elif isinstance(tnode, ast.GotoStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent
            if tnode.target:
                s += 'goto ' + tnode.target + ';\n'
                
        elif isinstance(tnode, ast.CompStmt):
            s += indent + '{\n'
            for stmt in tnode.stmts:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + '}\n'

        elif isinstance(tnode, ast.IfStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + 'if (' + self.generate(tnode.test, indent, extra_indent) + ') '
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                s += tstmt_s[tstmt_s.index('{'):]
                if tnode.false_stmt:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.generate(tnode.true_stmt, indent + extra_indent, extra_indent)
                if tnode.false_stmt:
                    s += indent + 'else '
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent, extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.generate(tnode.false_stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.ForStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + 'for ('
            if tnode.init:
                if isinstance(tnode.init, ast.VarDeclInit):
                  s += str(tnode.init.type_name) + ' '
                  s += self.generate(tnode.init.var_name, indent, extra_indent)
                  s += '=' + self.generate(tnode.init.init_exp, indent, extra_indent)
                else:
                  s += self.generate(tnode.init, indent, extra_indent)
            s += '; '
            if tnode.test:
                s += self.generate(tnode.test, indent, extra_indent)
            s += '; '
            if tnode.iter:
                s += self.generate(tnode.iter, indent, extra_indent)
            s += ') '
            if isinstance(tnode.stmt, ast.CompStmt): 
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.AssignStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ':'
            s += indent + tnode.var + '='
            s += self.generate(tnode.exp, indent, extra_indent)
            s += ';\n'
            
        elif isinstance(tnode, ast.TransformStmt):
            g.err('orio.module.loop.codegen_opencl internal error: a transformation statement is never generated as an output')

        elif isinstance(tnode, ast.VarDecl):
            s += indent + str(tnode.type_name) + ' '
            if isinstance(tnode.var_names[0], ast.IdentExp): 
              s += ', '.join(map(self.generate, tnode.var_names))
            else:
              s += ', '.join(tnode.var_names)
            s += ';\n'

        elif isinstance(tnode, ast.VarDeclInit):
            s += indent + str(tnode.type_name) + ' '
            s += self.generate(tnode.var_name, indent, extra_indent)
            s += '=' + self.generate(tnode.init_exp, indent, extra_indent)
            s += ';\n'

        elif isinstance(tnode, ast.FieldDecl):
            s += tnode.ty + ' '
            if isinstance(tnode.name, ast.IdentExp):
              s+= tnode.name.name
            else:
              s += tnode.name

        elif isinstance(tnode, ast.FunDecl):
            s += indent + ' '.join(tnode.modifiers) + ' '
            s += tnode.return_type + ' '
            s += tnode.name + '('
            s += ', '.join(map(self.generate, tnode.params)) + ') '
            s += self.generate(tnode.body, indent, extra_indent)

        elif isinstance(tnode, ast.Pragma):
            s += indent + '#pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.Container):
            s += self.generate(tnode.ast, indent, extra_indent)

        elif isinstance(tnode, ast.WhileStmt):
            s += indent + 'while (' + self.generate(tnode.test, indent, extra_indent)
            s += ') '
            if isinstance(tnode.stmt, ast.CompStmt): 
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)

        elif isinstance(tnode, ast.CastExpr):
            s += '(' + tnode.ctype + ')'
            s += self.generate(tnode.expr, indent, extra_indent)
        
        else:
            g.err('orio.module.loop.codegen_opencl internal error: unrecognized type of AST: %s' % tnode.__class__.__name__)

        return s
示例#37
0
    def readTransfArgs(self, perf_params, transf_args):
        '''Process the given transformation arguments'''

        # expected argument names
        THREADCOUNT = 'threadCount'
        BLOCKCOUNT  = 'blockCount'
        CB          = 'cacheBlocks'
        PHM         = 'pinHostMem'
        STREAMCOUNT = 'streamCount'
        DOMAIN      = 'domain'
        DOD         = 'dataOnDevice'
        UIF         = 'unrollInner'
        PREFERL1SZ  = 'preferL1Size'

        # default argument values
        szwarp  = self.props['warpSize']
        smcount = self.props['multiProcessorCount']
        threadCount  = szwarp
        blockCount   = smcount
        cacheBlocks  = False
        pinHost      = False
        streamCount  = 1
        domain       = None
        dataOnDevice = False
        unrollInner  = None
        preferL1Size = 0

        # iterate over all transformation arguments
        errors = ''
        for aname, rhs, line_no in transf_args:

            # evaluate the RHS expression
            try:
                rhs = eval(rhs, perf_params)
            except Exception, e:
                g.err('orio.module.loop.submodule.cuda.cuda: %s: failed to evaluate the argument expression: %s\n --> %s: %s' % (line_no, rhs,e.__class__.__name__, e))
            
            if aname == THREADCOUNT:
                if not isinstance(rhs, int) or rhs <= 0 or rhs > self.props['maxThreadsPerBlock']:
                    errors += 'line %s: threadCount must be a positive integer less than device limit of maxThreadsPerBlock of %s: %s' % (line_no, self.props['maxThreadsPerBlock'], rhs)
                elif rhs % szwarp != 0:
                    errors += 'line %s: threadCount is not a multiple of warp size of %s: %s' % (line_no, szwarp, rhs)
                else:
                    threadCount = rhs
            elif aname == BLOCKCOUNT:
                if not isinstance(rhs, int) or rhs <= 0 or rhs > self.props['maxGridSize'][0]:
                    errors += 'line %s: %s must be a positive integer less than device limit of maxGridSize[0]=%s: %s\n' % (line_no, aname, self.props['maxGridSize'][0], rhs)
                elif rhs % smcount != 0:
                    errors += 'line %s: blockCount is not a multiple of SM count of %s: %s' % (line_no, smcount, rhs)
                else:
                    blockCount = rhs
            elif aname == CB:
                if not isinstance(rhs, bool):
                    errors += 'line %s: %s must be a boolean: %s\n' % (line_no, aname, rhs)
                else:
                    cacheBlocks = rhs
            elif aname == PHM:
                if not isinstance(rhs, bool):
                    errors += 'line %s: %s must be a boolean: %s\n' % (line_no, aname, rhs)
                else:
                    pinHost = rhs
            elif aname == STREAMCOUNT:
                if not isinstance(rhs, int) or rhs <= 0:
                    errors += 'line %s: %s must be a positive integer: %s\n' % (line_no, aname, rhs)
                else:
                    if rhs > 1:
                      overlap = self.props['deviceOverlap']
                      if overlap == 0:
                        errors += '%s=%s: deviceOverlap=%s, overlap of data transfer and kernel execution is not supported\n' % (aname, rhs, overlap)
                      concs = self.props['concurrentKernels']
                      if concs == 0:
                        errors += '%s=%s: device concurrentKernels=%s, concurrent kernel execution is not supported\n' % (aname, rhs, concs)
                    streamCount = rhs
            elif aname == DOMAIN:
                if not isinstance(rhs, str):
                    errors += 'line %s: %s must be a string: %s\n' % (line_no, aname, rhs)
                else:
                    domain = rhs
            elif aname == DOD:
                if not isinstance(rhs, bool):
                    errors += 'line %s: %s must be a boolean: %s\n' % (line_no, aname, rhs)
                else:
                    dataOnDevice = rhs
            elif aname == UIF:
                if not isinstance(rhs, int) or rhs <= 0:
                    errors += 'line %s: %s must be a positive integer: %s\n' % (line_no, aname, rhs)
                else:
                    unrollInner = rhs
            elif aname == PREFERL1SZ:
                if not isinstance(rhs, int) or rhs not in [16,32,48]:
                    errors += 'line %s: %s must be either 16, 32 or 48 KB: %s\n' % (line_no, aname, rhs)
                else:
                    major = self.props['major']
                    if major < 2:
                      errors += '%s=%s: L1 cache is not resizable on compute capability less than 2.x, current comp.cap.=%s.%s\n' % (aname, rhs, major, self.props['minor'])
                    elif major < 3 and rhs == 32:
                      errors += '%s=%s: L1 cache cannot be set to %s on compute capability less than 3.x, current comp.cap.=%s.%s\n' % (aname, rhs, rhs, major, self.props['minor'])
                    preferL1Size = rhs
            else:
                g.err('%s: %s: unrecognized transformation argument: "%s"' % (self.__class__, line_no, aname))
示例#38
0
    def generate(self, tnode, indent = '  ', extra_indent = '  ', doloop_inc = False):
        '''To generate code that corresponds to the given AST'''

        s = ''

        if isinstance(tnode, ast.NumLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.StringLitExp):
            s += str(tnode.val)

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):            
            # Now get all the indices
            tmpnode = tnode
            prevtmpnode = tnode
            indices = []
            while isinstance(tmpnode, ast.ArrayRefExp):
                indices.append(tmpnode.sub_exp)
                prevtmpnode = tmpnode
                tmpnode = tmpnode.exp
            
            indices.reverse()
            s += self.generate(prevtmpnode.exp, indent, extra_indent)  # the variable name
            s += '(' + ','.join([self.generate(x, indent, extra_indent) for x in indices]) + ')'

        elif isinstance(tnode, ast.FunCallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + '('
            s += ','.join(map(lambda x: self.generate(x, indent, extra_indent), tnode.args))
            s += ')'

        elif isinstance(tnode, ast.UnaryExp):
            s = self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = '+' + s
            elif tnode.op_type == tnode.MINUS:
                
                s = '-' + s
            elif tnode.op_type == tnode.LNOT:
                s = 'NOT(' + s + ')'
            elif tnode.op_type == tnode.PRE_INC:
                s += '\n' + indent + s + ' = ' + s + ' + 1\n'
            elif tnode.op_type == tnode.PRE_DEC:
                s += '\n' + indent + s + ' = ' + s + ' - 1\n'
            elif tnode.op_type == tnode.POST_INC:
                s += s + '\n' + indent + s + ' = ' + s + ' + 1\n'
            elif tnode.op_type == tnode.POST_DEC:
                s += s + '\n' + indent + s + ' = ' + s + ' - 1\n'
            else:
                g.err('orio.module.loop.codegen internal error: unknown unary operator type: %s' % tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            if tnode.op_type not in [tnode.MOD, tnode.COMMA]:
                if not doloop_inc: 
                    s += self.generate(tnode.lhs, indent, extra_indent)
                if tnode.op_type == tnode.MUL:
                    s += '*'
                elif tnode.op_type == tnode.DIV:
                    s += '/'
                elif tnode.op_type == tnode.ADD:
                    s += '+'
                elif tnode.op_type == tnode.SUB:
                    s += '-'
                elif tnode.op_type == tnode.LT:
                    s += '<'
                elif tnode.op_type == tnode.GT:
                    s += '>'
                elif tnode.op_type == tnode.LE:
                    s += '<='
                elif tnode.op_type == tnode.GE:
                    s += '>='
                elif tnode.op_type == tnode.EQ:
                    s += '=='
                elif tnode.op_type == tnode.NE:
                    s += '!='
                elif tnode.op_type == tnode.LOR:
                    s += '.OR.'
                elif tnode.op_type == tnode.LAND:
                    s += '.AND.'
                elif tnode.op_type == tnode.EQ_ASGN:
                    s += '='
                else:
                    g.err('orio.module.loop.codegen internal error: unknown binary operator type: %s' % tnode.op_type)
                    
                s += self.generate(tnode.rhs, indent, extra_indent)
                
            else:
                
                if tnode.op_type == tnode.MOD:
                    s += 'MOD(' + self.generate(tnode.lhs, indent, extra_indent) + ', ' \
                        + self.generate(tnode.rhs, indent, extra_indent) + ')'
                elif tnode.op_type == tnode.COMMA:
                    # TODO: We need to implement an AST canonicalization step for Fortran before generating the code.
                    print 'internal warning: Fortran code generator does not fully support the comma operator -- the generated code may not compile.'
                    s += self.generate(tnode.rhs, indent, extra_indent) 
                    s += '\n' + indent + self.generate(tnode.lhs, indent, extra_indent)
                    s +='\n! ORIO Warining: check code above and fix problems.'

        elif isinstance(tnode, ast.ParenthExp):
            s += '(' + self.generate(tnode.exp, indent, extra_indent) + ')'

        elif isinstance(tnode, ast.Comment):
            s += indent
            if tnode.text:
                s += '!' + tnode.text 
            s += '\n'

        elif isinstance(tnode, ast.ExpStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent
            if tnode.exp:
                s += self.generate(tnode.exp, indent, extra_indent)
            s += '\n'
            
        elif isinstance(tnode, ast.GotoStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent
            if tnode.target:
                s += 'goto ' + tnode.target + '\n'

        elif isinstance(tnode, ast.CompStmt):
            s += indent + '\n'
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            for stmt in tnode.stmts:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + '\n'

        elif isinstance(tnode, ast.IfStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent + 'if (' + self.generate(tnode.test, indent, extra_indent) + ') then \n'
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                # TODO: fix below cludge -- { is missing for some reason in some compound ifs
                if tstmt_s.count('{') > 0: s += tstmt_s[tstmt_s.index('{'):]
                else: s += tstmt_s
                if tnode.false_stmt:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.generate(tnode.true_stmt, indent + extra_indent, extra_indent)
                if tnode.false_stmt:
                    s += indent + 'else '
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent, extra_indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.generate(tnode.false_stmt, indent + extra_indent, extra_indent)
            s += indent + 'end if\n'

        elif isinstance(tnode, ast.ForStmt):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += indent + 'do ' 
            if tnode.init:
                s += self.generate(tnode.init, indent, extra_indent)
            if not tnode.test:
                g.err('orio.module.loop.codegen:  missing loop test expression. Fortran code generation requires a loop test expression.')
                
            if not tnode.iter:
                g.err('orio.module.loop.codegen:  missing loop increment expression. Fortran code generation requires a loop increment expression.')
            s += ', '
            if not isinstance(tnode.test, ast.BinOpExp):
                g.err('orio.module.loop.codegen internal error: cannot handle code generation for loop test expression')
                
            if tnode.test.op_type not in [tnode.test.LE, tnode.test.LT, tnode.test.GE, tnode.test.GT]: 
                g.err('orio.module.loop.codegen internal error: cannot generate Fortran loop, only <, >, <=, >= are recognized in the loop limit test')
            
            # Generate the loop bound        
            s += self.generate(tnode.test.rhs, indent, extra_indent)
            
            # Check whether we need to change the bound
            #if tnode.test.op_type == tnode.test.LE: s += ' + 1'
            #if tnode.test.op_type == tnode.test.GE: s += ' - 1'
            
            s += ', '
            # Generate the loop increment/decrement step
            
            if not isinstance(tnode.iter, (ast.BinOpExp, ast.UnaryExp)):
                g.err('orio.module.loop.codegen internal error: cannot handle code generation for loop increment expression')
 
            unary = False
            if isinstance(tnode.iter, ast.UnaryExp):
                incr_decr = [tnode.iter.POST_DEC, tnode.iter.PRE_DEC, tnode.iter.POST_INC, tnode.iter.PRE_INC]
                unary = True
                
            if not ((isinstance(tnode.iter, ast.BinOpExp) and tnode.iter.op_type == tnode.iter.EQ_ASGN)
                    or (isinstance(tnode.iter, ast.UnaryExp) and tnode.iter.op_type in incr_decr)): 
                g.err('orio.module.loop.codegen internal error: cannot handle code generation for loop increment expression')

            if tnode.test.op_type in [tnode.test.GT, tnode.test.GE] \
                and unary and tnode.iter.op_type in [tnode.iter.PRE_DEC, tnode.iter.POST_DEC]:
                s += '-'
               
            if unary and tnode.iter.op_type in incr_decr:   # ++i
                s += '1'
            else:
                s += self.generate(tnode.iter.rhs, indent, extra_indent, True)
            
            if isinstance(tnode.stmt, ast.CompStmt): 
                s += self.generate(tnode.stmt, indent, extra_indent)
                s += '\n' + indent + 'end do\n'
            else:
                s += '\n'
                s += self.generate(tnode.stmt, indent + extra_indent, extra_indent)
                s += '\n' + indent + 'end do\n'

        elif isinstance(tnode, ast.TransformStmt):
            g.err('orio.module.loop.codegen internal error: a transformation statement is never generated as an output')

        elif isinstance(tnode, ast.VarDecl):
            
            if tnode.type_name not in self.ftypes.keys():
                g.err('orio.module.loop.codegen internal error: Cannot generate Fortran type for ' + tnode.type_name)
                
            s += indent + str(self.ftypes[tnode.type_name]) + ' '
            s += ', '.join(tnode.var_names)
            s += '\n'

        elif isinstance(tnode, ast.Pragma):
            s += '$pragma ' + str(tnode.pstring) + '\n'

        elif isinstance(tnode, ast.Container):
            if tnode.getLabel(): s += tnode.getLabel() + ' '
            s += self.generate(tnode.ast, indent, extra_indent)

        else:
            g.err('orio.module.loop.codegen internal error: unrecognized type of AST: %s' % tnode.__class__.__name__)

        return s
示例#39
0
文件: printer.py 项目: axelyamel/Orio
    def pp(self, n, indent='  '):
        '''Pretty-print the given AST'''

        s = ''
        if isinstance(n, Comment):
            s += indent
            if n.comment:
                s += '/*' + n.comment + '*/'
            s += '\n'
        
        elif isinstance(n, LitExp):
            if n.lit_type == LitExp.STRING:
                s += '"' + str(n.val) + '"'
            else:
                s += str(n.val)

        elif isinstance(n, IdentExp):
            s += str(n.name)

        elif isinstance(n, ArrayRefExp):
            s += self.pp(n.exp, indent)
            s += '[' + self.pp(n.sub, indent) + ']'

        elif isinstance(n, CallExp):
            s += self.pp(n.exp, indent) + '('
            s += ','.join(map(lambda x: self.pp(x, indent), n.args))
            s += ')'

        elif isinstance(n, UnaryExp):
            s = self.pp(n.exp, indent)
            if   n.oper == n.PLUS:      s = '+' + s
            elif n.oper == n.MINUS:     s = '-' + s
            elif n.oper == n.LNOT:      s = '!' + s
            elif n.oper == n.TRANSPOSE: s += "'"
            elif n.oper == n.PRE_INC:   s = ' ++' + s
            elif n.oper == n.PRE_DEC:   s = ' --' + s
            elif n.oper == n.POST_INC:  s += '++ '
            elif n.oper == n.POST_DEC:  s += '-- '
            else: g.err('%s: unknown unary operator type: %s' % (self.__class__, n.oper))

        elif isinstance(n, BinOpExp):
            s += self.pp(n.lhs, indent)
            if   n.oper == n.PLUS:    s += '+'
            elif n.oper == n.MINUS:   s += '-'
            elif n.oper == n.MULT:    s += '*'
            elif n.oper == n.DIV:     s += '/'
            elif n.oper == n.MOD:     s += '%'
            elif n.oper == n.LT:      s += '<'
            elif n.oper == n.GT:      s += '>'
            elif n.oper == n.LE:      s += '<='
            elif n.oper == n.GE:      s += '>='
            elif n.oper == n.EE:      s += '=='
            elif n.oper == n.NE:      s += '!='
            elif n.oper == n.LOR:     s += '||'
            elif n.oper == n.LAND:    s += '&&'
            elif n.oper == n.EQ:      s += '='
            elif n.oper == n.EQPLUS:  s += '+='
            elif n.oper == n.EQMINUS: s += '-='
            elif n.oper == n.EQMULT:  s += '*='
            elif n.oper == n.EQDIV:   s += '/='
            elif n.oper == n.COMMA:   s += ','
            else: g.err('%s: unknown binary operator type: %s' % (self.__class__, n.oper))
            s += self.pp(n.rhs, indent)

        elif isinstance(n, ParenExp):
            s += '(' + self.pp(n.exp, indent) + ')'

        elif isinstance(n, ExpStmt):
            s += indent + self.pp(n.exp, indent) + ';\n'

        elif isinstance(n, CompStmt):
            s += indent + '{\n'
            for stmt in n.stmts:
                s += self.pp(stmt, indent + self.extra_indent)
            s += indent + '}\n'

        elif isinstance(n, IfStmt):
            s += indent + 'if (' + self.pp(n.test, indent) + ') '
            if isinstance(n.then_s, CompStmt):
                tstmt_s = self.pp(n.then_s, indent)
                s += tstmt_s[tstmt_s.index('{'):]
                if n.else_s:
                    s = s[:-1] + ' else '
            else:
                s += '\n'
                s += self.pp(n.then_s, indent + self.extra_indent)
                if n.else_s:
                    s += indent + 'else '
            if n.else_s:
                if isinstance(n.else_s, CompStmt):
                    tstmt_s = self.pp(n.else_s, indent)
                    s += tstmt_s[tstmt_s.index('{'):]
                else:
                    s += '\n'
                    s += self.pp(n.else_s, indent + self.extra_indent)

        elif isinstance(n, ForStmt):
            #if n.getLabel(): s += n.getLabel() + ':'
            s += indent + 'for ('
            if n.init:
                s += self.pp(n.init, indent)
            s += '; '
            if n.test:
                s += self.pp(n.test, indent)
            s += '; '
            if n.itr:
                s += self.pp(n.itr, indent)
            s += ') '
            if isinstance(n.stmt, CompStmt): 
                stmt_s = self.pp(n.stmt, indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.pp(n.stmt, indent + self.extra_indent)

        elif isinstance(n, WhileStmt):
            s += indent + 'while (' + self.pp(n.test, indent)
            s += ') '
            if isinstance(n.stmt, CompStmt): 
                stmt_s = self.pp(n.stmt, indent)
                s += stmt_s[stmt_s.index('{'):]
            else:
                s += '\n'
                s += self.pp(n.stmt, indent + self.extra_indent)

        elif isinstance(n, VarInit):
            s += self.pp(n.var_name, indent)
            if n.init_exp:
                s += '=' + self.pp(n.init_exp, indent)

        elif isinstance(n, VarDec):
            if len(n.quals) > 0:
              s += ''.join(n.quals) + ' '
            s += str(n.type_name) + ' '
            s += ', '.join(map(self.pp, n.var_inits))
            if n.isAtomic:
                s = indent + s + ';\n'

        elif isinstance(n, ParamDec):
            s += self.pp(n.ty, indent) + ' ' + self.pp(n.name, indent)

        elif isinstance(n, FunDec):
            s += ' '.join(n.quals) + ''
            s += self.pp(n.rtype, indent) + ' '
            s += self.pp(n.name, indent) + '('
            s += ', '.join(map(self.pp, n.params)) + ') '
            s += self.pp(n.body, indent)

        elif isinstance(n, TransformStmt):
            g.err('%s: a transformation statement is never generated as an output' % self.__class__)

        else:
            g.err('%s: unrecognized type of AST: (%s, %s)' % (self.__class__, n.__class__.__name__,n))

        return s
示例#40
0
    def rewriteNode(self, r, n):
        ''' Rewrite the given node with the given rewrite function: post-order traversal, in-place update. '''
        
        if isinstance(n, orio.module.loop.ast.NumLitExp):
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.StringLitExp):
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.IdentExp):
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.VarDecl):
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.ArrayRefExp):
            n.exp = self.rewriteNode(r, n.exp)
            n.sub_exp = self.rewriteNode(r, n.sub_exp)
            return r(n)

        elif isinstance(n, orio.module.loop.ast.FunCallExp):
            n.exp = self.rewriteNode(r, n.exp)
            n.args = map(lambda x: self.rewriteNode(r, x), n.args)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.UnaryExp):
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.BinOpExp):
            n.lhs = self.rewriteNode(r, n.lhs)
            n.rhs = self.rewriteNode(r, n.rhs)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.ParenthExp):
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.Comment):
            n.text = self.rewriteNode(r, n.text)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.ExpStmt):
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.GotoStmt):
            n.target = self.rewriteNode(r, n.target)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.CompStmt):
            n.stmts = map(lambda x: self.rewriteNode(r, x), n.stmts)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.IfStmt):
            n.test = self.rewriteNode(r, n.test)
            n.true_stmt = self.rewriteNode(r, n.true_stmt)
            if n.false_stmt:
                n.false_stmt = self.rewriteNode(r, n.false_stmt)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.ForStmt):
            if n.init:
                n.init = self.rewriteNode(r, n.init)
            if n.test:
                n.test = self.rewriteNode(r, n.test)
            if n.iter:
                n.iter = self.rewriteNode(r, n.iter)
            n.stmt = self.rewriteNode(r, n.stmt)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.AssignStmt):
            n.var = self.rewriteNode(r, n.var)
            n.exp = self.rewriteNode(r, n.exp)
            return r(n)
        
        elif isinstance(n, orio.module.loop.ast.TransformStmt):
            n.name = self.rewriteNode(r, n.name)
            n.args = self.rewriteNode(r, n.args)
            n.stmt = self.rewriteNode(r, n.stmt)
            return r(n)
        
        else:
            g.err('orio.module.loop.ast_lib.common_lib.rewriteNode: unexpected AST type: "%s"' % n.__class__.__name__)
示例#41
0
    def generate(self, tnode, indent="  ", extra_indent="  "):
        """To generate code that corresponds to the given AST"""

        s = ""

        if isinstance(tnode, ast.Comment):
            s += indent + "/*" + tnode.text + "*/\n"

        elif isinstance(tnode, ast.LitExp):
            s += str(tnode.val).encode("string-escape")

        elif isinstance(tnode, ast.IdentExp):
            s += str(tnode.name)

        elif isinstance(tnode, ast.ArrayRefExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            s += "[" + self.generate(tnode.sub, indent, extra_indent) + "]"

        elif isinstance(tnode, ast.CallExp):
            s += self.generate(tnode.exp, indent, extra_indent) + "("
            s += ",".join(
                map(lambda x: self.generate(x, indent, extra_indent),
                    tnode.args))
            s += ")"

        elif isinstance(tnode, ast.CastExp):
            s += "(" + self.generate(tnode.castto, indent, extra_indent) + ")"
            s += self.generate(tnode.exp, indent, extra_indent)

        elif isinstance(tnode, ast.UnaryExp):
            s += self.generate(tnode.exp, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s = "+" + s
            elif tnode.op_type == tnode.MINUS:
                s = "-" + s
            elif tnode.op_type == tnode.LNOT:
                s = "!" + s
            elif tnode.op_type == tnode.BNOT:
                s = "~" + s
            elif tnode.op_type == tnode.PRE_INC:
                s = " ++" + s
            elif tnode.op_type == tnode.PRE_DEC:
                s = " --" + s
            elif tnode.op_type == tnode.POST_INC:
                s += "++ "
            elif tnode.op_type == tnode.POST_DEC:
                s += "-- "
            elif tnode.op_type == tnode.DEREF:
                s = "*" + s
            elif tnode.op_type == tnode.ADDRESSOF:
                s = "&" + s
            elif tnode.op_type == tnode.SIZEOF:
                s = "sizeof " + s
            else:
                g.err(__name__ +
                      ": internal error: unknown unary operator type: %s" %
                      tnode.op_type)

        elif isinstance(tnode, ast.BinOpExp):
            s += self.generate(tnode.lhs, indent, extra_indent)
            if tnode.op_type == tnode.PLUS:
                s += "+"
            elif tnode.op_type == tnode.MINUS:
                s += "-"
            elif tnode.op_type == tnode.MULT:
                s += "*"
            elif tnode.op_type == tnode.DIV:
                s += "/"
            elif tnode.op_type == tnode.MOD:
                s += "%"
            elif tnode.op_type == tnode.LT:
                s += "<"
            elif tnode.op_type == tnode.GT:
                s += ">"
            elif tnode.op_type == tnode.LE:
                s += "<="
            elif tnode.op_type == tnode.GE:
                s += ">="
            elif tnode.op_type == tnode.EE:
                s += "=="
            elif tnode.op_type == tnode.NE:
                s += "!="
            elif tnode.op_type == tnode.LOR:
                s += "||"
            elif tnode.op_type == tnode.LAND:
                s += "&&"
            elif tnode.op_type == tnode.EQ:
                s += "="
            elif tnode.op_type == tnode.PLUSEQ:
                s += "+="
            elif tnode.op_type == tnode.MINUSEQ:
                s += "-="
            elif tnode.op_type == tnode.MULTEQ:
                s += "*="
            elif tnode.op_type == tnode.DIVEQ:
                s += "/="
            elif tnode.op_type == tnode.MODEQ:
                s += "%="
            elif tnode.op_type == tnode.COMMA:
                s += ","
            elif tnode.op_type == tnode.BOR:
                s += "|"
            elif tnode.op_type == tnode.BAND:
                s += "&"
            elif tnode.op_type == tnode.BXOR:
                s += "^"
            elif tnode.op_type == tnode.BSHL:
                s += "<<"
            elif tnode.op_type == tnode.BSHR:
                s += ">>"
            elif tnode.op_type == tnode.BSHLEQ:
                s += "<<="
            elif tnode.op_type == tnode.BSHREQ:
                s += ">>="
            elif tnode.op_type == tnode.BANDEQ:
                s += "&="
            elif tnode.op_type == tnode.BXOREQ:
                s += "^="
            elif tnode.op_type == tnode.BOREQ:
                s += "|="
            elif tnode.op_type == tnode.DOT:
                s += "."
            elif tnode.op_type == tnode.SELECT:
                s += "->"
            else:
                g.err(__name__ +
                      ": internal error: unknown binary operator type: %s" %
                      tnode.op_type)
            s += self.generate(tnode.rhs, indent, extra_indent)

        elif isinstance(tnode, ast.TernaryExp):
            s += self.generate(tnode.test, indent, extra_indent) + "?"
            s += self.generate(tnode.true_exp, indent, extra_indent) + ":"
            s += self.generate(tnode.false_exp, indent, extra_indent)

        elif isinstance(tnode, ast.ParenExp):
            s += "(" + self.generate(tnode.exp, indent, extra_indent) + ")"

        elif isinstance(tnode, ast.CompStmt):
            s += indent + "{\n"
            for stmt in tnode.kids:
                s += self.generate(stmt, indent + extra_indent, extra_indent)
            s += indent + "}\n"

        elif isinstance(tnode, ast.ExpStmt):
            s += indent + self.generate(tnode.exp, indent,
                                        extra_indent) + ";\n"

        elif isinstance(tnode, ast.IfStmt):
            s += (indent + "if (" +
                  self.generate(tnode.test, indent, extra_indent) + ") ")
            if isinstance(tnode.true_stmt, ast.CompStmt):
                tstmt_s = self.generate(tnode.true_stmt, indent, extra_indent)
                s += tstmt_s[tstmt_s.index("{"):]
                if tnode.false_stmt:
                    s = s[:-1] + " else "
            else:
                s += "\n"
                s += self.generate(tnode.true_stmt, indent + extra_indent,
                                   extra_indent)
                if tnode.false_stmt:
                    s += indent + "else "
            if tnode.false_stmt:
                if isinstance(tnode.false_stmt, ast.CompStmt):
                    tstmt_s = self.generate(tnode.false_stmt, indent,
                                            extra_indent)
                    s += tstmt_s[tstmt_s.index("{"):]
                else:
                    s += "\n"
                    s += self.generate(tnode.false_stmt, indent + extra_indent,
                                       extra_indent)

        elif isinstance(tnode, ast.ForStmt):
            s += indent + "for ("
            if tnode.init:
                s += self.generate(tnode.init, indent, extra_indent)
            s += "; "
            if tnode.test:
                s += self.generate(tnode.test, indent, extra_indent)
            s += "; "
            if tnode.iter:
                s += self.generate(tnode.iter, indent, extra_indent)
            s += ") "
            if isinstance(tnode.stmt, ast.CompStmt):
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index("{"):]
            else:
                s += "\n"
                s += self.generate(tnode.stmt, indent + extra_indent,
                                   extra_indent)

        elif isinstance(tnode, ast.WhileStmt):
            s += (indent + "while (" +
                  self.generate(tnode.test, indent, extra_indent) + ") ")
            if isinstance(tnode.stmt, ast.CompStmt):
                stmt_s = self.generate(tnode.stmt, indent, extra_indent)
                s += stmt_s[stmt_s.index("{"):]
            else:
                s += "\n"
                s += self.generate(tnode.stmt, indent + extra_indent,
                                   extra_indent)

        elif isinstance(tnode, ast.VarDec):
            if not tnode.isnested:
                s += indent
            s += " ".join(tnode.type_name) + " "
            s += ", ".join(
                map(lambda x: self.generate(x, indent, extra_indent),
                    tnode.var_inits))
            if not tnode.isnested:
                s += ";\n"

        elif isinstance(tnode, ast.ParamDec):
            s += indent + str(tnode.ty) + " " + str(tnode.name)

        elif isinstance(tnode, ast.FunDec):
            s += indent + str(tnode.return_type) + " " + str(tnode.modifiers)
            s += tnode.name + "("
            s += ", ".join(
                map(lambda x: self.generate(x, indent, extra_indent),
                    tnode.params))
            s += ")" + self.generate(tnode.body, indent, extra_indent)

        elif isinstance(tnode, ast.Pragma):
            s += indent + "#pragma " + str(tnode.pstring) + "\n"

        elif isinstance(tnode, ast.TransformStmt):
            g.err(
                __name__ +
                ": internal error: a transformation statement is never generated as an output"
            )

        else:
            g.err(__name__ + ": internal error: unrecognized type of AST: %s" %
                  tnode.__class__.__name__)

        return s
示例#42
0
def p_error(p):
    col = find_column(p.lexer.lexdata, p)
    g.err(
        __name__
        + ": unexpected token-type '%s', token-value '%s' at line %s, column %s" % (p.type, p.value, p.lineno, col)
    )
示例#43
0
    def pp(self, n, indent="  "):
        """Pretty-print the given AST"""

        s = ""
        if isinstance(n, Comment):
            s += indent
            if n.comment:
                s += "/*" + n.comment + "*/"
            s += "\n"

        elif isinstance(n, LitExp):
            if n.lit_type == LitExp.STRING:
                s += '"' + str(n.val) + '"'
            else:
                s += str(n.val)

        elif isinstance(n, IdentExp):
            s += str(n.name)

        elif isinstance(n, ArrayRefExp):
            s += self.pp(n.exp, indent)
            s += "[" + self.pp(n.sub, indent) + "]"

        elif isinstance(n, CallExp):
            s += self.pp(n.exp, indent) + "("
            s += ",".join(map(lambda x: self.pp(x, indent), n.args))
            s += ")"

        elif isinstance(n, UnaryExp):
            s = self.pp(n.exp, indent)
            if n.oper == n.PLUS:
                s = "+" + s
            elif n.oper == n.MINUS:
                s = "-" + s
            elif n.oper == n.LNOT:
                s = "!" + s
            elif n.oper == n.TRANSPOSE:
                s += "'"
            elif n.oper == n.PRE_INC:
                s = " ++" + s
            elif n.oper == n.PRE_DEC:
                s = " --" + s
            elif n.oper == n.POST_INC:
                s += "++ "
            elif n.oper == n.POST_DEC:
                s += "-- "
            else:
                g.err("%s: unknown unary operator type: %s" % (self.__class__, n.oper))

        elif isinstance(n, BinOpExp):
            s += self.pp(n.lhs, indent)
            if n.oper == n.PLUS:
                s += "+"
            elif n.oper == n.MINUS:
                s += "-"
            elif n.oper == n.MULT:
                s += "*"
            elif n.oper == n.DIV:
                s += "/"
            elif n.oper == n.MOD:
                s += "%"
            elif n.oper == n.LT:
                s += "<"
            elif n.oper == n.GT:
                s += ">"
            elif n.oper == n.LE:
                s += "<="
            elif n.oper == n.GE:
                s += ">="
            elif n.oper == n.EE:
                s += "=="
            elif n.oper == n.NE:
                s += "!="
            elif n.oper == n.LOR:
                s += "||"
            elif n.oper == n.LAND:
                s += "&&"
            elif n.oper == n.EQ:
                s += "="
            elif n.oper == n.EQPLUS:
                s += "+="
            elif n.oper == n.EQMINUS:
                s += "-="
            elif n.oper == n.EQMULT:
                s += "*="
            elif n.oper == n.EQDIV:
                s += "/="
            elif n.oper == n.COMMA:
                s += ","
            else:
                g.err("%s: unknown binary operator type: %s" % (self.__class__, n.oper))
            s += self.pp(n.rhs, indent)

        elif isinstance(n, ParenExp):
            s += "(" + self.pp(n.exp, indent) + ")"

        elif isinstance(n, ExpStmt):
            s += indent + self.pp(n.exp, indent) + ";\n"

        elif isinstance(n, CompStmt):
            s += indent + "{\n"
            for stmt in n.stmts:
                s += self.pp(stmt, indent + self.extra_indent)
            s += indent + "}\n"

        elif isinstance(n, IfStmt):
            s += indent + "if (" + self.pp(n.test, indent) + ") "
            if isinstance(n.then_s, CompStmt):
                tstmt_s = self.pp(n.then_s, indent)
                s += tstmt_s[tstmt_s.index("{") :]
                if n.else_s:
                    s = s[:-1] + " else "
            else:
                s += "\n"
                s += self.pp(n.then_s, indent + self.extra_indent)
                if n.else_s:
                    s += indent + "else "
            if n.else_s:
                if isinstance(n.else_s, CompStmt):
                    tstmt_s = self.pp(n.else_s, indent)
                    s += tstmt_s[tstmt_s.index("{") :]
                else:
                    s += "\n"
                    s += self.pp(n.else_s, indent + self.extra_indent)

        elif isinstance(n, ForStmt):
            # if n.getLabel(): s += n.getLabel() + ':'
            s += indent + "for ("
            if n.init:
                s += self.pp(n.init, indent)
            s += "; "
            if n.test:
                s += self.pp(n.test, indent)
            s += "; "
            if n.itr:
                s += self.pp(n.itr, indent)
            s += ") "
            if isinstance(n.stmt, CompStmt):
                stmt_s = self.pp(n.stmt, indent)
                s += stmt_s[stmt_s.index("{") :]
            else:
                s += "\n"
                s += self.pp(n.stmt, indent + self.extra_indent)

        elif isinstance(n, WhileStmt):
            s += indent + "while (" + self.pp(n.test, indent)
            s += ") "
            if isinstance(n.stmt, CompStmt):
                stmt_s = self.pp(n.stmt, indent)
                s += stmt_s[stmt_s.index("{") :]
            else:
                s += "\n"
                s += self.pp(n.stmt, indent + self.extra_indent)

        elif isinstance(n, VarInit):
            s += self.pp(n.var_name, indent)
            if n.init_exp:
                s += "=" + self.pp(n.init_exp, indent)

        elif isinstance(n, VarDec):
            if len(n.quals) > 0:
                s += "".join(n.quals) + " "
            s += str(n.type_name) + " "
            s += ", ".join(map(self.pp, n.var_inits))
            if n.isAtomic:
                s = indent + s + ";\n"

        elif isinstance(n, ParamDec):
            s += self.pp(n.ty, indent) + " " + self.pp(n.name, indent)

        elif isinstance(n, FunDec):
            s += " ".join(n.quals) + ""
            s += self.pp(n.rtype, indent) + " "
            s += self.pp(n.name, indent) + "("
            s += ", ".join(map(self.pp, n.params)) + ") "
            s += self.pp(n.body, indent)

        elif isinstance(n, TransformStmt):
            g.err(
                "%s: a transformation statement is never generated as an output"
                % self.__class__
            )

        else:
            g.err(
                "%s: unrecognized type of AST: (%s, %s)"
                % (self.__class__, n.__class__.__name__, n)
            )

        return s