Example #1
0
class SemanticValidator(ASTListener):
    def __init__(self):
        self.symbolTable = SymbolTable()
        self.errors = []

    def enterProgram(self, node):
        self.symbolTable.newScope()

    def exitProgram(self, node):
        self.symbolTable.endScope()

    def enterCompound(self, node):
        self.symbolTable.newScope()

    def exitCompound(self, node):
        self.symbolTable.endScope()

    def enterVariableDecl(self, node):
        declList = node.declList
        for varDeclInit in declList.declInitializeList:
            # Check if new var already exists in current scope
            symbolInfo = self.symbolTable.getSymbolInCurrentScope(varDeclInit.name)
            if symbolInfo is None:
                if type(varDeclInit) is Variable.ArrayInitialize:
                    self.symbolTable.addSymbol(varDeclInit.name, ArrayInfo(node.type, varDeclInit.size))
                else:
                    self.symbolTable.addSymbol(varDeclInit.name, VarInfo(node.type))
            else:
                self.errors.append(varDeclInit.getPosition() + ": Redefinition of '" + varDeclInit.name + "'")

    def enterVarDeclInitialize(self, node):
        symbolInfo = self.symbolTable.getSymbol(node.name)
        if node.expression is not None:
            # if (node.expression is Expression.BinOp and (symbolInfo.type == "char" and not hasattr(symbolInfo,type))):
            #     self.errors.append(
            #         node.getPosition() + ": Type mismatch: expected '" + symbolInfo.type + "' but found 'BinOp'")
            # else:
            #     getTypeResult = getType(node.right, symbolInfo.type, self.symbolTable)
            #     if symbolInfo.type != getTypeResult[0] and getTypeResult[0] != "undefined input":
            #         self.errors.append(getTypeResult[1] + ": Type mismatch: expected '" + symbolInfo.type + "' but found '" +getTypeResult[0] + "'")
            if (isinstance(node.expression,Expression.BinOp)and (symbolInfo.type == "char" and not hasattr(symbolInfo,"size"))):
                self.errors.append(node.getPosition() + ": Type mismatch: expected '" + symbolInfo.type + "' but found 'BinOp'")
            else:
                getTypeResult = getType(node.expression, symbolInfo.type, self.symbolTable)
                if symbolInfo.type != getTypeResult[0] and getTypeResult[0] != "undefined input":
                    self.errors.append(getTypeResult[1] + ": Type mismatch for '" + node.name + "': expected '" + symbolInfo.type + "' but found '" +getTypeResult[0] + "'")

    def enterCall(self, node):
        symbolInfo = self.symbolTable.getSymbol(node.funcName)
        if symbolInfo is None or type(symbolInfo) is not FunctionInfo:
            if node.funcName == "printf" or node.funcName == "scanf":
                if len(node.args) > 0:
                    paramType = getType(node.args[0], "string", self.symbolTable)
                    if paramType[0] != "string" and paramType[0] != "char*":
                        self.errors.append(paramType[1] + ": Wrong parameter type for '" + node.funcName + "'! Expected: 'char*' found '" + paramType[0] + "'")
                else:
                    self.errors.append(node.getPosition() + ": Wrong amount of parameters for '" + node.funcName + "'! Expected at least one argument")

            else:
                self.errors.append(node.getPosition() + ": Undefined reference to '" + node.funcName + "'")
        else:
            symbolInfo.used = True
            if symbolInfo.isDecl:
                self.errors.append(node.getPosition() + ": Undefined reference to '" + node.funcName + "'")
            else:
                if len(node.args) == len(symbolInfo.paramTypes):
                    for i in range (0, len(symbolInfo.paramTypes)):
                        foundParamType = getType(node.args[i], symbolInfo.paramTypes[i],self.symbolTable)[0]
                        if foundParamType != symbolInfo.paramTypes[i]:
                            self.errors.append(node.getPosition() + ": Wrong parameter type for '" + node.funcName + "'! Expected: '" + symbolInfo.paramTypes[i] + "' found '" + foundParamType + "'")
                else:
                    self.errors.append(node.getPosition() + ": Wrong amount of parameters for '" + node.funcName + "'! Expected: " + str(len(symbolInfo.paramTypes)) + " found " + str(len(node.args)))

    def enterMutable(self, node):
        symbolInfo = self.symbolTable.getSymbol(node.name)
        if symbolInfo is None or (type(symbolInfo) is not VarInfo and type(symbolInfo) is not ArrayInfo):
            self.errors.append(node.getPosition() + ": Undefined reference to '" + node.name + "'")
        else:
            symbolInfo.used = True

    def enterSubScript(self, node):
        symbolInfo = self.symbolTable.getSymbol(node.mutable.name)
        if symbolInfo is None or type(symbolInfo) is not ArrayInfo:
            self.errors.append(node.getPosition() + ": Subscripted value '" + node.mutable.name + "' is not an array")
        else:
            symbolInfo.used = True
            # if int(symbolInfo.size) < int(node.index._int):
            #     self.errors.append(node.index.getPosition() + ": Index out of range for '" + node.mutable.name + "'! Max index: '" + str(
            #         int(symbolInfo.size) - 1) + "' but found '" + str(node.index._int) + "'")

    def enterFunctionDef(self, node):
        # Check if new function already exists
        symbolInfo = self.symbolTable.getSymbol(node.name)
        if symbolInfo is None:
            params = node.params
            paramTypes = []
            for param in params.params:
                paramTypes.append(param.type)
                self.symbolTable.addSymbol(param.name, VarInfo(param.type))

            self.symbolTable.addSymbol(node.name, FunctionInfo(node.returns, paramTypes))
            self.symbolTable.newScope()

        elif type(symbolInfo) is FunctionInfo and symbolInfo.isDecl:
            # Previous declaration => check if definition matches declaration
            if node.returns != symbolInfo.type:
                self.errors.append(node.getPosition() + ": Wrong return type for '" + node.name + "'! Expected: '" + symbolInfo.type + "' found '" + node.returns + "'")

            params = node.params
            paramTypes = []
            if len(params.params) == len(symbolInfo.paramTypes):
                for i in range (0, len(symbolInfo.paramTypes)):
                    if params.params[i].type != symbolInfo.paramTypes[i]:
                        self.errors.append(params.params[i].getPosition() + ": Wrong parameter type for '" + node.name + "'! Expected: '" + symbolInfo.paramTypes[i] + "' found '" + params.params[i].type + "'")
                    paramTypes.append(params.params[i].type)
                    self.symbolTable.addSymbol(params.params[i].name, VarInfo(params.params[i].type))
            else:
                self.errors.append(node.getPosition() + ": Wrong amount of parameters for '" + node.name + "'! Expected: " + str(len(symbolInfo.paramTypes)) + " found " + str(len(params.params)))

            self.symbolTable.addSymbol(node.name, FunctionInfo(node.returns, paramTypes))
            self.symbolTable.newScope()

        else:
            self.errors.append(node.getPosition() + ": Redefinition of '" + node.name + "'")

    def exitFunctionDef(self, node):
        self.symbolTable.endScope()

    def enterFunctionDecl(self, node):
        # Check if new function already exists
        symbolInfo = self.symbolTable.getSymbol(node.name)
        if symbolInfo is None:
            params = node.params
            paramTypes = []
            for param in params.params:
                paramTypes.append(param.type)
            self.symbolTable.addSymbol(node.name, FunctionInfo(node.returns, paramTypes, isDecl=True))
        else:
            self.errors.append(node.getPosition() + ": Redefinition of '" + node.name + "'")

    def enterAssign(self, node):
        symbolInfo = None
        if type(node.left) is Expression.SubScript:
            symbolInfo = self.symbolTable.getSymbol(node.left.mutable.name)
        else:
            symbolInfo = self.symbolTable.getSymbol(node.left.name)
        if symbolInfo is not None:
            if (isinstance(node.right,Expression.BinOp) and (symbolInfo.type == "char" and not hasattr(symbolInfo,"size"))):
                self.errors.append(node.getPosition() + ": Type mismatch: expected '" + symbolInfo.type + "' but found 'BinOp'")
            else:
                getTypeResult = getType(node.right, symbolInfo.type, self.symbolTable)
                if((symbolInfo.type == "float" or symbolInfo.type=="double") and (getTypeResult[0]!="float" and getTypeResult[0]!="double")):
                    self.errors.append(getTypeResult[1] + ": Type mismatch: expected '" + symbolInfo.type + "' but found '" +getTypeResult[0] + "'")
                elif ((symbolInfo.type=="int" or symbolInfo.type=="long" or symbolInfo.type=="signed" or symbolInfo.type=="unsigned")and(getTypeResult[0]!="int" and getTypeResult[0]!="long" and getTypeResult[0]!="signed" and getTypeResult[0]!="unsigned")):
                    self.errors.append(getTypeResult[1] + ": Type mismatch: expected '" + symbolInfo.type + "' but found '" +getTypeResult[0] + "'")

    def enterBinOp(self, node):
        foundMismatch=False
        if type(node.left) is not Expression.BinOp:
            leftType = getType(node.left,"",self.symbolTable)[0]
        else:
           leftTypeResult = checkBinOp(node.left,self.symbolTable)

           if (leftTypeResult[0] == "string" or leftTypeResult[0] == "char"):
               if (leftTypeResult[1] == "string" or leftTypeResult[1] == "char"):
                   leftType = leftTypeResult[0]
               else:
                   foundMismatch = True
                   error = leftTypeResult[2].getPosition() + ": type mismatch! Cannot compare '" + leftTypeResult[
                       0] + "' with '" + leftTypeResult[1] + "'"
                   if error not in self.errors:
                       self.errors.append(error)

           else:
               if (leftTypeResult[1] == "string" or leftTypeResult[1] == "char"):
                   foundMismatch = True
                   error = leftTypeResult[2].getPosition() + ": type mismatch! Cannot compare '" + leftTypeResult[
                       0] + "' with '" + leftTypeResult[1] + "'"
                   if error not in self.errors:
                       self.errors.append(error)

               else:
                   leftType = leftTypeResult[0]
        if type(node.right)is not Expression.BinOp:
            rightType = getType(node.right,"",self.symbolTable)[0]
        else:
            rightTypeResult = checkBinOp(node.right,self.symbolTable)
            if (rightTypeResult[0] == "string" or rightTypeResult[0] == "char"):
                if (rightTypeResult[1] == "string" or rightTypeResult[1] == "char"):
                    rightType = rightTypeResult[0]
                else:
                    foundMismatch = True
                    error = rightTypeResult[2].getPosition() + ": type mismatch! Cannot compare '" + rightTypeResult[
                        0] + "' with '" + rightTypeResult[1] + "'"
                    if error not in self.errors:
                        self.errors.append(error)

            else:
                if (rightTypeResult[1] == "string" or rightTypeResult[1] == "char"):
                    foundMismatch = True
                    error = rightTypeResult[2].getPosition() + ": type mismatch! Cannot compare '" + rightTypeResult[
                        0] + "' with '" + rightTypeResult[1] + "'"
                    if error not in self.errors:
                        self.errors.append(error)

                else:
                    rightType = rightTypeResult[0]

        if not foundMismatch:
            if(leftType == "string" or leftType == "char"):
                if(rightType!="string" and rightType!="char"):
                    error = node.getPosition() + ": type mismatch! Cannot compare '" + leftType + "' with '" + rightType + "'"
                    if error not in self.errors:
                        self.errors.append(error)
            else:
                if(rightType=="string" or rightType=="char"):
                    error = node.getPosition() + ": type mismatch! Cannot compare '" + leftType + "' with '" +rightType+ "'"
                    if error not in self.errors:
                        self.errors.append(error)
Example #2
0
class DecafSemanticChecker(DecafVisitor):
    def __init__(self):
        super().__init__()
        self.head = '.data\n'
        self.body = '.global main\n'
        self.st = SymbolTable()

    def visitProgram(self, ctx: DecafParser.ProgramContext):
        self.st.enterScope()
        self.visitChildren(ctx)
        self.st.exitScope()

    def visitField_decl(self, ctx: DecafParser.Field_declContext):
        line_num = ctx.start.line
        data_type = ctx.data_type().getText()

        field_count = len(ctx.field_name())
        for i in range(field_count):
            field_name = ctx.field_name(i).getText()

            if ctx.field_name(i).int_literal():
                array_size = int(ctx.field_name(i).int_literal().getText())
                if array_size <= 0:
                    print('Error on line', line_num, 'array declared with length 0')
            else:
                array_size = 1

            field_symbol = self.st.probe(field_name)

            if field_symbol:
                print('Error on line', line_num, 'field', field_name, 'is already declared on line', field_symbol.line)
            else:
                field_symbol = VarSymbol(id=field_name,
                                         type=data_type,
                                         line=line_num,
                                         size=8 * array_size,
                                         mem=HEAP)
                self.st.addSymbol(field_symbol)

        self.visitChildren(ctx)

    def visitMethod_decl(self, ctx: DecafParser.Method_declContext):
        line_num = ctx.start.line
        method_name = ctx.ID(0).getText()
        method_type = ctx.return_type().getText()
        method_params = []

        for i in range(len(ctx.data_type())):
            param_type = ctx.data_type(i).getText()
            param_name = ctx.ID(i + 1).getText()
            param_symbol = VarSymbol(id=param_name, type=param_type, line=line_num, size=8, mem=STACK)
            method_params.append(param_symbol)

        method_symbol = MethodSymbol(method_name,
                                     method_type,
                                     line_num,
                                     method_params)
        self.st.addSymbol(method_symbol)

        self.body += method_name + ':\n'
        if method_name == 'main':
            self.body += 'movq %rsp, %rbp\n'

        self.st.enterScope()
        for i in range(len(method_params)):
            self.st.addSymbol(method_params[i])
            self.body += 'movq ' + param_registers[i] + ',' + str(method_params[i].getAddr()) + '(%rsp)\n'

        self.visitChildren(ctx)
        self.body += 'ret\n'
        self.st.exitScope()

    def visitMethod_call(self, ctx: DecafParser.Method_callContext):
        if ctx.method_name():
            for i in range(len(ctx.expr())):
                self.visit(ctx.expr(i))
                self.st.stack_pointer[-1] += 8
                self.body += 'movq %rax, ' + str(-self.st.stack_pointer[-1]) + '(%rsp)\n'

            for i in range(len(ctx.expr())):
                self.body += 'movq ' + str(-self.st.stack_pointer[-1]) + '(%rsp), ' + param_registers[
                    len(ctx.expr()) - i - 1] + '\n'
                self.st.stack_pointer[-1] -= 8

            # adjust stack to 16-byte alignment (multiple of 8 that is not divisible by 16)
            stack_len = self.st.stack_pointer[-1]
            stack_len = stack_len + (stack_len // 8 + 1 % 2) * 8
            self.body += 'subq $' + str(stack_len) + ', %rsp\n'

            method_name = ctx.method_name().getText()
            self.body += 'call ' + method_name + '\n'

            self.body += 'addq $' + str(stack_len) + ', %rsp\n'
        elif ctx.CALLOUT():
            for i in range(len(ctx.callout_arg())):
                if ctx.callout_arg(i).STRING_LITERAL():
                    callout_arg = ctx.callout_arg(i).getText()
                    label = 'str' + str(ctx.callout_arg(i).start.start)
                    self.head += label + ': .asciz ' + callout_arg + '\n'
                    self.st.stack_pointer[-1] += 8
                    self.body += 'movq $' + label + ', ' + str(-self.st.stack_pointer[-1]) + '(%rsp)\n'
                else:
                    pass

    def visitExpr(self, ctx: DecafParser.ExprContext):
        if ctx.literal():
            int_literal = ctx.getText()
            self.body += 'movq $' + int_literal + ', %rax\n'
        elif ctx.location():
            loc_name = ctx.getText()
            location = self.st.lookup(loc_name)
            addr = str(location.getAddr())

            if location.mem == HEAP:
                self.body += 'movq ' + addr + '(%rbp), %rax\n'
            else:
                self.body += 'movq ' + addr + '(%rsp), %rax\n'
        else:
            self.visitChildren(ctx)
Example #3
0
class DecafSemanticChecker(DecafVisitor):
    def __init__(self):
        super().__init__()
        self.st = SymbolTable()

    def visitProgram(self, ctx: DecafParser.ProgramContext):
        #print('Start')
        self.st.enterScope()
        line_num = ctx.start.line

        self.visitChildren(ctx)

        #3
        main_exists = self.st.probe('main')
        if main_exists == None:
            print('Program has no method main')

        self.st.exitScope()
        #print('End')

    def visitField_decl(self, ctx: DecafParser.Field_declContext):
        line_num = ctx.start.line
        data_type = ctx.data_type().getText()
        field_decls = ctx.field_arg()

        for f in field_decls:

            array_size = 1
            field_symbol = self.st.probe(f.getText())

            if f.int_literal() != None:

                array_size = f.int_literal().getText()

                if int(array_size) <= 0:

                    # 4
                    print('Error on line ' + str(line_num) + ', array \'' +
                          f.ID().getText() +
                          '\' must have a declaration value greater than 0')

            if field_symbol != None:

                # 1
                print(
                    'Error on line', line_num, ', variable \'',
                    f.ID().getText(), '\' has already been declared on line ' +
                    str(field_symbol.line))

            else:

                field_symbol = VarSymbol(id=f.ID().getText(),
                                         type=data_type,
                                         line=line_num,
                                         size=8 * int(array_size),
                                         mem=HEAP)

                self.st.addSymbol(field_symbol)

        self.visitChildren(ctx)

    def visitMethod_decl(self, ctx: DecafParser.Method_declContext):
        #Makes method have its own scope
        #self.st.enterScope()
        line_num = ctx.start.line
        method_name = ctx.ID().getText()
        method_args = []

        main_check = self.st.probe('main')
        method_symbol = self.st.probe(method_name)

        if ctx.data_type() != None:
            data_type = ctx.data_type().getText()
        else:
            data_type = "void"

        if method_name == "main" and len(ctx.method_arg()) > 0:
            # 3
            print('Error on line', line_num,
                  ', method \'main\' cannot have any parameters')

        if main_check != None:
            # 3
            print('Warning detected function ' + method_name +
                  ' declared after main, this will not be executed')

        # Adds space to method arguments
        for m in ctx.method_arg():
            method_args.append(m.data_type().getText() + ' ' +
                               m.ID().getText())
            # var_symbol = VarSymbol(id=m.ID().getText(),
            #                          type=m.data_type().getText(),
            #                          line=line_num,
            #                          size=8,
            #                          mem=STACK)
            #
            # self.st.addSymbol(var_symbol)

        method_symbol = MethodSymbol(id=method_name,
                                     type=data_type,
                                     line=line_num,
                                     params=method_args)

        self.st.addSymbol(method_symbol)

        self.visitChildren(ctx)

        statements = ctx.block().statement()

        # 7 and 8
        for s in statements:
            var_type = self.visit(s.expr(0))
            #print(ctx.data_literal().bool_literal())

            if s.RETURN() != None:
                if ctx.VOID() != None:
                    print('Error on line ' + str(line_num) +
                          ', method should not have a return statement')
                #TODO
                # elif ctx.data_type().bool_literal.getText() == var_type:
                #     print('Error on line ' + str(line_num) + ', return value is does not have same type \'' + var_type + '\' as method')

        #self.st.exitScope()

    def visitMethod_arg(self, ctx: DecafParser.Method_argContext):
        line_num = ctx.start.line
        data_type = ctx.data_type().getText()
        var_decls = ctx.ID().getText()
        var_symbol = self.st.probe(var_decls)

        if var_symbol != None:
            # 1
            print(
                'Error on line', line_num, ', variable \'', var_decls,
                '\' has already been declared on line ' + str(var_symbol.line))

        else:

            print(var_decls)
            var_symbol = VarSymbol(id=var_decls,
                                   type=data_type,
                                   line=line_num,
                                   size=8,
                                   mem=STACK)

            self.st.addSymbol(var_symbol)

        self.visitChildren(ctx)

    def visitField_arg(self, ctx: DecafParser.Field_argContext):

        self.visitChildren(ctx)

    def visitData_type(self, ctx: DecafParser.Data_typeContext):

        self.visitChildren(ctx)

    def visitBlock(self, ctx: DecafParser.BlockContext):
        line_num = ctx.start.line
        #var_decl = ctx.var_decl(0).getText()
        #self.st.probe(var_decl)

        self.visitChildren(ctx)

    def visitVar_decl(self, ctx: DecafParser.Var_declContext):
        line_num = ctx.start.line
        var_ids = ctx.ID()
        data_type = ctx.data_type().getText()

        for v in var_ids:

            id_symbol = self.st.probe(v.getText())

            if id_symbol != None:

                # 1
                print('Error on line ' + str(line_num) + ', variable \'' +
                      v.getText() + '\' has already been declared on line ' +
                      str(id_symbol.line))

            else:

                id_symbol = VarSymbol(id=v.getText(),
                                      type=data_type,
                                      line=line_num,
                                      size=8,
                                      mem=STACK)

                self.st.addSymbol(id_symbol)

        self.visitChildren(ctx)

    def visitStatement(self, ctx: DecafParser.StatementContext):
        line_num = ctx.start.line
        expression = ctx.expr(0)

        # 11 Expr of an if must have type bool
        if ctx.IF() != None:
            expr_type = self.visit(expression)
            expr_b_type = self.visit(expression.expr(0))

            print(expr_b_type)
            #if (expr_a_type != None and expr_a_type != 'boolean') or (expr_b_type != None and expr_b_type != 'boolean'):
            #print('Error on line ' + str(line_num) + ', expression in if must be of type boolean')

        elif ctx.location() != None:

            loc_type = self.visit(ctx.location())
            expr_type = self.visit(ctx.expr(0))
            operator = ctx.assign_op().getText()
            identifier = self.st.probe(ctx.location().getText())

            #2
            if identifier == None:
                print('Error on line ' + str(line_num) + ', identifier \'' +
                      ctx.location().getText() + '\' has not been declared')

            #16 - TODO change error message
            elif loc_type != None:
                if (loc_type != 'int'
                        or expr_type != 'int') and (operator == '-='
                                                    or operator == '+='):
                    print(
                        'Error on line ' + str(line_num) +
                        ' variables must be of type int when in an incrementing/decrementing assignment'
                    )

                #15
                elif loc_type != expr_type:
                    print('Error on line ' + str(line_num) +
                          ' type mismatched in expression')

        #17
        elif ctx.FOR() != None:

            expr_type_a = self.visit(ctx.expr(0))
            expr_type_b = self.visit(ctx.expr(1))

            if expr_type_a != 'int' or expr_type_b != 'int':

                print('Error on line ' + str(line_num) +
                      ' for statement expressions must be of type int')

        else:

            self.visitChildren(ctx)

    def visitLocation(self, ctx: DecafParser.LocationContext):
        line_num = ctx.start.line
        var_id = ctx.ID().getText()
        var_symb = self.st.lookup(var_id)
        loc_type = ''

        # 10 - Check if statement has square bracket is an array
        if ctx.LSQUARE():
            # 10i
            if var_symb != None and var_symb.size == 8:
                print('Error on line ' + str(line_num) +
                      ', array variable \'' + var_id + '\' is not an array')

            # 10ii
            print(ctx.expr().location().ID())
            if ctx.expr().data_literal() == None:
                print('Error on line ' + str(line_num) +
                      ', array variable \'' + var_id +
                      '\' value must be of type integer')

        if var_symb != None:
            loc_type = var_symb.type

        self.visitChildren(ctx)

        return loc_type

    def visitExpr(self, ctx: DecafParser.ExprContext):
        line_num = ctx.start.line
        expr_type = ""

        if len(ctx.expr()) == 2:

            type_a = self.visit(ctx.expr(0))
            type_b = self.visit(ctx.expr(1))
            op = ctx.bin_op()

            if type_a == type_b:
                expr_type = type_a

            else:
                expr_type = None
                #print('Error on line', line_num, 'type mismatched in expression')

            #12
            if (op.rel_op() != None or op.arith_op() != None
                ) and type_a != 'int' and type_b != 'int':
                print('Error on line ' + str(line_num) +
                      ' operands must be of type int')
            #13
            elif op.eq_op() != None and type_a != type_b:
                print('Error on line ' + str(line_num) +
                      ' operands must be of same type')
            #14
            elif op.cond_op() != None and (type_a != 'boolean'
                                           or type_b != 'boolean'):
                print('Error on line ' + str(line_num) +
                      ' operands must be of type boolean')

        elif ctx.location() != None:

            var_name = ctx.location().ID().getText()
            var_symbol = self.st.lookup(var_name)

            if var_symbol != None:
                expr_type = var_symbol.type

            else:
                #2
                expr_type = None
                print('Error on line ' + str(line_num) + ', ' + var_name +
                      ' has not been declared')

        elif ctx.data_literal() != None:

            if ctx.data_literal().int_literal() != None:
                expr_type = 'int'

            elif ctx.data_literal().bool_literal() != None:
                expr_type = 'boolean'

            else:
                expr_type = None

        elif ctx.method_call() != None:

            method_name = ctx.method_call().method_name().getText()
            method_symbol = self.st.lookup(method_name)

            if method_symbol != None:

                expr_type = method_symbol.type
                self.visit(ctx.method_call())

            else:
                method_symbol = None

        elif ctx.EXCLAMATION() != None:

            expr_type = self.visit(ctx.expr(0))

            #14
            if expr_type != 'boolean':
                print('Error on line ' + str(line_num) +
                      ' operand must be of type boolean')

        else:
            self.visitChildren(ctx)

        return expr_type

    def visitMethod_call(self, ctx: DecafParser.Method_callContext):
        line_num = ctx.start.line
        method_name = ctx.method_name()
        method_symbol = self.st.probe(method_name.getText())

        # 6
        # Check method is of type bool or int
        if method_symbol != None and method_symbol.type == 'void':
            # 6
            print('Error on line ' + str(line_num) + ' method \'' +
                  method_name.getText() + '\' must return a value')

        # 5
        if method_symbol != None:
            # Check correct number of args
            if len(method_symbol.params) != len(ctx.expr()):
                print('Error on line ' + str(line_num) + ' method ' +
                      method_name.getText() + ' needs ' +
                      str(len(method_symbol.params)) + ' arguments has ' +
                      str(len(ctx.expr())))

            else:
                #Check correct types of args
                for i in range(len(method_symbol.params)):
                    param_type = method_symbol.params[i].split(' ')[0]
                    if (param_type == 'int'
                            and ctx.expr(i).data_literal().int_literal()
                            == None) or (param_type == 'boolean' and ctx.expr(
                                i).data_literal().bool_literal() == None):
                        print('Error on line ' + str(line_num) +
                              ' type mismatch on method \'' +
                              method_name.getText() + '\' parameters')

        self.visitChildren(ctx)
class DecafCodeGenVisitor(DecafVisitor):
    def __init__(self):
        super().__init__()
        self.st = SymbolTable()
        self.head = '.data\n'
        self.body = '.global main\n'

    def visitProgram(self, ctx: DecafParser.ProgramContext):
        #enters a new scope, visits all child nodes, then exits the scope
        self.st.enterScope()
        return_val = self.visitChildren(ctx)
        self.st.exitScope()
        return return_val

    # Visit a parse tree produced by DecafParser#var_id.
    def visitVar_id(self, ctx: DecafParser.Var_idContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#array_id.
    def visitArray_id(self, ctx: DecafParser.Array_idContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#field_declr.
    def visitField_declr(self, ctx: DecafParser.Field_declrContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#field_var.
    def visitField_var(self, ctx: DecafParser.Field_varContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#method_declr.
    def visitMethod_declr(self, ctx: DecafParser.Method_declrContext):

        methodtype = ctx.getChild(0).getChild(0)
        methodid = ctx.getChild(1).getChild(0)
        #methodparams = ctx.getChild(4).getChild(0)
        self.st.addSymbol(MethodSymbol(methodid, methodtype, 0, ""))
        self.body += "\n" + str(methodid) + ":\n"
        retval = self.visitChildren(ctx)
        self.body += "\nret\n"
        return retval

    # Visit a parse tree produced by DecafParser#return_type.
    def visitReturn_type(self, ctx: DecafParser.Return_typeContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#block.
    def visitBlock(self, ctx: DecafParser.BlockContext):
        self.st.enterScope()
        return_val = self.visitChildren(ctx)
        self.st.exitScope()
        return return_val

    def visitStatement(self, ctx: DecafParser.StatementContext):
        #1. recognise statement type (minimal example below)
        #2. write code generation procedure for each special case
        if ctx.assign_op():
            pass

    # Visit a parse tree produced by DecafParser#vardeclr.
    def visitVardeclr(self, ctx: DecafParser.VardeclrContext):
        vtype = None
        vid = None
        for i in ctx.children:
            tkn = i.getText()
            if tkn == 'int' or tkn == 'boolean':
                vtype = tkn
            elif tkn == ',' or tkn == ';':
                continue
            elif tkn != None:
                vid = tkn
                if self.st.probe(vid) != None:
                    raiseErr("Variable " + vid + " already declared in scope.",
                             ctx)

                self.st.addSymbol(VarSymbol(vid, vtype, 0, 8, 0))
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#method_call_inter.
    def visitMethod_call_inter(self,
                               ctx: DecafParser.Method_call_interContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#method_call.
    def visitMethod_call(self, ctx: DecafParser.Method_callContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#expr.
    def visitExpr(self, ctx: DecafParser.ExprContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#location.
    def visitLocation(self, ctx: DecafParser.LocationContext):
        vid = str(ctx.getChild(0).getChild(0))
        if self.st.lookup(vid) == None:
            raiseErr("Variable " + vid + " not declared.", ctx)
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#callout_arg.
    def visitCallout_arg(self, ctx: DecafParser.Callout_argContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#int_literal.
    def visitInt_literal(self, ctx: DecafParser.Int_literalContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#rel_op.
    def visitRel_op(self, ctx: DecafParser.Rel_opContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#eq_op.
    def visitEq_op(self, ctx: DecafParser.Eq_opContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#cond_op.
    def visitCond_op(self, ctx: DecafParser.Cond_opContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#literal.
    def visitLiteral(self, ctx: DecafParser.LiteralContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#bin_op.
    def visitBin_op(self, ctx: DecafParser.Bin_opContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#arith_op.
    def visitArith_op(self, ctx: DecafParser.Arith_opContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#var_type.
    def visitVar_type(self, ctx: DecafParser.Var_typeContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#assign_op.
    def visitAssign_op(self, ctx: DecafParser.Assign_opContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by DecafParser#method_name.
    def visitMethod_name(self, ctx: DecafParser.Method_nameContext):
        #  self.st.addSymbol(MethodSymbol(ctx.getChild(0))
        return_val = self.visitChildren(ctx)

        return return_val
class DecafCodeGenVisitor(DecafVisitor):

    # Global variables, used to keep track of how many if statements, callouts and loops exist in the code.
    IF_LABEL_COUNT = 1
    CALLOUT_COUNT = 1
    LOOP_COUNT = 1

    # Constructor sets up the header of the assembly code.
    def __init__(self):
        super().__init__()
        self.st = SymbolTable()
        self.head = '.data\n'
        self.body = '.global main\n'

    # Visits the program node, ensures there is a main method.
    def visitProgram(self, ctx:DecafParser.ProgramContext):
        self.st.enterScope()
        self.visitChildren(ctx)
        method_symbol = self.st.lookup('main')
        params = []

        # Checks if main method has been declared and if it contains paramaters.
        if method_symbol == None:
            print('[Error]: No main method has been declared.')
        else:
            if len(params) != 0:
                print('[Error]: The main method cannot contain paramaters.')
        self.body += 'ret\n'
        self.st.exitScope()

    # Visits the method declaration node, checks if method is already declared and manages parameters.
    def visitMethod_decl(self, ctx:DecafParser.Method_declContext):
        method_name = ctx.ID(0).getText()
        return_type = ctx.TYPE(0)
        line_number = ctx.start.line

        # Checks if the method has already been declared.
        if self.st.probe(method_name) != None:
            print('[Error]: The method ' + method_name + ' on line: ' + line_number + 'was already declared!')
        else:
            self.body += method_name
            self.body += ':\n'

        params = []

        # Loops through paramaters and creates a var symbol for them and appends them to a list.
        if len(params) > 1:
            for param in range(len(ctx.ID())):
                param_name = ctx.ID(param).getText()
                params.append(param_name)
                var_symbol = self.st.probe(param_name)
                if var_symbol == None:
                    var_symbol = VarSymbol(id=param_name, type='int', line=ctx.start.line, size=8, mem=self.st.stack_pointer)
                    self.st.addSymbol(var_symbol)
                    var_addr = var_symbol.getAddr()
                    self.body += '\tmovq %rax, -' + str(var_addr[0]) + '(%rsp)\n'

            params.pop(0)

        method_symbol = MethodSymbol(id=method_name, type=return_type, line=line_number, params=params)
        self.st.addSymbol(method_symbol)

        visit = self.visitChildren(ctx)
        return visit

    # Visits block node, enters a new scope inside the block.
    def visitBlock(self, ctx:DecafParser.BlockContext):
        self.st.enterScope()
        visit = self.visitChildren(ctx)
        self.st.exitScope()
        return visit

    # Visits expression node, handles variable assignment.
    def visitExpr(self, ctx:DecafParser.ExprContext):

        # Expression is a variable.
        if ctx.location():
            var_name = ctx.location().getText()
            var_symbol = self.st.lookup(var_name)
            if "[" in var_name:
                split_var = var_name.split('[', 1)[0]
                var_symbol = self.st.lookup(split_var)
            if var_symbol == None:
                print('[Error]: Variable', var_name, 'has not been declared. Found on line', ctx.start.line)
            else:
                var_addr = var_symbol.getAddr()
                self.body += '\tmovq -' + str(var_addr[0]) + '(%rsp), %rax\n'

        # Expression is a literal (number or string/char)
        elif ctx.literal():
            number = ctx.literal().getText()
            if number == 'false':
                number = '0'
            if number == 'true':
                number = '1'
            self.body += '\tmovq $' + number + ', %rax\n'

        # Expression length is more than 1 (more expressions present such as an operation)
        elif len(ctx.expr()) > 1:
            # Visit the first expression.
            self.visit(ctx.expr(0))

            # Move stack pointer 1 place and save value of first expression.
            self.st.stack_pointer[-1] += 8
            self.body += '\tmovq %rax, ' + str(-self.st.stack_pointer[-1]) + '(%rsp)\n'

            # Visit the second expression.
            self.visit(ctx.expr(1))
            self.body += '\tmovq ' + str(-self.st.stack_pointer[-1]) + '(%rsp), %r10\n'
            self.st.stack_pointer[-1] -= 8
            self.body += '\tmovq %rax, %r11\n'

            # If a binary operator is present, check the operator and add appropriate code.
            if ctx.BIN_OP():
                if str(ctx.BIN_OP()) == '+':
                    self.body += '\taddq %r10, %r11\n'
                if str(ctx.BIN_OP()) == '*':
                    self.body += '\timul %r10, %r11\n'
                if str(ctx.BIN_OP()) == '-':
                    self.body += '\tsubq %r10, %r11\n'
                if str(ctx.BIN_OP()) == '/':
                    self.body += '\tmovq $0, rdx\n'
                    self.body += '\tmovq %r11, rbx\n'
                    self.body += '\tmovq %r10, rax\n'
                    self.body += '\tidiv %rbx\n'

            self.body += '\tmovq %r11, %rax\n'

    # Visits the variable declaration node, handles storage of variables and name checking.
    def visitVar_decl(self, ctx:DecafParser.Var_declContext):

        # Loops through all variables (to evaluate int x, y, z for example.)
        for i in range(len(ctx.ID())):
            var_name = ctx.ID(i).getText()
            var_symbol = self.st.probe(var_name)
            if "[" in var_name:
                array_var_name = ctx.ID(i).getText()
                split_var = array_var_name.split('[', 1)[0]
            else:
                if var_symbol == None:
                    var_symbol = VarSymbol(id=var_name, type='int', line=ctx.start.line, size=8, mem=self.st.stack_pointer)
                    self.st.addSymbol(var_symbol)
                    var_addr = var_symbol.getAddr()
                    self.body += '\tmovq %rax, -' + str(var_addr[0]) + '(%rsp)\n'
                else:
                    print('[Error]:', var_symbol.id + ', declared on line', ctx.start.line, 'has already been declared on line', var_symbol.line)

        visit = self.visitChildren(ctx)
        return visit

    # Visit the statement node, handles constructs such as IF statements and FOR loops.
    def visitStatement(self, ctx:DecafParser.StatementContext):
        if ctx.CONTINUE() != None:
            self.body += '\tjmp main\n'
        if ctx.BREAK() != None:
            self.body += '\tjmp main\n'
        if ctx.IF():
            self.st.enterScope()
            if_label = 'if-label-'+str(self.IF_LABEL_COUNT)
            self.body += '\tcmp %r11 %r10\n'
            self.body += '\tjl '+if_label+'l\n'
            self.body += '\tje '+if_label+'e\n'
            self.body += '\tjg '+if_label+'g\n'
            self.body += '\tret\n'
            self.body += if_label+':\n'
            self.IF_LABEL_COUNT = self.IF_LABEL_COUNT + 1
            ctx.expr()
            self.st.exitScope()
        if ctx.RETURN():
            if ctx.expr():
                return_value = str(ctx.expr(0).getText())
                self.body += '\tmovq $'+return_value+', %rax\n'
                self.body += '\tret\n'
            else:
                self.body += '\tret\n'
        if ctx.FOR():
            self.st.enterScope()
            start_value = ctx.expr(0)
            end_value = ctx.expr(1)
            self.body += '\tmovq $1, %rbx\n'
            self.body += '\tjmp begin-for-'+str(self.LOOP_COUNT)+'\n'
            self.body += 'begin-for-'+str(self.LOOP_COUNT)+':\n'
            self.body += '\tcmp $'+str(end_value)+ ', %rbx\n'
            self.body += '\tjge end-for-'+str(self.LOOP_COUNT)+'\n'
            visit = self.visitChildren(ctx)
            self.body += '\taddq $1, %rbx\n'
            self.body += '\tjmp begin-for-'+str(self.LOOP_COUNT)+'\n'
            self.body += 'end-for-'+str(self.LOOP_COUNT)+':\n'
            self.body += '\tret\n'

            self.LOOP_COUNT = self.LOOP_COUNT + 1
            self.st.exitScope()

        visit = self.visitChildren(ctx)
        return visit

    # Visit field declaration node, handles assignment of arrays.
    def visitField_decl(self, ctx:DecafParser.Field_declContext):
        for i in range(len(ctx.field_name())):
            var_name = ctx.field_name(i).getText()
            var_symbol = self.st.probe(var_name)

            # Declaration is an array.
            if "[" in var_name:
                array_var_name = ctx.field_name(i).getText()
                split_var = array_var_name.split('[', 1)[0]
                if var_symbol == None:
                    var_symbol = VarSymbol(id=split_var, type='int', line=ctx.start.line, size=8, mem=self.st.stack_pointer)
                    self.st.addSymbol(var_symbol)
                    var_addr = var_symbol.getAddr()
                    self.body += '\tmovq %rax, -' + str(var_addr[0]) + '(%rsp)\n'
            else:
                if var_symbol == None:
                    var_symbol = VarSymbol(id=var_name, type='int', line=ctx.start.line, size=8, mem=self.st.stack_pointer)
                    self.st.addSymbol(var_symbol)
                    var_addr = var_symbol.getAddr()
                    self.body += '\tmovq %rax, -' + str(var_addr[0]) + '(%rsp)\n'
                else:
                    print('[Error]:', var_symbol.id + ', declared on line', ctx.start.line, 'has already been declared on line', var_symbol.line)
        visit = self.visitChildren(ctx)
        return visit

    # Visit method call node, checks if method exists.
    def visitMethod_call(self, ctx:DecafParser.Method_callContext):
        method_name = ctx.method_name()
        method_symbol = self.st.lookup(method_name)
        if not ctx.callout_arg():
            if method_symbol == None:
                print('[Error]: Call to a function that does not exist: ' + str(method_name) + ' on line: ' + str(ctx.start.line))
            else:
                self.body += '\tjmp '+method_name+'\n'
        visit = self.visitChildren(ctx)
        return visit

    # Visits callout arg node, handles adding strings to the head and printing text.
    def visitCallout_arg(self, ctx:DecafParser.Callout_argContext):
        self.head += 'string'+str(self.CALLOUT_COUNT)+': .asciz '+str(ctx.STRING_LITERAL())+'\n'
        self.body += '\tmovq $'+str(self.CALLOUT_COUNT)+', %rdi\n'
        self.body += '\tsubq $8, %rsp\n'
        self.body += '\tcall printf\n'
        self.body += '\taddq $8, %rsp\n'
        self.CALLOUT_COUNT = self.CALLOUT_COUNT + 1

        visit = self.visitChildren(ctx)
        return visit
Example #6
0
class DecafSemanticChecker(DecafVisitor):
    def __init__(self):
        super().__init__()
        self.st = SymbolTable()
        # initialise an empty Symbol Table object

    def visitProgram(self, ctx: DecafParser.ProgramContext):
        self.st.enterScope()  # enter symbol table scope
        self.visitChildren(ctx)
        self.st.exitScope()

    def visitVar_decl(self, ctx: DecafParser.Var_declContext):
        # semantic rule: No identifier is declared twice in the same scope
        # test with testdata/semantics/illegal-01.dcf
        line_num = ctx.start.line
        for var_decl in ctx.ID():
            var_name = var_decl.getText()  # gets the variable name (eg. x)
            var_symbol = self.st.probe(
                var_name)  # search Symbol Table for variable entry

            if var_symbol != None:  # if variable does NOT exist in Symbol Table
                print('Error on line', line_num, 'variable \'', var_name,
                      '\' already declared on line', var_symbol.line)
            else:
                var_symbol = VarSymbol(id=var_name,
                                       type='int',
                                       line=line_num,
                                       size=8,
                                       mem=STACK)
                self.st.addSymbol(
                    var_symbol
                )  # add var_symbol to the scope (st abbreviation of SymbolTable)

        return self.visitChildren(ctx)

    def visitStatement(self, ctx: DecafParser.StatementContext):
        # semantic rule: No identifier is used before it is declared
        if ctx.location() != None:
            line_num = ctx.start.line
            var_name = ctx.location().ID().getText()

            var_symbol = self.st.lookup(var_name)

            if var_symbol == None:
                print('Error on line', line_num, 'variable \'', var_name,
                      '\'is not declared')

        self.visitChildren(ctx)

    # semantic rule: warn the user that any method defined after the main method will never be executed.

    # semantic rule: int_literal in an array declaration must be greater than 0
    def visitField_name(self, ctx: DecafParser.Field_nameContext):
        if ctx.int_literal() != None:
            if int(ctx.int_literal().DECIMAL_LITERAL().getText()) < 1:
                line_num = ctx.start.line
                var_name = ctx.ID().getText()
                print("Error on line", line_num, "variable '", var_name,
                      "' array size must be greater than 0")

        return self.visitChildren(ctx)

    # semantic rule 5: number and types of arguments in a method call must be the same as
    #   the number and types of the formals, i.e., the signatures must be identical.
    def visitMethod_decl(self, ctx: DecafParser.Method_declContext):
        method_name = ctx.ID()[0].getText()
        method_return_type = ctx.return_type().getText()
        line_num = ctx.start.line
        method_params = []
        for x in ctx.data_type():
            method_params.append(x.getText())  # get data type as a string
        method_symbol = MethodSymbol(
            id=method_name,
            type=method_return_type,
            line=line_num,
            params=method_params)  # create a method symbol with ctx values
        self.st.addSymbol(
            method_symbol
        )  # push method symbol with params list to global scope
        return self.visitChildren(ctx)

    def visitMethod_call(self, ctx: DecafParser.Method_callContext):
        # get method call
        line_num = ctx.start.line
        method_name = ctx.method_name().getText()
        # lookup method call name in symbol table
        method_symbol = self.st.lookup(method_name)
        method_symbol_params = method_symbol.params
        if len(ctx.expr()) != len(method_symbol_params):
            return print(
                "Error you passed an incorrect combination of parameters",
                "on line", line_num,
                ", the number and types of arguments in a method call must be the same as the number and types of the formals"
            )
        else:
            for i in range(max(len(method_symbol_params), len(ctx.expr()))):
                # check out of bound index
                if i >= len(method_symbol_params):
                    print(
                        "Error you passed an unexpected parameter",
                        ctx.expr()[i].literal().getText(), "on line", line_num,
                        ", the number and types of arguments in a method call must be the same as the number and types of the formals"
                    )
                else:
                    if method_symbol_params[i] == 'int':
                        if ctx.expr()[i].literal().int_literal() == None:
                            print(
                                "Error incorrect parameter data type expected",
                                method_symbol.type, "received value",
                                ctx.expr()[i].literal().getText(), "on line",
                                line_num,
                                ", the number and types of arguments in a method call must be the same as the number and types of the formals"
                            )
                    elif method_symbol_params[i] == 'boolean':
                        if ctx.expr()[i].literal().bool_literal() == None:
                            print(
                                "Error incorrect parameter date type expected",
                                method_symbol.type, "received",
                                ctx.expr()[i].literal(), "on line", line_num,
                                ", the number and types of arguments in a method call must be the same as the number and types of the formals"
                            )
                    else:
                        print(
                            "missing method_symbol_params with data type classification:",
                            method_symbol_params[i], " on line number",
                            line_num,
                            ", the number and types of arguments in a method call must be the same as the number and types of the formals"
                        )

        return self.visitChildren(ctx)
Example #7
0
class DecafSemanticChecker(DecafVisitor):
    def __init__(self):
        super().__init__()
        self.head = '.data\n'
        self.body = '.global main\n'
        self.st = SymbolTable()
        
    def visitProgram(self, ctx:DecafParser.ProgramContext):
        self.st.enterScope()
        self.visitChildren(ctx)
        self.st.exitScope()
    
    def visitField_decl(self, ctx:DecafParser.Field_declContext):
        line_num = ctx.start.line
        data_type = ctx.data_type().getText()
        field_decls = ctx.field_arg()
                
        for f in field_decls:
            
            array_size = 0
            field_symbol = self.st.probe(f.getText())
            
            if f.int_literal() != None:
                
                array_size = f.int_literal().getText()
                
                if int(array_size) <= 0:
                
                    print('Error on line', line_num,', array \'', f.ID().getText(),'\' must have a declaration value greater than 0')
            
            if field_symbol != None:
    
                print('Error on line', line_num,', variable \'', f.ID().getText(),'\' has already been declared on line',field_symbol.line)
                
            else:

                field_symbol = VarSymbol(id=f.ID().getText(),
                                         type=data_type,
                                         line=line_num,
                                         size=8,
                                         mem=HEAP)
                
                self.st.addSymbol(field_symbol)
                        
        self.visitChildren(ctx)
    
    def visitMethod_decl(self, ctx:DecafParser.Method_declContext):
        data_type = ""
        line_num = ctx.start.line
        method_name = ctx.ID().getText()
        method_args = ctx.method_arg()
        method_params = []
        
        if ctx.data_type() != None:
            data_type = ctx.data_type().getText()
        else:
            data_type = "void"

        for i in method_args:
            arg_type = i.data_type().getText()
            method_arg = VarSymbol(id=i.ID().getText(), type=arg_type, line=line_num, size=8, mem=STACK)
            method_params.append(method_arg)

        method_symbol = MethodSymbol(id=method_name,
                                         type=data_type,
                                         line=line_num,
                                         params=method_params)
        
        self.body += method_name + ':\n'
        
        if method_name == 'main':
            self.body += 'movq %rsp, %rbp\n'
            
        self.st.enterScope()
        for i in range(len(method_params)):
            self.st.addSymbol(method_params[i])
            #Saving each method parameter onto a location on the stack(Memory)
            self.body += 'movq ' + param_registers[i] + ',' + str(method_params[i].getAddr()) + '(%rsp)\n'
        
        self.visitChildren(ctx)
        self.body += 'ret\n'
        self.st.exitScope()

    def visitVar_decl(self, ctx: DecafParser.Var_declContext):
        line_num = ctx.start.line
        var_ids = ctx.ID()
        data_type = ctx.data_type().getText()

        for v in var_ids:

            id_symbol = self.st.probe(v.getText())

            if id_symbol != None:

                # 1
                print('Error on line ' + str(
                    line_num) + ', variable \'' + v.getText() + '\' has already been declared on line ' + str(
                    id_symbol.line))

            else:

                id_symbol = VarSymbol(id=v.getText(),
                                      type=data_type,
                                      line=line_num,
                                      size=8,
                                      mem=STACK)

                self.st.addSymbol(id_symbol)

        self.visitChildren(ctx)

    def visitExpr(self, ctx:DecafParser.ExprContext):
                
        if ctx.data_literal():
            if ctx.data_literal().int_literal() != None:
                int_literal = ctx.data_literal().getText()
                self.body += 'movq $' + int_literal + ', %rax\n'
            
        elif ctx.location():
            loc_name = ctx.location().getText()
            location = self.st.lookup(loc_name)

            addr = location.getAddr()

            if location.mem == HEAP:
                self.body += 'movq ' + str(addr) + '(%rbp), %rax\n'
            else:
                self.body += 'movq ' + str(addr) + '(%rsp), %rax\n'

        elif len(ctx.expr()) == 2:
            self.visit(ctx.expr(0))
            self.body += 'movq %rax, %r10\n'

            self.st.stack_pointer[-1] += 8
            self.body += 'movq %r10, ' + str(-self.st.stack_pointer[-1]) + '(%rsp)\n'

            self.visit(ctx.expr(1))
            self.body += 'movq %rax, %r11\n'

            self.body += 'movq ' + str(-self.st.stack_pointer[-1]) + '(%rsp), %r10\n'
            self.st.stack_pointer[-1] -= 8

            if ctx.bin_op().arith_op().ADD():
                self.body += 'addq %r10, %r11\n'
            elif ctx.bin_op().arith_op().SUB():
                self.body += 'subq %r11, %r10\n'
                self.body += 'movq %r10, %r11\n'
            elif ctx.bin_op().arith_op().MUL():
                self.body += 'imul %r10, %r11\n'
            elif ctx.bin_op().arith_op().DIV():
                self.body += 'movq $0, %rdx\n'
                self.body += 'movq %r11, %rbx\n'
                self.body += 'movq %r10, %rax\n'
                self.body += 'idiv %rbx\n'
                self.body += 'movq %rax, %r11\n'

            self.body += 'movq %r11, %rax\n'

        else:
            self.visitChildren(ctx)

    def visitMethod_call(self, ctx:DecafParser.Method_callContext):

        if ctx.method_name():

            for i in range(len(ctx.expr())):
                self.visit(ctx.expr(i))
                self.st.stack_pointer[-1] += 8
                ptr = self.st.stack_pointer[-1]
                self.body += 'movq %rax, ' + str(ptr) + '(%rsp)\n'

            for z in range(len(ctx.expr())):
                ptr = self.st.stack_pointer[-1]
                reg = param_registers[z]
                self.body += 'movq ' + str(ptr) + '(% rsp), ' + reg + '\n'
                self.st.stack_pointer[-1] -= 8

            #Current pos stored in symbol table
            #Needs to be 16 byte aligned or we get segmentation errors
            stack_len = self.st.stack_pointer[-1]
            stack_len = stack_len + (int(stack_len/8+1) % 2) * 8
            self.body += 'subq $' + str(stack_len) + ', %rsp\n'
            method_name = ctx.method_name().getText()
            self.body += 'call ' + method_name + '\n'
            self.body += 'addq $' + str(stack_len) + ', %rsp\n'

        elif ctx.CALLOUT():
            pass

        else:
            self.visitChildren(ctx)
Example #8
0
class c2llvmVisitor(tinycVisitor):

    BASE = 0
    ARRAY = 1
    FUNC = 2

    def __init__(self):
        super(c2llvmVisitor, self).__init__()
        self.module = ir.Module()
        self.scope_depth = 0
        self.module.triple = "x86_64-unknown-linux-gnu"  # llvm.Target.from_default_triple()
        self.module.data_layout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"

        # 符号表
        self.symbol_table = SymbolTable()
        printf_type = ir.FunctionType(LLVMTypes().int32,
                                      [ir.PointerType(LLVMTypes().int8)],
                                      var_arg=True)
        printf_func = ir.Function(self.module, printf_type, "printf")
        self.symbol_table.addSymbol('printf', printf_func)
        self.cur_type = None
        self.constants = 0
        self.continue_block, self.break_block = None, None
        pass

    # Visit a parse tree produced by tinycParser#program.
    def visitProgram(self, ctx: tinycParser.ProgramContext):
        print('visit program')
        total = ctx.getChildCount()
        for index in range(total):
            # print(ctx.getChild(index).getRuleIndex())
            self.visit(ctx.getChild(index))
        return

    # Visit a parse tree produced by tinycParser#include.
    def visitInclude(self, ctx: tinycParser.IncludeContext):
        print('visit include', ctx.getChild(2).getText())
        return self.visitChildren(ctx)

    # Visit a parse tree produced by tinycParser#translationUnit.
    def visitTranslationUnit(self, ctx: tinycParser.TranslationUnitContext):
        print('visit trans unit')
        return self.visitChildren(ctx)

    # Visit a parse tree produced by tinycParser#function.
    def visitFunction(self, ctx: tinycParser.FunctionContext):
        self.symbol_table.enterScope()
        ret_type = self.visit(ctx.getChild(0))
        self.cur_type = ret_type
        _, func_name, func_type, arg_names = self.visit(ctx.getChild(1))

        val = self.symbol_table.getSymbol(func_name)
        if val:
            # TODO 错误处理
            raise Exception("redefine function error!")
        else:
            llvm_func = ir.Function(self.module, func_type, name=func_name)
            print("func is", func_name, func_type)
            self.symbol_table.addSymbol(func_name, llvm_func)
            self.builder = ir.IRBuilder(
                llvm_func.append_basic_block(name=".entry" + func_name))
        self.symbol_table.enterScope()
        for arg, llvm_arg in zip(arg_names, llvm_func.args):
            print('func add argname', arg, llvm_arg)
            print(type(llvm_arg))
            self.symbol_table.addSymbol(arg, llvm_arg)
        self.continue_block = None
        self.break_block = None
        self.visit(ctx.compoundStatement())
        self.symbol_table.exitScope()
        return

    # Visit a parse tree produced by tinycParser#typeSpecifier.
    def visitTypeSpecifier(self, ctx: tinycParser.TypeSpecifierContext):
        text = ctx.getText()
        if text == 'int':
            return ir.IntType(32)
        elif text == 'void':
            return ir.VoidType()
        elif text == 'char':
            return ir.IntType(8)
        else:
            ##TODO error
            pass

    # Visit a parse tree produced by tinycParser#compoundStatement.
    def visitCompoundStatement(self,
                               ctx: tinycParser.CompoundStatementContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by tinycParser#compoundUnit.
    def visitCompoundUnit(self, ctx: tinycParser.CompoundUnitContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by tinycParser#declaration.
    def visitDeclaration(self, ctx: tinycParser.DeclarationContext):
        var_type = self.visit(ctx.typeSpecifier())
        self.cur_type = var_type
        print('var_type:', var_type)
        init_list = self.visit(ctx.initDeclaration())

    # Visit a parse tree produced by tinycParser#initDeclaration.
    def visitInitDeclaration(self, ctx: tinycParser.InitDeclarationContext):
        cnt = ctx.getChildCount()
        declarator = None
        for index in range(0, cnt, 2):
            declarator = self.visit(ctx.getChild(index))
        return declarator

    # Visit a parse tree produced by tinycParser#initDeclarator.
    def visitInitDeclarator(self, ctx: tinycParser.InitDeclaratorContext):
        tpe, name, llvm_tpe, size = self.visit(ctx.declarator())
        has_init = (len(ctx.children) == 3)
        if tpe == self.BASE:
            # 分配内存
            addr = self.builder.alloca(llvm_tpe)
            try:
                self.symbol_table.addSymbol(name, addr)
                if has_init:
                    init_val = self.visit(ctx.initializer())
                    # if type(init_val) != ir.Constant: #TODO:check here我觉得问题很大
                    #     converted_val = ir.Constant(llvm_tpe,init_val)
                    # else:
                    converted_val = init_val
                    print('initiaze to ', init_val)
                    print(isinstance(init_val, ir.IntType))
                    print(type(init_val))
                    print('help me teacher!!', init_val, addr)
                    self.builder.store(converted_val, addr)
            except Exception as e:
                raise e

        elif tpe == self.FUNC:
            raise Exception("init declarator cannot be a func")
            #TODO: error
        elif tpe == self.ARRAY:
            var_type = llvm_tpe

            try:

                print('var_name type', name, var_type)
                if has_init:
                    init_val = self.visit(ctx.initializer())
                    if isinstance(init_val, list):
                        converted_val = ir.Constant(var_type, init_val)
                    else:
                        var_type = init_val.type
                        converted_val = init_val
                        print('Array initiaze to ', init_val, type(init_val))
                addr = self.builder.alloca(var_type)
                print('addr', addr, 'llvm_tpe', llvm_tpe)
                self.symbol_table.addSymbol(name, addr)
                if has_init:
                    self.builder.store(converted_val, addr)
                    #print('self.builder.module', self.builder.block.instructions)
            except Exception as e:
                raise e

    def visitInitializer(self, ctx: tinycParser.InitializerContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.assignmentExpression())
        else:
            return self.visit(ctx.initializerList())

    def visitInitializerList(self, ctx: tinycParser.InitializerListContext):
        print('visitInitializerList', ctx.children)
        init_list = []
        if len(ctx.children) != 1:
            init_list = self.visit(ctx.initializerList())
        init_list.append(self.visit(ctx.initializer()))
        print('init_list', init_list)
        return init_list

    def visitIterationStatement(self,
                                ctx: tinycParser.IterationStatementContext):
        self.symbol_table.enterScope()
        prefix = self.builder.block.name

        keyword = ctx.getChild(0).getText()
        if keyword == 'for':
            #初始化语句
            child_idx = 2
            if getRuleName(ctx.children[child_idx]) == 'expression':
                self.visit(ctx.children[child_idx])
                child_idx += 2
            elif getRuleName(ctx.children[child_idx]) == 'declaration':
                self.visit(ctx.declaration())
                child_idx += 1
            else:
                child_idx += 1

            start_block = self.builder.append_basic_block(name=prefix +
                                                          ".loop_start")
            cond_block = self.builder.append_basic_block(name=prefix +
                                                         ".loop_cond")
            loop_block = self.builder.append_basic_block(name=prefix +
                                                         ".loop_body")
            update_block = self.builder.append_basic_block(name=prefix +
                                                           ".loop_update")
            end_block = self.builder.append_basic_block(name=prefix +
                                                        ".loop_end")

            last_continue, last_break = self.continue_block, self.break_block
            self.continue_block, self.break_block = update_block, end_block

            self.builder.branch(start_block)
            self.builder.position_at_start(start_block)

            self.builder.branch(cond_block)
            self.builder.position_at_start(cond_block)

            if getRuleName(ctx.children[child_idx]) == 'expression':
                cond_val = self.visit(ctx.children[child_idx])
                print('cond_val', cond_val.flags)
                cond_val = whether_is_true(self.builder, cond_val)
                buf = cond_val.get_reference()

                buf = LLVMTypes.bool(buf)
                print('converted_cond', buf)

                self.builder.cbranch(buf, loop_block, end_block)
                child_idx += 2
            else:
                child_idx += 1
                self.builder.branch(loop_block)

            self.builder.position_at_start(loop_block)
            self.visit(ctx.statement())
            self.builder.branch(update_block)

            self.builder.position_at_start(update_block)
            if getRuleName(ctx.children[child_idx]) == 'expression':
                self.visit(ctx.children[child_idx])
                child_idx += 2
            else:
                child_idx += 1
            self.builder.branch(cond_block)

            self.builder.position_at_start(end_block)
            self.symbol_table.exitScope()
            self.continue_block = last_continue
            self.break_block = last_break
            #TODO: continue break[-=/        return self.visitChildren(ctx)

    # Visit a parse tree produced by tinycParser#returnStatement.
    def visitReturnStatement(self, ctx: tinycParser.ReturnStatementContext):
        jump_instru = ctx.children[0].getText()
        if jump_instru == 'return':
            if len(ctx.children) == 2:
                self.builder.ret_void()
            else:
                ret_val = self.visit(ctx.expression())
                # TODO: cast type
                self.builder.ret(ret_val)
        elif jump_instru == 'continue':
            if self.continue_block is None:
                raise Exception("continue can not be used here", ctx)
            self.builder.branch(self.continue_block)
        else:
            raise Exception('not implemented')

    # Visit a parse tree produced by tinycParser#expressionStatement.
    def visitExpressionStatement(self,
                                 ctx: tinycParser.ExpressionStatementContext):
        return self.visitChildren(ctx)

    # Visit a parse tree produced by tinycParser#declarator.
    def visitDeclarator(self, ctx: tinycParser.DeclaratorContext):
        """返回_, func_name, func_type, arg_names"""
        print('visitDeclarator:', ctx.children, ctx.getText())
        tpe, name, llvm_type, arg = self.visit(ctx.directDeclarator())
        print('after decl: tpe name ', tpe, name, llvm_type, arg)
        if tpe == self.ARRAY:
            if arg:
                for size in reversed(arg):
                    print('before llvm_type:', llvm_type, arg)
                    llvm_type = ir.ArrayType(element=llvm_type, count=size)
                    print('llvm_type:', llvm_type, arg)
                    return tpe, name, llvm_type, []
            return tpe, name, llvm_type, []
        else:
            return tpe, name, llvm_type, arg

    def visitDirectDeclarator(self, ctx: tinycParser.DirectDeclaratorContext):
        print('visitDirectDeclarator', ctx.getText())
        if len(ctx.children) == 1:
            # :   IDENTIFIER
            # TODO: 检查这里返回值
            # print('Base:', self.cur_type)
            return self.BASE, ctx.IDENTIFIER().getText(), self.cur_type, []
        else:
            op = ctx.children[1].getText()
            # 函数
            old_type = self.cur_type
            if op == '(':
                func_name = ctx.children[0].getText()
                if len(ctx.children) == 4:
                    (arg_names,
                     arg_types), var_arg = self.visit(ctx.parameterTypeList())
                    new_llvm_type = ir.FunctionType(old_type,
                                                    arg_types,
                                                    var_arg=var_arg)
                    # 代表有参数列表
                    return self.FUNC, func_name, new_llvm_type, arg_names
                else:
                    arg_names = []
                    arg_types = []
                    new_llvm_type = ir.FunctionType(old_type, arg_types)
                    return self.FUNC, func_name, new_llvm_type, arg_names
            elif op == '[':
                tpe, arrayname, old_type, array_nums = self.visit(
                    ctx.directDeclarator())
                print('arrayname', arrayname, ctx.children, len(ctx.children),
                      tpe, old_type, array_nums)
                if len(ctx.children) >= 4:
                    try:
                        array_size = int(ctx.constantExpression().getText())
                        array_nums.append(array_size)
                        # llvm_type = ir.PointerType(old_type)
                        print('return ARRAY', self.ARRAY, arrayname, old_type,
                              array_nums)
                        return self.ARRAY, arrayname, old_type, array_nums
                    except:
                        raise Exception('only constant value are supported')
                else:
                    arrayname = ctx.children[0].getText()
                    print("current type", old_type)
                    llvm_type = ir.PointerType(old_type)
                    print("param type", llvm_type)
                    return self.ARRAY, arrayname, llvm_type, None
                    # return self.ARRAY, arrayname, ,None

    # Visit a parse tree produced by tinycParser#constantExpression.
    def visitConstantExpression(self,
                                ctx: tinycParser.ConstantExpressionContext):
        var, addr = self.visit(ctx.conditionalExpression())
        return var

    # Visit a parse tree produced by tinycParser#parameterTypeList.
    def visitParameterTypeList(self,
                               ctx: tinycParser.ParameterTypeListContext):
        if len(ctx.children) == 3:
            return self.visit(ctx.parameterList()), True
        else:
            return self.visit(ctx.parameterList()), False

    # Visit a parse tree produced by tinycParser#parameterList.
    def visitParameterList(self, ctx: tinycParser.ParameterListContext):
        ### C是从右向左压参的
        cnt = ctx.getChildCount()
        names, types = [], []
        for index in range(0, cnt, 2):
            bn, bt = self.visit(ctx.getChild(index))
            names.append(bn)
            types.append(bt)
        return names, types

    # Visit a parse tree produced by tinycParser#parameterDeclaration.
    def visitParameterDeclaration(
            self, ctx: tinycParser.ParameterDeclarationContext):
        # TODO:现在parameter生成不太对
        self.cur_type = self.visit(ctx.typeSpecifier())
        _, argname, argtype, _ = self.visit(ctx.declarator())
        return argname, argtype

    # Visit a parse tree produced by tinycParser#statement.
    def visitStatement(self, ctx: tinycParser.StatementContext):
        return self.visitChildren(ctx)

    def visitExpression(self, ctx: tinycParser.ExpressionContext):
        cnt = ctx.getChildCount()
        val = self.visit(ctx.getChild(0))
        for index in range(2, cnt, 2):
            val = self.visit(ctx.getChild(index))
        return val

    # Visit a parse tree produced by tinycParser#selectionStatement.
    def visitSelectionStatement(self,
                                ctx: tinycParser.SelectionStatementContext):
        if ctx.children[0].getText() == 'if':
            value = self.visit(ctx.children[2])
            value = whether_is_true(self.builder, value)
            value = value.get_reference()
            condition = ir.IntType(1)(value)
            self.symbol_table.enterScope()
            if len(ctx.children) > 5:
                with self.builder.if_else(condition) as (then, otherwise):
                    with then:
                        self.visit(ctx.children[4])
                    with otherwise:
                        self.visit(ctx.children[6])
            else:
                with self.builder.if_then(condition):
                    self.visit(ctx.children[4])
            self.symbol_table.exitScope()

    # Visit a parse tree produced by tinycParser#assignmentExpression.
    def visitAssignmentExpression(
            self, ctx: tinycParser.AssignmentExpressionContext):
        if (len(ctx.children)) == 3:
            var, addr = self.visit(ctx.unaryExpression())
            op = ctx.getChild(1).getText()
            val = self.visit(ctx.assignmentExpression())
            if op == "=":
                self.builder.store(val, addr)
                print('val, addr', val, '\n', addr)
                return val
            elif op == "+=":
                add = self.builder.add(var, val)
                print('add, addr, var', add, '\n', addr, '\n', var)
                self.builder.store(add, addr)
                return var  #TODO:check here
            elif op == "-=":
                sub = self.builder.sub(var, val)
                self.builder.store(sub, addr)
                return var
            elif op == "*=":
                mul = self.builder.mul(var, val)
                self.builder.store(mul, addr)
                return var
            elif op == "/=":
                mul = self.builder.sdiv(var, val)
                self.builder.store(mul, addr)
                return var
            elif op == "%=":
                rem = self.builder.srem(var, val)
                self.builder.store(rem, addr)
                return var
            else:
                raise Exception('not implemented')
        else:
            val, addr = self.visit(ctx.conditionalExpression())
            return val

    def visitArgumentExpressionList(
            self, ctx: tinycParser.ArgumentExpressionListContext):
        if len(ctx.children) == 1:
            arg_list = []
        else:
            arg_list = self.visit(ctx.argumentExpressionList())
        arg = self.visit(ctx.assignmentExpression())
        print('arg is ', arg)
        arg_list.append(arg)
        return arg_list

    # Visit a parse tree produced by tinycParser#postfixExpression.
    def visitPostfixExpression(self,
                               ctx: tinycParser.PostfixExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.primaryExpression())
        else:
            left_exp, addr = self.visit(ctx.postfixExpression())
            op = ctx.children[1].getText()
            if op == '(':
                args = []
                if len(ctx.children) == 4:
                    args = self.visit(ctx.argumentExpressionList())
                converted_args = []
                # TODO 类型转换
                for arg, param in zip(args, left_exp.args):
                    if (type(arg.type) is ir.ArrayType) and \
                            (type(param.type) is ir.PointerType):
                        arg = arr_to_llvm_pointer(self.builder, arg)
                        converted_args.append(arg)
                    else:
                        converted_args.append(arg)
                if len(converted_args) < len(args):  # 考虑变长参数
                    converted_args += args[len(left_exp.args):]
                return self.builder.call(left_exp, converted_args), None
            elif op == '[':

                idx = self.visit(ctx.expression())
                # if type(left_exp.type) in [ir.ArrayType]:
                #     var = self.builder.extract_value(left_exp, val.constant)
                #     return var, None
                zero = ir.Constant(LLVMTypes.int, 0)
                if type(left_exp) is ir.Argument:
                    array_indices = [idx]
                else:
                    array_indices = [zero, idx]
                print("postif []the val is ", idx, "left ", addr)
                addr = self.builder.gep(addr, array_indices)
                # tmp = self.builder.alloca(val.type)
                # self.builder.store(val, tmp)
                # addr = self.builder.gep(tmp, [val])
                print("addr is ", addr)

                var = self.builder.load(addr)
                return var, addr
            elif op == '++':
                one = left_exp.type(1)
                print('one', one)
                res = self.builder.add(left_exp, one)
                print('++', addr, res)
                if addr:
                    val = self.builder.store(res, addr)
                print('++', addr)
                return left_exp, addr
            elif op == '--':
                one = left_exp.type(1)
                print('one', one)
                res = self.builder.sub(left_exp, one)
                if addr:  #++可以放在右值后
                    val = self.builder.store(res, addr)
                return left_exp, addr

    def visitUnaryExpression(self, ctx: tinycParser.UnaryExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.children[0])
        else:
            text = ctx.children[0].getText()
            val, addr = self.visit(ctx.unaryExpression())
            if text == '++':
                one = val.type(1)
                res = self.builder.add(val, one)
                self.builder.store(res, addr)
                return res, addr
            elif text == '--':
                one = val.type(1)
                res = self.builder.sub(val, one)
                self.builder.store(res, addr)
                return res, addr
            elif text == '+':
                return val, None
            elif text == '-':
                neg = self.builder.neg(val)
                print(
                    '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!neg is ',
                    neg)
                return neg, None

    # Visit a parse tree produced by tinycParser#multiplicativeExpression.
    def visitMultiplicativeExpression(
            self, ctx: tinycParser.MultiplicativeExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.children[0])
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            if op == '*':
                return self.builder.mul(lhs, rhs), None
            elif op == '/':
                return self.builder.sdiv(lhs, rhs), None
            else:
                return self.builder.srem(lhs, rhs), None

    # Visit a parse tree produced by tinycParser#additiveExpression.
    def visitAdditiveExpression(self,
                                ctx: tinycParser.AdditiveExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.multiplicativeExpression())
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            if op == '+':
                return self.builder.add(lhs, rhs), None
            else:
                return self.builder.sub(lhs, rhs), None

    # Visit a parse tree produced by tinycParser#shiftExpression.
    def visitShiftExpression(self, ctx: tinycParser.ShiftExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.additiveExpression())
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            if op == '<<':
                return self.builder.shl(lhs, rhs), None
            else:
                return self.builder.ashr(lhs, rhs), None

    # Visit a parse tree produced by tinycParser#relationalExpression.
    def visitRelationalExpression(
            self, ctx: tinycParser.RelationalExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.shiftExpression())
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            build = self.builder.icmp_signed(cmpop=op, lhs=lhs, rhs=rhs)
            print('relation build', build, lhs, rhs)

            return build, None

    # Visit a parse tree produced by tinycParser#equalityExpression.
    def visitEqualityExpression(self,
                                ctx: tinycParser.EqualityExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.relationalExpression())
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            return self.builder.icmp_signed(cmpop=op, lhs=lhs, rhs=rhs), None

    # Visit a parse tree produced by tinycParser#andExpression.
    def visitAndExpression(self, ctx: tinycParser.AndExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.equalityExpression())
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            return self.builder.and_(lhs, rhs), None

    # Visit a parse tree produced by tinycParser#exclusiveOrExpression.
    def visitExclusiveOrExpression(
            self, ctx: tinycParser.ExclusiveOrExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.andExpression())
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            return self.builder.xor(lhs, rhs), None

    # Visit a parse tree produced by tinycParser#inclusiveOrExpression.
    def visitInclusiveOrExpression(
            self, ctx: tinycParser.InclusiveOrExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.exclusiveOrExpression())
        else:
            lhs, laddr = self.visit(ctx.children[0])
            rhs, raddr = self.visit(ctx.children[2])
            op = ctx.children[1].getText()
            return self.builder.or_(lhs, rhs), None

    # Visit a parse tree produced by tinycParser#logicalAndExpression.
    def visitLogicalAndExpression(
            self, ctx: tinycParser.LogicalAndExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.inclusiveOrExpression())
        else:
            lhs = (self.visit(ctx.children[0])[0])
            result = self.builder.alloca(ir.IntType(1))
            converted = whether_is_true(self.builder, lhs)
            cond = LLVMTypes.bool(converted.get_reference())
            with self.builder.if_else(cond) as (then, otherwise):
                with then:
                    rhs, rhs_ptr = self.visit(ctx.inclusiveOrExpression())
                    converted_rhs = whether_is_true(self.builder, rhs)
                    self.builder.store(converted_rhs, result)
                with otherwise:
                    self.builder.store(ir.IntType(1)(0), result)
            return self.builder.load(result), result

    # Visit a parse tree produced by tinycParser#logicalOrExpression.
    def visitLogicalOrExpression(self,
                                 ctx: tinycParser.LogicalOrExpressionContext):
        if len(ctx.children) == 1:
            return self.visit(ctx.logicalAndExpression())
        else:
            lhs = (self.visit(ctx.children[0])[0])
            result = self.builder.alloca(ir.IntType(1))
            converted = whether_is_true(self.builder, lhs)
            cond = LLVMTypes.bool(converted.get_reference())
            with self.builder.if_else(cond) as (then, otherwise):
                with then:
                    self.builder.store(ir.IntType(1)(1), result)
                with otherwise:
                    rhs, rhs_ptr = self.visit(ctx.logicalAndExpression())
                    converted_rhs = whether_is_true(self.builder, rhs)
                    self.builder.store(converted_rhs, result)
            return self.builder.load(result), result

    # Visit a parse tree produced by tinycParser#conditionalExpression.
    def visitConditionalExpression(
            self, ctx: tinycParser.ConditionalExpressionContext):
        return self.visit(ctx.logicalOrExpression())

    # Visit a parse tree produced by tinycParser#assignmentOperator.
    def visitAssignmentOperator(self,
                                ctx: tinycParser.AssignmentOperatorContext):
        return ctx.getText()

    # Visit a parse tree produced by tinycParser#primaryExpression.
    def visitPrimaryExpression(self,
                               ctx: tinycParser.PrimaryExpressionContext):
        """return val and addr"""
        if ctx.IDENTIFIER():
            text = ctx.getText()
            addr = self.symbol_table.getSymbol(text)
            if addr:
                print("primary of addr", type(addr), addr)
                if type(addr) in [ir.Argument, ir.Function]:
                    print("why it is not ", addr)
                    #TODO:here is a function parameter bug
                    return addr, addr
                elif isinstance(addr.type.pointee, ir.ArrayType):
                    zero = ir.Constant(LLVMTypes.int, 0)
                    value = self.builder.gep(addr, [zero, zero])
                else:
                    print(f"{text}addr is ", addr)
                    value = self.builder.load(addr)
                return value, addr
            else:
                raise Exception('the identifier should be defined first')
        elif ctx.mString():
            text = self.visit(ctx.mString())
            idx = self.constants
            self.constants += 1
            text = text[1:-1]
            strlen = len(text) + 1
            print(f'strlen {strlen}')
            string = get_const_from_str('string', text)
            # print(string)
            # const = ir.GlobalVariable(self.module, ir.ArrayType(LLVMTypes.int8,strlen), ".str%d"%idx)
            # const.global_constant = True
            # const.initializer = string
            # zero = ir.Constant(LLVMTypes.int32, 0)
            # first = ir.Constant(ir.ArrayType, bytearray( ,'ascii'))
            return string, string
        elif ctx.CONSTANT():
            text = ctx.getText()
            print('const', text)
            const = get_const_from_str('int', text)
            return const, None
        elif ctx.expression():
            val = self.visit(ctx.expression())
            return val, None
        else:
            raise Exception('not supported')

    def visitMString(self, ctx: tinycParser.MStringContext):
        """ 将string或者char里面的\n修改了"""
        text = ctx.getText()
        text = formatString(text)
        return text