def generateIRCode(graph, extraInfoDict): program = None innerMostLetASTNode = None dictNodeNameToOutVarStr = {} outVarCt = 0 outVarPrefix = "J" mtdAST = MtdAST() for curNode in graph.getAllNodesRef(): for curInp in curNode.getInputsRef(): assert(curInp in dictNodeNameToOutVarStr), "input={} expected as input but not yet processed".format(curInp) #Consequence of topological sorting of the TF graph (assignedVarAST, curAsts) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) for outputName, curAst in curAsts.items(): mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(), AST.ASTNode.mtdKeyTFNodeName : outputName} if (curAst is None): dictNodeNameToOutVarStr[outputName] = None continue curOutVarStr = outVarPrefix + str(outVarCt) curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr)) if program: assert(type(innerMostLetASTNode) is AST.Let) newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) mtdAST.visit(newNode, mtdForCurAST) innerMostLetASTNode.expr = newNode innerMostLetASTNode = newNode else: innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode) mtdAST.visit(innerMostLetASTNode, mtdForCurAST) innerMostLetASTNode.depth = 0 program = innerMostLetASTNode dictNodeNameToOutVarStr[outputName] = curOutVarStr outVarCt += 1 return (program, dictNodeNameToOutVarStr)
def addOutputs(program, dictNodeNameToOutVarStr, output_tensors): mtdAST = MtdAST() assert type(program) is AST.Let lastLetASTNode = program while True: if type(lastLetASTNode.expr) is AST.Let: lastLetASTNode = lastLetASTNode.expr else: break assert lastLetASTNode is not None if output_tensors is None: output_name = lastLetASTNode.name print(output_name.name) output = AST.Output(output_name, AST.Party.CLIENT) lastLetASTNode.expr = output else: outVarCt = 0 outVarPrefix = "O" for i in range(0, len(output_tensors)): # name, decl, expr t_name = output_tensors[i] if i == len(output_tensors) - 1: output_name = AST.ID(dictNodeNameToOutVarStr[t_name]) output = AST.Output(output_name, AST.Party.CLIENT) newNode = output else: output_name = AST.ID(dictNodeNameToOutVarStr[t_name]) output = AST.Output(output_name, AST.Party.CLIENT) let_name_id = AST.ID(outVarPrefix + str(outVarCt)) newNode = AST.Let(let_name_id, output, AST.ASTNode()) mtdForCurAST = { AST.ASTNode.mtdKeyTFOpName: "Output", AST.ASTNode.mtdKeyTFNodeName: t_name, } mtdAST.visit(newNode, mtdForCurAST) lastLetASTNode.expr = newNode lastLetASTNode = newNode outVarCt += 1 return program
def run(self): with open(Util.Config.astFile, "rb") as ff: ast = pickle.load(ff) if not (Util.Config.disableAllOpti): if not (Util.Config.disableRMO): print("Performing Relu-maxpool optimization...") ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast) print("Relu-maxpool optimization done.") if not (Util.Config.disableLivenessOpti): print("Performing Garbage collection...") mtdAST = MtdAST() GC = GarbageCollector.GarbageCollector(ast) GC.run([mtdAST]) print("Garbage collection done.") # Perform type inference and annotate nodes with type information InferType().visit(ast) # if Util.Config.printASTBool : if False: PrintAST().visit(ast) print("\n") sys.stdout.flush() IRUtil.init() compiler = IRBuilderCSF() res = compiler.visit(ast) res = self.fixOuputScale(res, compiler) res = self.fixNames(res, compiler) Util.write_debug_info(compiler.name_mapping) # Insert a generic start_computation and end_computation function call after all input IR statements. res = self.insertStartEndFunctionCalls(res) writer = Writer(Util.Config.outputFileName) debugVarEzPCName = (compiler.name_mapping[Util.Config.debugVar] if (Util.Config.debugVar in compiler.name_mapping) else None) if Util.forEzPC(): codegen = EzPCCodegen(writer, compiler.globalDecls, debugVarEzPCName) else: assert False codegen.printAll(*res) writer.close()
def generate_seedot_ast(model, value_info, model_dir): graph_def = model.graph # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes program = None innermost_let_ast_node = None node_name_to_out_var_dict = {} out_var_count = 0 mtdAST = MtdAST() (program, innermost_let_ast_node, out_var_count) = process_input_variables( program, innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info, ) process_onnx_nodes( innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info, model, ) output_tensors = [i.name for i in graph_def.output] addOutputs( output_tensors, innermost_let_ast_node, node_name_to_out_var_dict, mtdAST, value_info, ) if DEBUG: PrintAST().visit(program) common.write_debug_info(node_name_to_out_var_dict) with open(os.path.join(model_dir, "astOutput.pkl"), "wb") as f: # print(program) pickle.dump(program, f) print("Dumped SeeDot AST")
def run(self): with open(Util.Config.astFile, 'rb') as ff: ast = pickle.load(ff) if not (Util.Config.disableAllOpti): if not (Util.Config.disableRMO): print("Performing Relu-maxpool optimization...") # Perform optimizations on the AST ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast) if not (Util.Config.disableLivenessOpti): print("Performing Liveness Optimization...") # Perform liveness analysis optimization on the AST mtdAST = MtdAST() LivenessOpti.LivenessAnalysis().visit(ast) LivenessOpti.LivenessOpti().visit(ast, [mtdAST, 0, {}]) if Util.Config.printASTBool: PrintAST().visit(ast) sys.stdout.flush() # Perform type inference InferType().visit(ast) IRUtil.init() compiler = IRBuilderCSF() res = compiler.visit(ast) Util.write_debug_info(compiler.name_mapping) # Insert a generic start_computation and end_computation function call after all input IR statements. res = self.insertStartEndFunctionCalls(res) writer = Writer(Util.Config.outputFileName) debugVarEzPCName = compiler.name_mapping[Util.Config.debugVar] if ( Util.Config.debugVar in compiler.name_mapping) else None if Util.forEzPC(): codegen = EzPCCodegen(writer, compiler.decls, debugVarEzPCName) else: assert False codegen.printAll(*res) writer.close()
def addOutputs(program, dictNodeNameToOutVarStr, output_tensors): mtdAST = MtdAST() assert type(program) is AST.Let lastLetASTNode = program while True: if type(lastLetASTNode.expr) is AST.Let: lastLetASTNode = lastLetASTNode.expr else: break assert lastLetASTNode is not None if output_tensors is None: output_name = lastLetASTNode.name tf_node_name = lastLetASTNode.decl.metadata[ AST.ASTNode.mtdKeyTFNodeName] print( "Output not specified, taking output of ", tf_node_name, " as program output.", ) output = AST.Output(output_name, AST.Party.CLIENT) mtdForCurAST = { AST.ASTNode.mtdKeyTFOpName: "Output", AST.ASTNode.mtdKeyTFNodeName: tf_node_name, } mtdAST.visit(output, mtdForCurAST) lastLetASTNode.expr = output else: outVarCt = 0 outVarPrefix = "O" for i in range(0, len(output_tensors)): # name, decl, expr t_name = output_tensors[i] if t_name not in dictNodeNameToOutVarStr: if ":" in t_name: try: op_name, out_n = t_name.split(":") out_n = int(out_n) except: raise ValueError( "The tensor name {} looks like a tensor name but is not a valid one" .format(name)) if out_n == 0: if op_name in dictNodeNameToOutVarStr: t_name = op_name else: t_name = op_name + "_mpc_const_var" else: t_name = op_name + "_mpc_const_var" + ":" + str(out_n) else: t_name += "_mpc_const_var" if i == len(output_tensors) - 1: output_name = AST.ID(dictNodeNameToOutVarStr[t_name]) output = AST.Output(output_name, AST.Party.CLIENT) newNode = output else: output_name = AST.ID(dictNodeNameToOutVarStr[t_name]) output = AST.Output(output_name, AST.Party.CLIENT) let_name_id = AST.ID(outVarPrefix + str(outVarCt)) newNode = AST.Let(let_name_id, output, AST.ASTNode()) mtdForCurAST = { AST.ASTNode.mtdKeyTFOpName: "Output", AST.ASTNode.mtdKeyTFNodeName: t_name, } mtdAST.visit(newNode, mtdForCurAST) lastLetASTNode.expr = newNode lastLetASTNode = newNode outVarCt += 1 return program
def main(): sys.setrecursionlimit(10000) # First read the ONNX file if (len(sys.argv) < 2): print("TF python file unspecified.", file=sys.stderr) exit(1) file_name = sys.argv[1] file_path = 'models/' + file_name model_name = file_name[:-5] # name without the '.onnx' extension # load the model and extract the graph model = onnx.load(file_path) graph_def = model.graph print(model.graph.value_info) # Before shape inference (model.graph.value_info) should have shapes of all the variables and constants model.graph.value_info.append( make_tensor_value_info( model.graph.input[0].name, TensorProto.FLOAT, common.proto_val_to_dimension_tuple(model.graph.input[0]))) model.graph.value_info.append( make_tensor_value_info( model.graph.output[0].name, TensorProto.FLOAT, common.proto_val_to_dimension_tuple(model.graph.output[0]))) print(model.graph.value_info) for init_vals in model.graph.initializer: model.graph.value_info.append( make_tensor_value_info(init_vals.name, TensorProto.FLOAT, tuple(init_vals.dims))) if (DEBUG): print("Shape inference *****************") print(model.graph.value_info) inferred_model = onnx.shape_inference.infer_shapes(model) if (DEBUG): print("Printing shape ******************") print(inferred_model.graph.value_info) print("Done ******************") # value_info: dictionary of name -> (type, dimension tuple) value_info = {} for val in inferred_model.graph.value_info: value_info[val.name] = (val.type.tensor_type.elem_type, common.proto_val_to_dimension_tuple(val)) # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes program = None innermost_let_ast_node = None node_name_to_out_var_dict = {} out_var_count = 0 mtdAST = MtdAST() (program, innermost_let_ast_node, out_var_count) = process_input_variables( program, innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info) process_onnx_nodes(innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info) PrintAST().visit(program) common.write_debug_info(node_name_to_out_var_dict) with open('debug/' + model_name + '/' + model_name + '.pkl', 'wb') as f: pickle.dump(program, f)