def run(self): # Parse and generate CST for the input lexer = SeeDotLexer(self.input) tokens = CommonTokenStream(lexer) parser = SeeDotParser(tokens) tree = parser.expr() # Generate AST ast = ASTBuilder.ASTBuilder().visit(tree) # Pretty printing AST # PrintAST().visit(ast) # Perform type inference InferType().visit(ast) IRUtil.init() res, state = self.compile(ast) writer = Writer(self.outputFile) if forArduino(): codegen = ArduinoCodegen(writer, *state) elif forX86(): codegen = X86Codegen(writer, *state) else: assert False codegen.printAll(*res) writer.close()
def printVarDecls(self): varsFilePath = os.path.join(self.outputDir, "vars_" + getVersion() + ".h") varsFile = Writer(varsFilePath) varsFile.printf("#pragma once\n\n") varsFile.printf("#include \"datatypes.h\"\n\n") varsFile.printf("namespace vars_%s {\n" % (getVersion())) varsFile.increaseIndent() for decl in self.decls: if decl in self.globalVars: continue if forFloat() and decl not in self.internalVars: typ_str = IR.DataType.getFloatStr() else: typ_str = IR.DataType.getIntStr() idf_str = decl type = self.decls[decl] if Type.isInt(type): shape_str = '' elif Type.isTensor(type): shape_str = ''.join(['[' + str(n) + ']' for n in type.shape]) self.out.printf('%s vars_%s::%s%s;\n', typ_str, getVersion(), idf_str, shape_str, indent=True) varsFile.printf('extern %s %s%s;\n', typ_str, idf_str, shape_str, indent=True) self.out.printf('\n') varsFile.decreaseIndent() varsFile.printf("}\n") varsFile.close() self.generateDebugProgram()
def generateDebugProgram(self): if not self.generateAllFiles: return debugFilePath = os.path.join(self.outputDir, "debug.cpp") debugFile = Writer(debugFilePath) debugFile.printf("#include <iostream>\n\n") debugFile.printf("#include \"datatypes.h\"\n") debugFile.printf("#include \"profile.h\"\n") debugFile.printf("#include \"vars_fixed.h\"\n") debugFile.printf("#include \"vars_float.h\"\n\n") debugFile.printf("using namespace std;\n\n") debugFile.printf("void debug() {\n\n") if debugMode() and forFixed(): debugFile.increaseIndent() for decl in self.decls: if decl in self.globalVars: continue type = self.decls[decl] if decl not in self.scales or not isinstance( type, Type.Tensor) or type.isShapeOne(): continue scale = self.scales[decl] s = decl + "[0]" * type.dim shape_str = ''.join([str(n) + ', ' for n in type.shape]) shape_str = shape_str.rstrip(', ') debugFile.printf( "diff(&vars_float::%s, &vars_fixed::%s, %d, %s);\n\n" % (s, s, scale, shape_str), indent=True) debugFile.decreaseIndent() debugFile.printf("}\n") debugFile.close()
class X86(CodegenBase): def __init__(self, outputDir, decls, scales, intvs, cnsts, expTables, globalVars, internalVars, floatConstants): self.outputDir = outputDir cppFile = os.path.join(self.outputDir, "seedot_" + getVersion() + ".cpp") self.out = Writer(cppFile) self.decls = decls self.scales = scales self.intvs = intvs self.cnsts = cnsts self.expTables = expTables self.globalVars = globalVars self.internalVars = internalVars self.floatConstants = floatConstants def printPrefix(self): self.printCincludes() self.printExpTables() self.printVarDecls() self.printCHeader() self.printConstDecls() self.out.printf('\n') def printCincludes(self): self.out.printf('#include <iostream>\n', indent=True) self.out.printf('#include <cstring>\n', indent=True) self.out.printf('#include <cmath>\n\n', indent=True) self.out.printf('#include "datatypes.h"\n', indent=True) self.out.printf('#include "predictors.h"\n', indent=True) self.out.printf('#include "profile.h"\n', indent=True) self.out.printf('#include "library_%s.h"\n' % (getVersion()), indent=True) self.out.printf('#include "model_%s.h"\n' % (getVersion()), indent=True) self.out.printf('#include "vars_%s.h"\n\n' % (getVersion()), indent=True) self.out.printf('using namespace std;\n', indent=True) self.out.printf('using namespace seedot_%s;\n' % (getVersion()), indent=True) self.out.printf('using namespace vars_%s;\n\n' % (getVersion()), indent=True) def printExpTables(self): for exp, [table, [tableVarA, tableVarB]] in self.expTables.items(): self.printExpTable(table[0], tableVarA) self.printExpTable(table[1], tableVarB) self.out.printf('\n') def printExpTable(self, table_row, var): self.out.printf('const MYINT %s[%d] = {\n' % (var.idf, len(table_row)), indent=True) self.out.increaseIndent() self.out.printf('', indent=True) for i in range(len(table_row)): self.out.printf('%d, ' % table_row[i]) self.out.decreaseIndent() self.out.printf('\n};\n') def printCHeader(self): if forFloat(): func = "Float" type = "float" else: func = "Fixed" type = "MYINT" self.out.printf('int seedot%s(%s **X) {\n' % (func, type), indent=True) self.out.increaseIndent() def printVarDecls(self): varsFilePath = os.path.join(self.outputDir, "vars_" + getVersion() + ".h") varsFile = Writer(varsFilePath) varsFile.printf("#pragma once\n\n") varsFile.printf("#include \"datatypes.h\"\n\n") varsFile.printf("namespace vars_%s {\n" % (getVersion())) varsFile.increaseIndent() for decl in self.decls: if decl in self.globalVars: continue if forFloat() and decl not in self.internalVars: typ_str = IR.DataType.getFloatStr() else: typ_str = IR.DataType.getIntStr() idf_str = decl type = self.decls[decl] if Type.isInt(type): shape_str = '' elif Type.isTensor(type): shape_str = ''.join(['[' + str(n) + ']' for n in type.shape]) self.out.printf('%s vars_%s::%s%s;\n', typ_str, getVersion(), idf_str, shape_str, indent=True) varsFile.printf('extern %s %s%s;\n', typ_str, idf_str, shape_str, indent=True) self.out.printf('\n') varsFile.decreaseIndent() varsFile.printf("}\n") varsFile.close() self.generateDebugProgram() def generateDebugProgram(self): debugFilePath = os.path.join(self.outputDir, "debug.cpp") debugFile = Writer(debugFilePath) debugFile.printf("#include <iostream>\n\n") debugFile.printf("#include \"datatypes.h\"\n") debugFile.printf("#include \"profile.h\"\n") debugFile.printf("#include \"vars_fixed.h\"\n") debugFile.printf("#include \"vars_float.h\"\n\n") debugFile.printf("using namespace std;\n\n") debugFile.printf("void debug() {\n\n") if debugMode() and forFixed(): debugFile.increaseIndent() for decl in self.decls: if decl in self.globalVars: continue type = self.decls[decl] if decl not in self.scales or not isinstance( type, Type.Tensor) or type.isShapeOne(): continue scale = self.scales[decl] s = decl + "[0]" * type.dim shape_str = ''.join([str(n) + ', ' for n in type.shape]) shape_str = shape_str.rstrip(', ') debugFile.printf( "diff(&vars_float::%s, &vars_fixed::%s, %d, %s);\n\n" % (s, s, scale, shape_str), indent=True) debugFile.decreaseIndent() debugFile.printf("}\n") debugFile.close() def printSuffix(self, expr: IR.Expr): self.out.printf('\n') type = self.decls[expr.idf] if Type.isInt(type): self.out.printf('return ', indent=True) self.print(expr) self.out.printf(';\n') elif Type.isTensor(type): idfr = expr.idf exponent = self.scales[expr.idf] num = 2**exponent if type.dim == 0: self.out.printf('cout << ', indent=True) self.out.printf('float(' + idfr + ')*' + str(num)) self.out.printf(' << endl;\n') else: iters = [] for i in range(type.dim): s = chr(ord('i') + i) tempVar = IR.Var(s) iters.append(tempVar) expr_1 = IRUtil.addIndex(expr, iters) cmds = IRUtil.loop(type.shape, iters, [IR.PrintAsFloat(expr_1, exponent)]) self.print(IR.Prog(cmds)) else: assert False self.out.decreaseIndent() self.out.printf('}\n', indent=True) self.out.close()
class Arduino(CodegenBase): def __init__(self, outputDir, decls, scales, intvs, cnsts, expTables, globalVars, internalVars, floatConstants): outputFile = os.path.join(outputDir, "predict.cpp") self.out = Writer(outputFile) self.decls = decls self.scales = scales self.intvs = intvs self.cnsts = cnsts self.expTables = expTables self.globalVars = globalVars self.internalVars = internalVars self.floatConstants = floatConstants def printPrefix(self): self.printArduinoIncludes() self.printExpTables() self.printArduinoHeader() self.printVarDecls() self.printConstDecls() self.out.printf('\n') def printArduinoIncludes(self): self.out.printf('#include <Arduino.h>\n\n', indent=True) self.out.printf('#include "config.h"\n', indent=True) self.out.printf('#include "predict.h"\n', indent=True) self.out.printf('#include "library.h"\n', indent=True) self.out.printf('#include "model.h"\n\n', indent=True) self.out.printf('using namespace model;\n\n', indent=True) # Dumps the generated look-up table for computing exponentials. def printExpTables(self): for exp, [table, [tableVarA, tableVarB]] in self.expTables.items(): self.printExpTable(table[0], tableVarA) self.printExpTable(table[1], tableVarB) self.out.printf('\n') def printExpTable(self, table_row, var): self.out.printf('const PROGMEM MYINT %s[%d] = {\n' % (var.idf, len(table_row)), indent=True) self.out.increaseIndent() self.out.printf('', indent=True) for i in range(len(table_row)): self.out.printf('%d, ' % table_row[i]) self.out.decreaseIndent() self.out.printf('\n};\n') def printArduinoHeader(self): self.out.printf('int predict() {\n', indent=True) self.out.increaseIndent() # Generate the appropriate return experssion # If integer, return the integer # If tensor of size 0, convert the fixed-point integer to float and return the float value. # If tensor of size >0, convert the tensor to fixed-point integer, print # it to the serial port, and return void. def printSuffix(self, expr: IR.Expr): self.out.printf('\n') type = self.decls[expr.idf] if Type.isInt(type): self.out.printf('return ', indent=True) self.print(expr) self.out.printf(';\n') elif Type.isTensor(type): idfr = expr.idf exponent = self.scales[expr.idf] num = 2**exponent if type.dim == 0: self.out.printf('Serial.println(', indent=True) self.out.printf('float(' + idfr + ')*' + str(num)) self.out.printf(', 6);\n') else: iters = [] for i in range(type.dim): s = chr(ord('i') + i) tempVar = IR.Var(s) iters.append(tempVar) expr_1 = IRUtil.addIndex(expr, iters) cmds = IRUtil.loop(type.shape, iters, [IR.PrintAsFloat(expr_1, exponent)]) self.print(IR.Prog(cmds)) else: assert False self.out.decreaseIndent() self.out.printf('}\n', indent=True) self.out.close() ''' Below functions are overriding their corresponding definitions in codegenBase.py. These function have arduino-specific print functions. ''' # Print the variable with pragmas def printVar(self, ir): if ir.inputVar: if config.wordLength == 16: self.out.printf('((MYINT) pgm_read_word_near(&') elif config.wordLength == 32: self.out.printf('((MYINT) pgm_read_dword_near(&') else: assert False self.out.printf('%s', ir.idf) for e in ir.idx: self.out.printf('[') self.print(e) self.out.printf(']') if ir.inputVar: self.out.printf('))') # The variable X is used to define the data point. # It is either read from the serial port or from the device's memory based on the operating mode. # The getIntFeature() function reads the appropriate value of X based on the mode. def printAssn(self, ir): if isinstance(ir.e, IR.Var) and ir.e.idf == "X": self.out.printf("", indent=True) self.print(ir.var) if forFixed(): self.out.printf(" = getIntFeature(i0);\n") else: self.out.printf(" = getFloatFeature(i0);\n") else: super().printAssn(ir) def printFuncCall(self, ir): self.out.printf("%s(" % ir.name, indent=True) keys = list(ir.argList) for i in range(len(keys)): arg = keys[i] # Do not print the 'X' variable as it will be read from the getIntFeature() function. if isinstance(arg, IR.Var) and arg.idf == 'X': continue # The value of x in the below code is the number of special characters (& and []) around the variable in the function call. # This number depends on the shape of the variable. # Example: A[10][10] is written as &A[0][0]. The value of x in this case is 2. # x is 0 for constants # x is -1 for integer variables where only & is printed and not [] if isinstance(arg, IR.Var) and arg.idf in self.decls.keys( ) and not arg.idf == 'X': type = self.decls[arg.idf] if isinstance(type, Type.Tensor): if type.dim == 0: x = -1 else: x = type.dim - len(arg.idx) else: x = -1 else: x = 0 if x != 0: self.out.printf("&") self.print(arg) if x != 0 and x != -1: self.out.printf("[0]" * x) if i != len(keys) - 1: self.out.printf(", ") self.out.printf(");\n\n") def printPrint(self, ir): self.out.printf('Serial.println(', indent=True) self.print(ir.expr) self.out.printf(');\n') def printPrintAsFloat(self, ir): self.out.printf('Serial.println(float(', indent=True) self.print(ir.expr) self.out.printf(') * ' + str(2**ir.expnt) + ', 6);')
class Arduino(CodegenBase): def __init__(self, outputDir, decls, localDecls, scales, intvs, cnsts, expTables, globalVars, internalVars, floatConstants, substitutions, demotedVarsOffsets, varsForBitwidth, varLiveIntervals, notScratch): outputFile = os.path.join(outputDir, "predict.cpp") self.outputDir = outputDir self.out = Writer(outputFile) self.decls = decls self.localDecls = localDecls self.scales = scales self.intvs = intvs self.cnsts = cnsts self.expTables = expTables self.globalVars = globalVars self.internalVars = internalVars self.floatConstants = floatConstants self.demotedVarsOffsets = demotedVarsOffsets self.varsForBitwidth = varsForBitwidth self.varLiveIntervals = varLiveIntervals self.scratchSubs = {} self.notScratch = notScratch self.currentRAMestimate = 0 self.maxRAMestimate = 0 def printCompilerConfig(self): configFile = os.path.join(self.outputDir, "compileConfig.h") with open(configFile, "w") as file: file.write( "// The datatype of the fixed-point representation is specified below\n" ) file.write("#define INT%d\n" % config.wordLength) if forFloat(): file.write("#define XFLOAT\n") else: if config.vbwEnabled: file.write("#define XINT%d\n" % self.varsForBitwidth['X']) else: file.write("#define XINT%d\n" % config.wordLength) if isSaturate(): file.write("#define SATURATE\n") else: file.write("//#define SATURATE\n") if isfastApprox(): file.write("#define FASTAPPROX\n") else: file.write("//#define FASTAPPROX\n") if useMathExp() or (useNewTableExp()): file.write("#define FLOATEXP\n") else: file.write("//#define FLOATEXP\n") def printPrefix(self): self.printCompilerConfig() self.printArduinoIncludes() self.printExpTables() self.printArduinoHeader() self.printVarDecls() self.printConstDecls() self.out.printf('\n') def printArduinoIncludes(self): self.out.printf('#include <Arduino.h>\n\n', indent=True) self.out.printf('#include "config.h"\n', indent=True) self.out.printf('#include "predict.h"\n', indent=True) self.out.printf('#include "library.h"\n', indent=True) self.out.printf('#include "model.h"\n\n', indent=True) self.out.printf('using namespace model;\n\n', indent=True) # Dumps the generated look-up table for computing exponentials. def printExpTables(self): for exp, [table, [tableVarA, tableVarB]] in self.expTables.items(): self.printExpTable(table[0], tableVarA) self.printExpTable(table[1], tableVarB) self.out.printf('\n') def printExpTable(self, table_row, var): self.out.printf('const PROGMEM MYINT %s[%d] = {\n' % (var.idf, len(table_row)), indent=True) self.out.increaseIndent() self.out.printf('', indent=True) for i in range(len(table_row)): self.out.printf('%d, ' % table_row[i]) self.out.decreaseIndent() self.out.printf('\n};\n') def printArduinoHeader(self): self.out.printf('int predict() {\n', indent=True) self.out.increaseIndent() # Generate the appropriate return experssion # If integer, return the integer # If tensor of size 0, convert the fixed-point integer to float and return the float value. # If tensor of size >0, convert the tensor to fixed-point integer, print # it to the serial port, and return void. def printSuffix(self, expr: IR.Expr): self.out.printf('\n') type = self.decls[expr.idf] if Type.isInt(type): self.out.printf('return ', indent=True) self.print(expr) self.out.printf(';\n') elif Type.isTensor(type): idfr = expr.idf exponent = self.scales[expr.idf] num = 2**exponent if type.dim == 0: self.out.printf('Serial.println(', indent=True) self.out.printf('float(' + idfr + ')*' + str(num)) self.out.printf(', 6);\n') else: iters = [] for i in range(type.dim): s = chr(ord('i') + i) tempVar = IR.Var(s) iters.append(tempVar) expr_1 = IRUtil.addIndex(expr, iters) cmds = IRUtil.loop(type.shape, iters, [IR.PrintAsFloat(expr_1, exponent)]) self.print(IR.Prog(cmds)) else: assert False self.out.decreaseIndent() self.out.printf('}\n', indent=True) self.out.close() with open(os.path.join(self.outputDir, "ram.usage"), "w") as f: f.write("Estimate RAM usage :: %d bytes" % (self.maxRAMestimate)) ''' Below functions are overriding their corresponding definitions in codegenBase.py. These function have arduino-specific print functions. ''' # Print the variable with pragmas def printVar(self, ir): if ir.inputVar: if config.wordLength == 8: self.out.printf('((MYINT) pgm_read_byte_near(&') if config.wordLength == 16: self.out.printf('((MYINT) pgm_read_word_near(&') elif config.wordLength == 32: self.out.printf('((MYINT) pgm_read_dword_near(&') else: assert False self.out.printf('%s', ir.idf) for e in ir.idx: self.out.printf('[') self.print(e) self.out.printf(']') if ir.inputVar: self.out.printf('))') def printFor(self, ir): if forFloat(): super().printFor(ir) else: self.printForHeader(ir) self.out.increaseIndent() varToLiveRange = [] for var in ir.varDecls.keys(): size = np.prod(self.localDecls[var].shape) varToLiveRange.append((self.varLiveIntervals[var], var, size, self.varsForBitwidth[var])) varToLiveRange.sort() usedSpaceMap = {} totalScratchSize = -1 listOfDimensions = [] for ([_, _], var, size, atomSize) in varToLiveRange: listOfDimensions.append(size) mode = (lambda x: np.bincount(x).argmax() )(listOfDimensions) if len(listOfDimensions) > 0 else None for ([startIns, endIns], var, size, atomSize) in varToLiveRange: if var in self.notScratch: continue spaceNeeded = size * atomSize // 8 varsToKill = [] for activeVar in usedSpaceMap.keys(): endingIns = usedSpaceMap[activeVar][0] if endingIns < startIns: varsToKill.append(activeVar) for tbk in varsToKill: del usedSpaceMap[tbk] i = 0 if spaceNeeded >= mode: blockSize = int(2**np.ceil(np.log2( spaceNeeded / mode))) * mode else: blockSize = mode / int(2**np.floor( np.log2(mode // spaceNeeded))) breakOutOfWhile = True while True: potentialStart = int(blockSize * i) potentialEnd = int(blockSize * (i + 1)) - 1 for activeVar in usedSpaceMap.keys(): (locationOccupiedStart, locationOccupiedEnd) = usedSpaceMap[activeVar][1] if not (locationOccupiedStart > potentialEnd or locationOccupiedEnd < potentialStart): i += 1 breakOutOfWhile = False break else: breakOutOfWhile = True continue if breakOutOfWhile: break usedSpaceMap[var] = (endIns, (potentialStart, potentialEnd)) totalScratchSize = max(totalScratchSize, potentialEnd) self.scratchSubs[var] = potentialStart self.currentRAMestimate += (totalScratchSize + 1) self.maxRAMestimate = max(self.currentRAMestimate, self.maxRAMestimate) self.out.printf("char scratch[%d];\n" % (totalScratchSize + 1), indent=True) self.printLocalVarDecls(ir) for cmd in ir.cmd_l: self.print(cmd) self.out.decreaseIndent() self.out.printf('}\n', indent=True) self.updateRAMafterDealloc(ir) self.currentRAMestimate -= (totalScratchSize + 1) # The variable X is used to define the data point. # It is either read from the serial port or from the device's memory based on the operating mode. # The getIntFeature() function reads the appropriate value of X based on the mode. def printAssn(self, ir): if isinstance(ir.e, IR.Var) and ir.e.idf == "X": self.out.printf("", indent=True) self.print(ir.var) indices = [index.idf for index in ir.e.idx] sizes = self.localDecls[ ir.e.idf].shape if ir.e.idf in self.localDecls else self.decls[ ir.e.idf].shape assert len(indices) == len(sizes), "Illegal state" prod = functools.reduce(operator.mul, sizes) dereferenceString = "" for i in range(len(indices)): prod = prod // sizes[i] dereferenceString += ("%s * %d + " % (indices[i], prod)) dereferenceString = dereferenceString[:-3] if forFixed(): self.out.printf(" = getIntFeature(%s);\n" % (dereferenceString)) else: self.out.printf(" = getFloatFeature(%s);\n" % (dereferenceString)) else: super().printAssn(ir) def printFuncCall(self, ir): self.out.printf("{\n", indent=True) self.out.increaseIndent() self.printLocalVarDecls(ir) self.out.printf("%s(" % ir.name, indent=True) keys = list(ir.argList) for i in range(len(keys)): arg = keys[i] # Do not print the 'X' variable as it will be read from the getIntFeature() function. if isinstance(arg, IR.Var) and arg.idf == 'X': continue # The value of x in the below code is the number of special characters (& and []) around the variable in the function call. # This number depends on the shape of the variable. # Example: A[10][10] is written as &A[0][0]. The value of x in this case is 2. # x is 0 for constants # x is -1 for integer variables where only & is printed and not [] if isinstance(arg, IR.Var) and ( arg.idf in self.decls.keys() or arg.idf in self.localDecls.keys()) and not arg.idf == 'X': type = self.decls[ arg.idf] if arg.idf in self.decls else self.localDecls[ arg.idf] if isinstance(type, Type.Tensor): if type.dim == 0: x = -1 else: x = type.dim - len(arg.idx) else: x = -1 else: x = 0 if forFixed(): typeCast = "(int%d_t*)" % self.varsForBitwidth[ arg.idf] if x > 0 else "" self.out.printf(typeCast) if not (isinstance(arg, IR.Var) and arg.idf in self.scratchSubs): if x != 0: self.out.printf("&") self.print(arg) if x != 0 and x != -1: self.out.printf("[0]" * x) else: self.out.printf("(scratch + %d)" % (self.scratchSubs[arg.idf])) if i != len(keys) - 1: self.out.printf(", ") self.out.printf(");\n") self.out.decreaseIndent() self.out.printf("}\n", indent=True) self.updateRAMafterDealloc(ir) def printPrint(self, ir): self.out.printf('Serial.println(', indent=True) self.print(ir.expr) self.out.printf(');\n') def printPrintAsFloat(self, ir): self.out.printf('Serial.println(float(', indent=True) self.print(ir.expr) self.out.printf(') * ' + str(2**ir.expnt) + ', 6);')
def printVarDecls(self, globalVarDecl=True): if self.generateAllFiles: varsFilePath = os.path.join(self.outputDir, "vars_" + getVersion() + ".h") varsFile = Writer(varsFilePath) varsFile.printf("#pragma once\n\n") varsFile.printf("#include \"datatypes.h\"\n\n") varsFile.printf("namespace vars_%s {\n" % (getVersion())) varsFile.increaseIndent() for decl in self.decls: if decl in self.globalVars: continue if forFloat() and decl not in self.internalVars: typ_str = IR.DataType.getFloatStr() elif forFixed() and decl not in self.internalVars: if config.vbwEnabled and decl not in self.internalVars: bw = self.varsForBitwidth.get(decl, config.wordLength) typ_str = "int%d_t" % bw else: typ_str = IR.DataType.getIntStr() else: typ_str = IR.DataType.getIntStr() idf_str = decl type = self.decls[decl] if Type.isInt(type): shape_str = '' elif Type.isTensor(type): shape_str = ''.join(['[' + str(n) + ']' for n in type.shape]) if not config.vbwEnabled: self.out.printf('%s vars_%s::%s%s;\n', typ_str, getVersion(), idf_str, shape_str, indent=True) if self.generateAllFiles: varsFile.printf('extern %s %s%s;\n', typ_str, idf_str, shape_str, indent=True) else: if forFixed( ) and idf_str in self.varsForBitwidth and idf_str[:3] == "tmp": if globalVarDecl: for bw in config.availableBitwidths: self.out.printf("int%d_t vars_%s::%s_%d%s;\n", bw, getVersion(), idf_str, bw, shape_str, indent=True) else: self.out.printf("int%d_t %s_%d%s;\n", self.varsForBitwidth[idf_str], idf_str, bw, shape_str, indent=True) else: if globalVarDecl: self.out.printf("%s vars_%s::%s%s;\n", typ_str, getVersion(), idf_str, shape_str, indent=True) else: self.out.printf("%s %s%s;\n", typ_str, idf_str, shape_str, indent=True) if self.generateAllFiles: if forFixed( ) and idf_str in self.varsForBitwidth and idf_str[: 3] == "tmp": for bw in config.availableBitwidths: varsFile.printf("extern int%d_t %s_%d%s;\n", bw, idf_str, bw, shape_str, indent=True) else: varsFile.printf("extern %s %s%s;\n", typ_str, idf_str, shape_str, indent=True) self.out.printf('\n') if self.generateAllFiles: varsFile.decreaseIndent() varsFile.printf("}\n") varsFile.close() self.generateDebugProgram()
class X86(CodegenBase): def __init__(self, outputDir, generateAllFiles, printSwitch, idStr, decls, localDecls, scales, intvs, cnsts, expTables, globalVars, internalVars, floatConstants, substitutions, demotedVarsOffsets, varsForBitwidth, varLiveIntervals, notScratch): self.outputDir = outputDir cppFile = os.path.join(self.outputDir, "seedot_" + getVersion() + ".cpp") if generateAllFiles: self.out = Writer(cppFile) else: if debugCompiler(): print("Opening file to output cpp code: ID" + idStr) for i in range(3): if debugCompiler(): print("Try %d" % (i + 1)) try: self.out = Writer(cppFile, "a") except: if debugCompiler(): print( "OS prevented file from opening. Sleeping for %d seconds" % (i + 1)) time.sleep(i + 1) else: if debugCompiler(): print("Opened") break self.decls = decls self.localDecls = localDecls self.scales = scales self.intvs = intvs self.cnsts = cnsts self.expTables = expTables self.globalVars = globalVars self.internalVars = internalVars self.floatConstants = floatConstants self.generateAllFiles = generateAllFiles self.idStr = idStr self.printSwitch = printSwitch self.demotedVarsOffsets = demotedVarsOffsets self.varsForBitwidth = varsForBitwidth self.varLiveIntervals = varLiveIntervals self.notScratch = notScratch def printPrefix(self): if self.generateAllFiles: self.printCincludes() self.printExpTables() self.printVarDecls() self.printCHeader() self.printModelParamsWithBitwidth() self.printVarDecls(globalVarDecl=False) self.printConstDecls() self.out.printf('\n') def printCincludes(self): self.out.printf('#include <iostream>\n', indent=True) self.out.printf('#include <cstring>\n', indent=True) self.out.printf('#include <cmath>\n\n', indent=True) self.out.printf('#include "datatypes.h"\n', indent=True) self.out.printf('#include "predictors.h"\n', indent=True) self.out.printf('#include "profile.h"\n', indent=True) self.out.printf('#include "library_%s.h"\n' % (getVersion()), indent=True) self.out.printf('#include "model_%s.h"\n' % (getVersion()), indent=True) self.out.printf('#include "vars_%s.h"\n\n' % (getVersion()), indent=True) self.out.printf('using namespace std;\n', indent=True) self.out.printf('using namespace seedot_%s;\n' % (getVersion()), indent=True) # self.out.printf('using namespace vars_%s;\n\n' % # (getVersion()), indent=True) def printExpTables(self): for exp, [table, [tableVarA, tableVarB]] in self.expTables.items(): self.printExpTable(table[0], tableVarA) self.printExpTable(table[1], tableVarB) self.out.printf('\n') def printExpTable(self, table_row, var): self.out.printf('const MYINT %s[%d] = {\n' % (var.idf, len(table_row)), indent=True) self.out.increaseIndent() self.out.printf('', indent=True) for i in range(len(table_row)): self.out.printf('%d, ' % table_row[i]) self.out.decreaseIndent() self.out.printf('\n};\n') def printCHeader(self): if forFloat(): func = "Float" type = "float" else: func = "Fixed" type = "MYINT" if forFloat(): self.out.printf('int seedot%s(%s **X) {\n' % (func, type), indent=True) else: self.out.printf( 'int seedot%s%s(%s **X%s) {\n' % (func, self.idStr if not self.generateAllFiles else "", type, "_temp" if config.vbwEnabled else ""), indent=True) self.out.increaseIndent() def printModelParamsWithBitwidth(self): if config.vbwEnabled and forFixed(): for var in self.globalVars: if var + "idx" in self.globalVars and var + "val" in self.globalVars: continue bw = self.varsForBitwidth[var] typ_str = "int%d_t" % bw size = self.decls[var].shape sizestr = ''.join(["[%d]" % (i) for i in size]) Xindexstr = '' Xintstar = ''.join(["*" for i in size]) if var != 'X': self.out.printf(typ_str + " " + var + sizestr + ";\n", indent=True) else: self.out.printf(typ_str + Xintstar + " " + var + ";\n", indent=True) for i in range(len(size)): Xindexstr += ("[i" + str(i - 1) + "]" if i > 0 else "") if var == 'X': Xintstar = Xintstar[:-1] self.out.printf( "X%s = new %s%s[%d];\n" % (Xindexstr, typ_str, Xintstar, size[i]), indent=True) self.out.printf("for (int i%d = 0; i%d < %d; i%d ++) {\n" % (i, i, size[i], i), indent=True) self.out.increaseIndent() indexstr = ''.join("[i" + str(i) + "]" for i in range(len(size))) divide = int( round( np.ldexp( 1, config.wordLength - self.varsForBitwidth[var] + (self.demotedVarsOffsets.get(var, 0) if self. varsForBitwidth[var] != config.wordLength else 0)) )) if var[-3:] != "idx" and var != "X" else 1 self.out.printf(var + indexstr + " = " + var + "_temp" + indexstr + "/" + str(divide) + ";\n", indent=True) for i in range(len(size)): self.out.decreaseIndent() self.out.printf("}\n", indent=True) def printVarDecls(self, globalVarDecl=True): if self.generateAllFiles: varsFilePath = os.path.join(self.outputDir, "vars_" + getVersion() + ".h") varsFile = Writer(varsFilePath) varsFile.printf("#pragma once\n\n") varsFile.printf("#include \"datatypes.h\"\n\n") varsFile.printf("namespace vars_%s {\n" % (getVersion())) varsFile.increaseIndent() for decl in self.decls: if decl in self.globalVars: continue if forFloat() and decl not in self.internalVars: typ_str = IR.DataType.getFloatStr() elif forFixed() and decl not in self.internalVars: if config.vbwEnabled and decl not in self.internalVars: bw = self.varsForBitwidth.get(decl, config.wordLength) typ_str = "int%d_t" % bw else: typ_str = IR.DataType.getIntStr() else: typ_str = IR.DataType.getIntStr() idf_str = decl type = self.decls[decl] if Type.isInt(type): shape_str = '' elif Type.isTensor(type): shape_str = ''.join(['[' + str(n) + ']' for n in type.shape]) if not config.vbwEnabled: self.out.printf('%s vars_%s::%s%s;\n', typ_str, getVersion(), idf_str, shape_str, indent=True) if self.generateAllFiles: varsFile.printf('extern %s %s%s;\n', typ_str, idf_str, shape_str, indent=True) else: if forFixed( ) and idf_str in self.varsForBitwidth and idf_str[:3] == "tmp": if globalVarDecl: for bw in config.availableBitwidths: self.out.printf("int%d_t vars_%s::%s_%d%s;\n", bw, getVersion(), idf_str, bw, shape_str, indent=True) else: self.out.printf("int%d_t %s_%d%s;\n", self.varsForBitwidth[idf_str], idf_str, bw, shape_str, indent=True) else: if globalVarDecl: self.out.printf("%s vars_%s::%s%s;\n", typ_str, getVersion(), idf_str, shape_str, indent=True) else: self.out.printf("%s %s%s;\n", typ_str, idf_str, shape_str, indent=True) if self.generateAllFiles: if forFixed( ) and idf_str in self.varsForBitwidth and idf_str[: 3] == "tmp": for bw in config.availableBitwidths: varsFile.printf("extern int%d_t %s_%d%s;\n", bw, idf_str, bw, shape_str, indent=True) else: varsFile.printf("extern %s %s%s;\n", typ_str, idf_str, shape_str, indent=True) self.out.printf('\n') if self.generateAllFiles: varsFile.decreaseIndent() varsFile.printf("}\n") varsFile.close() self.generateDebugProgram() def generateDebugProgram(self): if not self.generateAllFiles: return debugFilePath = os.path.join(self.outputDir, "debug.cpp") debugFile = Writer(debugFilePath) debugFile.printf("#include <iostream>\n\n") debugFile.printf("#include \"datatypes.h\"\n") debugFile.printf("#include \"profile.h\"\n") debugFile.printf("#include \"vars_fixed.h\"\n") debugFile.printf("#include \"vars_float.h\"\n\n") debugFile.printf("using namespace std;\n\n") debugFile.printf("void debug() {\n\n") if debugMode() and forFixed(): debugFile.increaseIndent() for decl in self.decls: if decl in self.globalVars: continue type = self.decls[decl] if decl not in self.scales or not isinstance( type, Type.Tensor) or type.isShapeOne(): continue scale = self.scales[decl] s = decl + "[0]" * type.dim shape_str = ''.join([str(n) + ', ' for n in type.shape]) shape_str = shape_str.rstrip(', ') debugFile.printf( "diff(&vars_float::%s, &vars_fixed::%s, %d, %s);\n\n" % (s, s, scale, shape_str), indent=True) debugFile.decreaseIndent() debugFile.printf("}\n") debugFile.close() def printSuffix(self, expr: IR.Expr): self.out.printf('\n') if config.vbwEnabled and forFixed(): bw = self.varsForBitwidth['X'] typ_str = "int%d_t" % bw size = self.decls['X'].shape sizestr = ''.join([("[%d]" % i) for i in size]) Xindexstr = '' Xintstar = ''.join(["*" for i in size]) for i in range(len(size)): Xindexstr += (("[i%d]" % (i - 1)) if i > 0 else "") self.out.printf("for (int i%d = 0; i%d < %d; i%d ++ ){\n" % (i, i, size[i], i), indent=True) self.out.increaseIndent() for i in range(len(size) - 1, -1, -1): self.out.decreaseIndent() self.out.printf("}\n", indent=True) self.out.printf("delete[] X%s;\n" % (Xindexstr), indent=True) Xindexstr = Xindexstr[:-4] if len(Xindexstr) > 0 else Xindexstr assert len( size ) < 10, "Too simple logic for printing indices used, cannot handle 10+ Dim Tensors" type = self.decls[expr.idf] if Type.isInt(type): self.out.printf('return ', indent=True) self.print(expr) self.out.printf(';\n') elif Type.isTensor(type): idfr = expr.idf exponent = self.scales[expr.idf] num = 2**exponent if type.dim == 0: self.out.printf('cout << ', indent=True) self.out.printf('float(' + idfr + ')*' + str(num)) self.out.printf(' << endl;\n') else: iters = [] for i in range(type.dim): s = chr(ord('i') + i) tempVar = IR.Var(s) iters.append(tempVar) expr_1 = IRUtil.addIndex(expr, iters) cmds = IRUtil.loop(type.shape, iters, [IR.PrintAsFloat(expr_1, exponent)]) self.print(IR.Prog(cmds)) else: assert False self.out.decreaseIndent() self.out.printf('}\n', indent=True) def isInt(a): try: int(a) return True except: return False if forFixed(): if (int(self.printSwitch) if isInt(self.printSwitch) else -2) > -1: self.out.printf("const int switches = %d;\n" % (int(self.printSwitch)), indent=True) self.out.printf( 'void seedotFixedSwitch(int i, MYINT **X_temp, int& res) {\n', indent=True) self.out.increaseIndent() self.out.printf('switch(i) {\n', indent=True) self.out.increaseIndent() for i in range(int(self.printSwitch)): self.out.printf( 'case %d: res = seedotFixed%d(X_temp); return;\n' % (i, i + 1), indent=True) self.out.printf('default: res = -1; return;\n', indent=True) self.out.decreaseIndent() self.out.printf('}\n', indent=True) self.out.decreaseIndent() self.out.printf('}\n', indent=True) if debugCompiler(): print("Closing File after outputting cpp code: ID " + self.idStr) self.out.close()