Ejemplo n.º 1
0
def main(infile_path: str, outfile_path: str):
    print(infile_path)
    vm_files = [path for path in glob.glob(infile_path + "/*.vm")]
    code_writer = CodeWriter(filepath=outfile_path)
    code_writer.write_init()
    for filepath in vm_files:
        print(filepath)
        parser = Parser(filepath=filepath)
        code_writer.set_file_path(filepath)
        while parser.hasMoreCommands():
            cmd = parser.commandType()
            if cmd == C_ARITHMETIC:
                code_writer.writeArithmetic(parser.arithmetic())
            elif cmd == C_PUSH or cmd == C_POP:
                code_writer.writePushPop(cmd,
                                         parser.arg1(),
                                         index=int(parser.arg2()))
            elif cmd == C_LABEL:
                code_writer.writeLabel(parser.arg1())
            elif cmd == C_GOTO:
                code_writer.writeGoto(parser.arg1())
            elif cmd == C_IF:
                code_writer.writeIf(parser.arg1())
            elif cmd == C_FUNCTION:
                code_writer.writeFunction(parser.arg1(), parser.arg2())
            elif cmd == C_RETURN:
                code_writer.writeReturn()
            elif cmd == C_CALL:
                code_writer.writeCall(parser.arg1(), parser.arg2())
            parser.advance()
    code_writer.close()
Ejemplo n.º 2
0
def main():
    '''Main entry point for the script.'''

    # For each .vm file, create a parser object
    filetrue = os.path.isfile(sys.argv[1])
    dirtrue = os.path.isdir(sys.argv[1])
    vmfiles = []

    # Rename directory as a ".asm" file for later use
    finame = os.path.basename(os.path.normpath(sys.argv[1])) + ".asm"

    # Get file path with .asm file appended
    dirname = os.path.join(sys.argv[1], finame)

    # Create list of files to convert and add to asm file
    if dirtrue:

        cw = CodeWriter(dirname)
        fi = os.listdir(sys.argv[1])

        for names in fi:

            if names.endswith(".vm"):
                vmfiles.append(sys.argv[1] + names)

    elif filetrue:

        di = sys.argv[1]

        if di.endswith(".vm"):

            vmfiles.append(di)
            tr = vmfiles[0]
            trs = tr.replace("vm", "asm")
            cw = CodeWriter(trs)

        else:
            print "invalid filetype: only input .vm files"

    else:
        print "usage: 'python <file.vm> or <dirname/>'"

    out = cw.constructor()
    cw.writeInit(out)

    with out as outfile:

        for files in vmfiles:

            # Create new instance of class Parser()
            p = cw.setFileName(files)

            with p.constructor() as infile:

                for line in infile:

                    if p.commandType(line) == "comments":

                        pass

                    elif p.commandType(line) == "C_ARITHMETIC":

                        cw.writeArithmetic(outfile, p.args(line)[0])

                    elif p.commandType(line) == "C_IF":

                        # Handle if-goto command
                        cw.writeIf(outfile, p.args(line)[1])

                    elif p.commandType(line) == "C_GOTO":

                        # Handle goto command
                        cw.writeGoto(outfile, p.args(line)[1])

                    elif p.commandType(line) == "C_RETURN":

                        # Return function result
                        cw.writeReturn(outfile)

                    elif p.commandType(line) == "C_LABEL":

                        # Set label address
                        cw.writeLabel(outfile, p.args(line)[1])

                    elif p.commandType(line) == "C_CALL":

                        # Handle function calls
                        cw.writeCall(outfile, p.args(line)[1], p.args(line)[2])

                    elif p.commandType(line) == "C_FUNCTION":

                        cw.writeFunction(outfile,
                                         p.args(line)[1],
                                         p.args(line)[2])

                    elif p.commandType(line) == "C_PUSH" or "C_POP":

                        cw.writePushPop(outfile, p.commandType(line),
                                        p.args(line)[1],
                                        p.args(line)[2])
class jackVisitor(jackGrammarVisitor):
    """Clase que hereda del visitor para ir escribiendo en lenguaje de maquina virtual"""

    def __init__(self):
        """Inciializa una tabald e simbolos y un ecritor de codigo junto con variables auxiliares"""
        self.symbolTable = SymbolTable()
        self.contWhile = -1
        self.contIf = -1
        self.nombreClase = ""
        self.kindMetodo = ""
        self.nombreMetodo = ""
        self.vmWriter = CodeWriter()
        self.vmWriter.vm = ""
        self.nArgs = 0

    def visitClasses(self, ctx):
        """Obtiene y guarda el nombre de la clase actualmente compilada"""
        self.nombreClase = ctx.children[1].children[0].getText()
        return self.visitChildren(ctx)

    def visitClassVarDec(self, ctx):
        """Guarda en la tabla de simbolos cada uno de los fields  variables taticas declaradas """
        kind = ctx.children[0].getText()
        tipo = ctx.children[1].children[0].getText()
        i = 2
        while ctx.children[i].getText() != ';': 
            name = ctx.children[i].getText()  
            if name == ',':
                pass
            else:
                self.symbolTable.define(name, tipo, kind)
            i +=1
        return self.visitChildren(ctx)

    def visitTypes(self, ctx):
        return self.visitChildren(ctx)

    def visitSubroutineDec(self, ctx):
        """Inicializa en la tabla de simbolos una subrotina, y en caso de se un metodo agrega this como parametro"""
        self.kindMetodo = ctx.children[0].getText()
        self.nombreMetodo = ctx.children[2].children[0].getText()
        self.symbolTable.startSubroutine()
        if self.kindMetodo == 'method':
            self.symbolTable.define('this', self.nombreMetodo, 'argument')
        return self.visitChildren(ctx)

    def visitParameterList(self, ctx):
        """Agrega a la tabla de simbolos de la subroutina cada uno de los parametros """
        if ctx.getChildCount() > 0:
            tipo = ctx.children[0].children[0].getText()
            nombre = ctx.children[1].children[0].getText()
            self.symbolTable.define(nombre, tipo, 'argument')
            i = 2
            while i < len(ctx.children)-1  and ctx.children[i].getText() != ')':
                tipo = ctx.children[i+1].getText()
                nombre = ctx.children[i+2].getText()
                self.symbolTable.define(nombre, tipo, 'argument')
                i+=3
        return self.visitChildren(ctx)

    def visitSubroutineBody(self, ctx):
        """Despues de contar las variables locales escribe la funcion en 
        maquina virtual y dependiendo del tipo de funcion hace los llamados, push y pop correspondientes"""
        i = 1
        while ctx.children[i].children[0].getText() == "var":
            self.visit(ctx.children[i])
            i += 1
        funcion = self.nombreClase +'.'+ self.nombreMetodo
        numLcl = self.symbolTable.varCount('local')
        self.vmWriter.writeFunction(funcion, numLcl)
        if self.kindMetodo == 'constructor':
            numFields = self.symbolTable.varCount('field')
            self.vmWriter.writePush('constant', numFields)
            self.vmWriter.writeCall('Memory.alloc', 1)
            self.vmWriter.writePop('pointer', 0)
        elif self.kindMetodo == 'method':
            self.vmWriter.writePush('argument', 0)
            self.vmWriter.writePop('pointer', 0)
        while i < ctx.getChildCount():
            self.visit(ctx.children[i])
            i += 1

    def visitVarDec(self, ctx):
        """Inicializa en la tabla de simbolos todas las variables locales de la subrutina para poder escribir la función"""
        tipo = ctx.children[1].children[0].getText()
        nombre = ctx.children[2].getText()
        self.symbolTable.define(nombre, tipo, 'local')
        i = 3
        while ctx.children[i].getText() != ';':
            nombre = ctx.children[i].getText()
            if nombre == ',':
                pass
            else:
                self.symbolTable.define(nombre, tipo, 'local')
            i += 1
        return self.visitChildren(ctx)

    """Llamados en los que no es necesario  escribir codigo de VM"""
    def visitClassName(self, ctx):
        return self.visitChildren(ctx)

    def visitSubroutineName(self, ctx):
        return self.visitChildren(ctx)

    def visitVarName(self, ctx):
        return self.visitChildren(ctx)

    def visitStatements(self, ctx):
        return self.visitChildren(ctx)

    def visitStatement(self, ctx):
        return self.visitChildren(ctx)

    def visitLetStatement(self, ctx): 
        """Realiza los push y pop necesarios para guardar un valor y asignarle una posiicon en memoria"""
        nombre = ctx.children[1].getText()
        tipo = self.symbolTable.kindOf(nombre)
        index = self.symbolTable.indexOf(nombre)
        if tipo  == None:
            tipo = self.symbolTable.kindOf(nombre)
            index = self.symbolTable.indexOf(nombre)
        if ctx.children[2].getText() == '[':
            self.visit(ctx.children[3])
            self.vmWriter.writePush(tipo,index)
            self.vmWriter.writeArithmetic('add')
            self.visit(ctx.children[6]) 
            self.vmWriter.writePop('temp', 0)           
            self.vmWriter.writePop('pointer', 1)
            self.vmWriter.writePush('temp', 0)
            self.vmWriter.writePop('that', 0)
        else:
            self.visit(ctx.children[3])
            self.vmWriter.writePop(tipo,index)

    def visitIfStatement(self, ctx):
        """Escribe los labels necesarios para manejar el flujo del programa de a cuerdo a lo indicado por la expresión"""
        self.contIf += 1
        cont = self.contIf
        self.visit(ctx.children[2])
        self.vmWriter.writeIf('IF_TRUE' + str(cont))
        self.vmWriter.writeGoto('IF_FALSE' + str(cont))
        self.vmWriter.writeLabel('IF_TRUE' + str(cont))
        self.visit(ctx.children[5])
        if ctx.getChildCount() > 7 :
            if str(ctx.children[7]) == 'else':
                self.vmWriter.writeGoto('IF_END' + str(cont))
                self.vmWriter.writeLabel('IF_FALSE' + str(cont))
                self.visit(ctx.children[9])
                self.vmWriter.writeLabel('IF_END' + str(cont))
        else:
            self.vmWriter.writeLabel('IF_FALSE' + str(cont))

    def visitWhileStatement(self, ctx):
        """Similar al if, escribe labels para que el flujo del programa se repita hasta que una condicion no se cumpla"""
        self.contWhile += 1 
        contW = self.contWhile
        self.vmWriter.writeLabel('WHILE_EXP' + str(contW))
        self.visit(ctx.children[2])
        self.vmWriter.writeArithmetic('not')
        self.vmWriter.writeIf('WHILE_END' + str(contW))
        self.visit(ctx.children[5])
        self.vmWriter.writeGoto('WHILE_EXP' + str(contW))
        self.vmWriter.writeLabel('WHILE_END' + str(contW))

    def visitDoStatement(self, ctx):
        """Hago el llamado y posteriormente vuelvo a la función de donde hice el llamado"""
        self.visitChildren(ctx)
        self.vmWriter.writePop('temp', 0)

    def visitReturnStatement(self, ctx):
        """Obtengo valor de retorno, si no hay, el valor de retorno es 0"""
        if ctx.children[1].getText() != ';':
            self.visit(ctx.children[1])
        else:
            self.vmWriter.writePush('constant', 0)
        self.vmWriter.writeReturn()

    def visitExpression(self, ctx):
        """Separo al expresion por partes para irla compilando"""
        self.visit(ctx.children[0])
        i = 2
        while i < ctx.getChildCount():
            self.visit(ctx.children[i])
            self.visit(ctx.children[i-1])
            i +=2

    def visitTerm(self, ctx):
        """Determino el tipo de termino,si es un tipo de dato o un valor de un arreglo, dependiendo de esto obtengo 
        su valor si está en la tabla de simbolos o lo busco en un arreglo o busco el siguiente etrmino con el que opera y lo guardo en memoria"""
        term = ctx.children[0].getText()
        if ctx.getChildCount() == 1:
            if term.isdigit():
                self.vmWriter.writePush('constant', term)
            elif term.startswith('"'):
                term = term.strip('"')
                tam = len(term)
                self.vmWriter.writePush('constant', tam)
                self.vmWriter.writeCall('String.new', 1)
                for char in term:
                    self.vmWriter.writePush('constant', ord(char))
                    self.vmWriter.writeCall('String.appendChar', 2)
            elif term in ['true', 'false', 'null', 'this']:
                self.visitChildren(ctx)
            elif term in self.symbolTable.subrutina.keys():
                tipo = self.symbolTable.kindOf(term)
                index = self.symbolTable.indexOf(term)
                self.vmWriter.writePush(tipo,index)
            elif term in self.symbolTable.clase.keys():
                tipo = self.symbolTable.kindOf(term)
                index = self.symbolTable.indexOf(term)
                self.vmWriter.writePush(tipo,index)
            else:
                self.visitChildren(ctx) 
        else:
            var = ctx.children[0].getText()
            if ctx.children[1].getText() == '[':
                index = self.symbolTable.indexOf(var)
                segment = self.symbolTable.kindOf(var)
                self.visit(ctx.children[2])
                self.vmWriter.writePush(segment, index)
                self.vmWriter.writeArithmetic('add')
                self.vmWriter.writePop('pointer', '1')
                self.vmWriter.writePush('that', '0')
            elif term == '(':
                self.visitChildren(ctx)
            elif term  == '-':
                self.visit(ctx.children[1])
                self.visit(ctx.children[0])
            elif term  == '~':
                self.visit(ctx.children[1])
                self.visit(ctx.children[0])

    def visitSubroutineCall(self, ctx):
        """Ubica la subrutina de acuerdo a la clase en la que se encuentre y escribe en VM el respectivo llamado con su paso de parametros"""
        nombre = ctx.children[0].children[0].getText()
        funcion = nombre
        args = 0
        if ctx.children[1].getText() == '.':    
            nombreSubrutina = ctx.children[2].children[0].getText()
            tipo = self.symbolTable.typeOf(nombre)
            if tipo != None:
                kind = self.symbolTable.kindOf(nombre)
                index = self.symbolTable.indexOf(nombre)
                self.vmWriter.writePush(kind, index)
                funcion = tipo + '.' + nombreSubrutina
                args += 1
            else: 
                funcion = nombre + '.' + nombreSubrutina
        elif ctx.children[1].getText() == '(':
            funcion =  self.nombreClase + '.' + nombre
            args += 1
            self.vmWriter.writePush('pointer', 0)
        self.visitChildren(ctx)
        args = args +self.nArgs
        self.vmWriter.writeCall(funcion, args)

    def visitExpressionList(self, ctx):
        """Evalua cada expresion indivudualmente"""
        self.nArgs = 0
        if ctx.getChildCount() > 0:
            self.nArgs = 1
            self.visit(ctx.children[0])
            i = 2
            while i < ctx.getChildCount():
                self.visit(ctx.children[i])
                self.visit(ctx.children[i-1])
                self.nArgs += 1
                i += 2

    def visitOp(self, ctx):
        """Genera el comando de VM respectivo dependiendo del operador"""
        op = ctx.children[0].getText()
        if op == "+":
            self.vmWriter.writeArithmetic('add')
        elif op == "-":
            self.vmWriter.writeArithmetic('sub')
        elif op == "*":
            self.vmWriter.writeArithmetic('call Math.multiply 2')
        elif op == "/":
            self.vmWriter.writeArithmetic('call Math.divide 2')
        elif op == "&":
            self.vmWriter.writeArithmetic('and')
        elif op == "|":
            self.vmWriter.writeArithmetic('or')
        elif op == ">":
            self.vmWriter.writeArithmetic('gt')
        elif op == "<":
            self.vmWriter.writeArithmetic('lt')
        elif op == "=":
            self.vmWriter.writeArithmetic('eq')
        return self.visitChildren(ctx)

    def visitUnaryop(self, ctx):
        """Determina el comando de VM para cada operaodr unario"""
        op = ctx.children[0].getText()
        if op == "~":
            self.vmWriter.writeArithmetic('not')
        elif op == "-":
            self.vmWriter.writeArithmetic('neg')

    def visitKeywordconstant(self, ctx):
        """Escribe el comando de VM para poder hacer uso de una palabra reservada espcifica"""
        keyword = ctx.children[0].getText()
        if keyword == 'this':
            self.vmWriter.writePush('pointer', 0)
        elif keyword in ['false','null']:
            self.vmWriter.writePush('constant', 0)
        elif keyword == 'true':
            self.vmWriter.writePush('constant', 0)
            self.vmWriter.writeArithmetic('not')
        return self.visitChildren(ctx)

    def crearArchivo(self,path):
        """Abre el archivo .vm donde se escribirán lso comandos de máquina virtual"""
        filewrite = path.split('.jack') #Reemplazo el .jack con .xml si lo tiene 
        filewritef = filewrite[0]+'.vm'  #Sino le agrego el .
        codigoVM = self.vmWriter.vm
        archivo = filewritef
        try:
            file = open(archivo,'w')  #Abro el file en modo escribir
        except FileNotFoundError:
            print('ERROR:No hay directorio existente para escribir')   
            exit(1) 
        file.write(codigoVM)
Ejemplo n.º 4
0
class CodeWriterTest(unittest.TestCase):
    def setUp(self):
        self.tmpdir = TemporaryDirectory()
        os.chdir(self.tmpdir.name)
        self.assembly_filename = 'test_output'
        self.code_writer = CodeWriter(self.assembly_filename)
        self.maxDiff = None


    def __del__(self):
        self.tmpdir.cleanup()


    def test_add(self):
        self.code_writer.writeArithmetic('add')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@SP',
            'M=M-1',
            'A=M',
            'D=D+M',
            'M=D',
            '@SP',
            'M=M+1'
        ])


    def test_sub(self):
        self.code_writer.writeArithmetic('sub')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@SP',
            'M=M-1',
            'A=M',
            'D=M-D',
            'M=D',
            '@SP',
            'M=M+1'
        ])


    def test_neq(self):
        self.code_writer.writeArithmetic('neg')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'M=-M',
            '@SP',
            'M=M+1'
        ])


    def test_and(self):
        self.code_writer.writeArithmetic('and')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@SP',
            'M=M-1',
            'A=M',
            'M=M&D',
            '@SP',
            'M=M+1'
        ])


    def test_or(self):
        self.code_writer.writeArithmetic('or')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@SP',
            'M=M-1',
            'A=M',
            'M=M|D',
            '@SP',
            'M=M+1'
        ])


    def test_not(self):
        self.code_writer.writeArithmetic('not')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'M=!M',
            '@SP',
            'M=M+1'
        ])


    def test_eq(self):
        self.code_writer.writeArithmetic('eq')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@SP',
            'M=M-1',
            'A=M',
            'D=D-M',
            '@test_output.1',
            'D;JEQ',

            '@test_output.0',
            '0;JMP',

            '(test_output.1)',
            '@SP',
            'A=M',
            'M=-1',
            '@test_output.2',
            '0;JMP',

            '(test_output.0)',
            '@SP',
            'A=M',
            'M=0',
            '@test_output.2',
            '0;JMP',

            '(test_output.2)',
            '@SP',
            'M=M+1'
        ])


    def test_lt(self):
        self.code_writer.writeArithmetic('lt')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@SP',
            'M=M-1',
            'A=M',
            'D=M-D',
            '@test_output.1',
            'D;JLT',

            '@test_output.0',
            '0;JMP',

            '(test_output.1)',
            '@SP',
            'A=M',
            'M=-1',
            '@test_output.2',
            '0;JMP',

            '(test_output.0)',
            '@SP',
            'A=M',
            'M=0',
            '@test_output.2',
            '0;JMP',

            '(test_output.2)',
            '@SP',
            'M=M+1'
        ])


    def test_gt(self):
        self.code_writer.writeArithmetic('gt')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@SP',
            'M=M-1',
            'A=M',
            'D=D-M',
            '@test_output.1',
            'D;JLT',

            '@test_output.0',
            '0;JMP',

            '(test_output.1)',
            '@SP',
            'A=M',
            'M=-1',
            '@test_output.2',
            '0;JMP',

            '(test_output.0)',
            '@SP',
            'A=M',
            'M=0',
            '@test_output.2',
            '0;JMP',

            '(test_output.2)',
            '@SP',
            'M=M+1'
        ])


    def test_write_push_pop_ignores_trailing_comments(self):
        self.code_writer.writePushPop('push constant 37 // this is a comment')
        self.assertGeneratedAssemblyEqual([
            '@37',
            'D=A',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1'
        ])

    def test_push_constant(self):
        self.code_writer.writePushPop('push constant 37')
        self.assertGeneratedAssemblyEqual([
            '@37',
            'D=A',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1'
        ])

    def test_push_static(self):
        self.code_writer.writePushPop('push static 37')
        self.assertGeneratedAssemblyEqual([
            '@test_output.37',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1'
        ])

    def test_push_argument(self):
        self.code_writer.writePushPop('push argument 800')
        self.assertGeneratedAssemblyEqual([
            '@ARG',
            'D=M',
            '@800',
            'D=A+D',
            'A=D',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1'
        ])

    def test_push_temp(self):
        self.code_writer.writePushPop('push temp 876')
        self.assertGeneratedAssemblyEqual([
            '@5',
            'D=A',
            '@876',
            'D=A+D',
            'A=D',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1'
        ])

    def test_push_pointer_0(self):
        self.code_writer.writePushPop('push pointer 0')
        self.assertGeneratedAssemblyEqual([
            '@THIS',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1'
        ])

    def test_push_pointer_1(self):
        self.code_writer.writePushPop('push pointer 1')
        self.assertGeneratedAssemblyEqual([
            '@THAT',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1'
        ])


    def test_pop_static(self):
        self.code_writer.writePushPop('pop static 37')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@test_output.37',
            'M=D'
        ])

    def test_pop_argument(self):
        self.code_writer.writePushPop('pop argument 800')
        self.assertGeneratedAssemblyEqual([
            '@ARG',
            'D=M',
            '@800',
            'D=A+D',
            '@R13',
            'M=D',
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@R13',
            'A=M',
            'M=D'
        ])

    def test_pop_temp(self):
        self.code_writer.writePushPop('pop temp 876')
        self.assertGeneratedAssemblyEqual([
            '@5',
            'D=A',
            '@876',
            'D=A+D',
            '@R13',
            'M=D',
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@R13',
            'A=M',
            'M=D'
        ])

    def test_pop_pointer_0(self):
        self.code_writer.writePushPop('pop pointer 0')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@THIS',
            'M=D'
        ])

    def test_pop_pointer_1(self):
        self.code_writer.writePushPop('pop pointer 1')
        self.assertGeneratedAssemblyEqual([
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@THAT',
            'M=D'
        ])


    def test_write_function(self):
        self.code_writer.writeFunction('function SimpleFunction.test 2')
        self.assertGeneratedAssemblyEqual([
            '(test_output.SimpleFunction.test)',
            '@SP',
            'D=M',
            'A=D',
            'M=0',
            '@SP',
            'M=M+1',
            '@SP',
            'D=M',
            'A=D',
            'M=0',
            '@SP',
            'M=M+1'
        ])


    def test_write_return(self):
        self.code_writer.writeReturn()
        self.assertGeneratedAssemblyEqual([
            '@LCL',
            'D=M',
            '@R13',
            'M=D',
            'D=D-1',
            'D=D-1',
            'D=D-1',
            'D=D-1',
            'D=D-1',
            'A=D',
            'D=M',
            '@R14',
            'M=D',
            '@SP',
            'M=M-1',
            'A=M',
            'D=M',
            '@ARG',
            'A=M',
            'M=D',
            '@ARG',
            'D=M+1',
            '@SP',
            'M=D',
            '@R13',
            'M=M-1',
            'A=M',
            'D=M',
            '@THAT',
            'M=D',
            '@R13',
            'M=M-1',
            'A=M',
            'D=M',
            '@THIS',
            'M=D',
            '@R13',
            'M=M-1',
            'A=M',
            'D=M',
            '@ARG',
            'M=D',
            '@R13',
            'M=M-1',
            'A=M',
            'D=M',
            '@LCL',
            'M=D',
            '@R14',
            'A=M',
            '0;JMP'
        ])


    def test_write_call(self):
        self.code_writer.writeCall('call Sys.init 0')
        self.assertGeneratedAssemblyEqual([
            '@test_output.Sys.init.return',
            'D=A',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1',
            '@LCL',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1',
            '@ARG',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1',
            '@THIS',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1',
            '@THAT',
            'D=M',
            '@SP',
            'A=M',
            'M=D',
            '@SP',
            'M=M+1',
            '@SP',
            'D=M',
            'D=D-1',
            'D=D-1',
            'D=D-1',
            'D=D-1',
            'D=D-1',
            '@ARG',
            'M=D',
            '@SP',
            'D=M',
            '@LCL',
            'M=D',
            '@test_output.Sys.init',
            '0;JMP',
            '@9999',
            '(test_output.Sys.init.return)'
        ])


    def test_write_bootstrap(self):
        self.code_writer.writeBootstrap()
        self.assertGeneratedAssemblyEqual([
            '@256',
            'D=A',
            '@SP',
            'M=D'
        ])


    def assertGeneratedAssemblyEqual(self, assembly=[]):
        self.assertAssemblyEqual(self.assembly_filename, assembly)


    def assertAssemblyEqual(self, file_name, assembly=[]):
        with open(file_name, 'r') as f:
            self.assertEqual([x.rstrip() for x in f.readlines()], assembly)
Ejemplo n.º 5
0
def main(path):
    """Entry point for the vm translator."""
    vm_files_paths = get_vm_files(path)

    if not os.path.exists(path):
        print("invalid file path")
        sys.exit(1)

    if os.path.isdir(path):
        dirname = os.path.dirname(path)
        name = os.path.basename(dirname)
        path = f"{dirname}/{name}"
        isdir = True
    else:
        path = os.path.splitext(path)[0]
        isdir = False

    # Create single code write module
    code_writer = CodeWriter(f"{path}.asm")

    if isdir:
        code_writer.writeInit()

    for vm_file_path in vm_files_paths:
        filestream = open(vm_file_path, "r")
        parser = Parser(filestream)
        filestream.close()

        # write to assembly file
        code_writer.setFileName(os.path.basename(vm_file_path))

        while parser.hasMoreCommands():
            parser.advance()
            command_type = parser.commandType()

            if (command_type == CommandType.C_PUSH or
                    command_type == CommandType.C_POP):
                segment = parser.arg1()
                index = parser.arg2()
                code_writer.writePushPop(
                    command_type,
                    segment,
                    int(index)
                )
            elif command_type == CommandType.C_ARITHMETIC:
                command = parser.arg1()
                code_writer.writeArithmetic(command)
            elif command_type == CommandType.C_LABEL:
                label = parser.arg1()
                code_writer.writeLabel(label)
            elif command_type == CommandType.C_GOTO:
                label = parser.arg1()
                code_writer.writeGoto(label)
            elif command_type == CommandType.C_IF:
                label = parser.arg1()
                code_writer.writeIf(label)
            elif command_type == CommandType.C_FUNCTION:
                label = parser.arg1()
                number_of_locals = int(parser.arg2())
                code_writer.writeFunction(label, number_of_locals)
            elif command_type == CommandType.C_RETURN:
                code_writer.writeReturn()
            elif command_type == CommandType.C_CALL:
                functioname = parser.arg1()
                number_of_args = int(parser.arg2())
                code_writer.writeCall(functioname, number_of_args)

        print(code_writer.filestream.get_global_counter())
    code_writer.close()