예제 #1
0
def generateIRCode(graph, extraInfoDict):
	program = None
	innerMostLetASTNode = None
	dictNodeNameToOutVarStr = {}
	outVarCt = 0
	outVarPrefix = "J"
	mtdAST = MtdAST()
	for curNode in graph.getAllNodesRef():
		for curInp in curNode.getInputsRef():
			assert(curInp in dictNodeNameToOutVarStr), "input={} expected as input but not yet processed".format(curInp) #Consequence of topological sorting of the TF graph
		(assignedVarAST, curAsts) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict)
		for outputName, curAst in curAsts.items():
			mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(),
							AST.ASTNode.mtdKeyTFNodeName : outputName}

			if (curAst is None):
				dictNodeNameToOutVarStr[outputName] = None
				continue
			curOutVarStr = outVarPrefix + str(outVarCt)
			curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr))
			if program:
				assert(type(innerMostLetASTNode) is AST.Let)
				newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode)
				mtdAST.visit(newNode, mtdForCurAST)
				innerMostLetASTNode.expr = newNode
				innerMostLetASTNode = newNode
			else:
				innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode)
				mtdAST.visit(innerMostLetASTNode, mtdForCurAST)
				innerMostLetASTNode.depth = 0
				program = innerMostLetASTNode
			dictNodeNameToOutVarStr[outputName] = curOutVarStr
			outVarCt += 1
	return (program, dictNodeNameToOutVarStr)
예제 #2
0
def addOutputs(program, dictNodeNameToOutVarStr, output_tensors):
    mtdAST = MtdAST()
    assert type(program) is AST.Let
    lastLetASTNode = program
    while True:
        if type(lastLetASTNode.expr) is AST.Let:
            lastLetASTNode = lastLetASTNode.expr
        else:
            break
    assert lastLetASTNode is not None
    if output_tensors is None:
        output_name = lastLetASTNode.name
        print(output_name.name)
        output = AST.Output(output_name, AST.Party.CLIENT)
        lastLetASTNode.expr = output
    else:
        outVarCt = 0
        outVarPrefix = "O"
        for i in range(0, len(output_tensors)):  # name, decl, expr
            t_name = output_tensors[i]
            if i == len(output_tensors) - 1:
                output_name = AST.ID(dictNodeNameToOutVarStr[t_name])
                output = AST.Output(output_name, AST.Party.CLIENT)
                newNode = output
            else:
                output_name = AST.ID(dictNodeNameToOutVarStr[t_name])
                output = AST.Output(output_name, AST.Party.CLIENT)
                let_name_id = AST.ID(outVarPrefix + str(outVarCt))
                newNode = AST.Let(let_name_id, output, AST.ASTNode())
                mtdForCurAST = {
                    AST.ASTNode.mtdKeyTFOpName: "Output",
                    AST.ASTNode.mtdKeyTFNodeName: t_name,
                }
                mtdAST.visit(newNode, mtdForCurAST)
            lastLetASTNode.expr = newNode
            lastLetASTNode = newNode
            outVarCt += 1

    return program
예제 #3
0
파일: Compiler.py 프로젝트: mpc-msri/EzPC
    def run(self):
        with open(Util.Config.astFile, "rb") as ff:
            ast = pickle.load(ff)

        if not (Util.Config.disableAllOpti):
            if not (Util.Config.disableRMO):
                print("Performing Relu-maxpool optimization...")
                ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast)
                print("Relu-maxpool optimization done.")

            if not (Util.Config.disableLivenessOpti):
                print("Performing Garbage collection...")
                mtdAST = MtdAST()
                GC = GarbageCollector.GarbageCollector(ast)
                GC.run([mtdAST])
                print("Garbage collection done.")

        # Perform type inference and annotate nodes with type information
        InferType().visit(ast)

        # if Util.Config.printASTBool :
        if False:
            PrintAST().visit(ast)
            print("\n")
            sys.stdout.flush()

        IRUtil.init()
        compiler = IRBuilderCSF()
        res = compiler.visit(ast)
        res = self.fixOuputScale(res, compiler)
        res = self.fixNames(res, compiler)

        Util.write_debug_info(compiler.name_mapping)

        # Insert a generic start_computation and end_computation function call after all input IR statements.
        res = self.insertStartEndFunctionCalls(res)
        writer = Writer(Util.Config.outputFileName)
        debugVarEzPCName = (compiler.name_mapping[Util.Config.debugVar] if
                            (Util.Config.debugVar
                             in compiler.name_mapping) else None)

        if Util.forEzPC():
            codegen = EzPCCodegen(writer, compiler.globalDecls,
                                  debugVarEzPCName)
        else:
            assert False

        codegen.printAll(*res)
        writer.close()
예제 #4
0
def generate_seedot_ast(model, value_info, model_dir):
    graph_def = model.graph
    # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes
    program = None
    innermost_let_ast_node = None
    node_name_to_out_var_dict = {}
    out_var_count = 0
    mtdAST = MtdAST()

    (program, innermost_let_ast_node, out_var_count) = process_input_variables(
        program,
        innermost_let_ast_node,
        node_name_to_out_var_dict,
        out_var_count,
        mtdAST,
        graph_def,
        value_info,
    )

    process_onnx_nodes(
        innermost_let_ast_node,
        node_name_to_out_var_dict,
        out_var_count,
        mtdAST,
        graph_def,
        value_info,
        model,
    )

    output_tensors = [i.name for i in graph_def.output]
    addOutputs(
        output_tensors,
        innermost_let_ast_node,
        node_name_to_out_var_dict,
        mtdAST,
        value_info,
    )

    if DEBUG:
        PrintAST().visit(program)
        common.write_debug_info(node_name_to_out_var_dict)

    with open(os.path.join(model_dir, "astOutput.pkl"), "wb") as f:
        # print(program)
        pickle.dump(program, f)
        print("Dumped SeeDot AST")
예제 #5
0
    def run(self):
        with open(Util.Config.astFile, 'rb') as ff:
            ast = pickle.load(ff)

        if not (Util.Config.disableAllOpti):
            if not (Util.Config.disableRMO):
                print("Performing Relu-maxpool optimization...")
                # Perform optimizations on the AST
                ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast)

            if not (Util.Config.disableLivenessOpti):
                print("Performing Liveness Optimization...")
                # Perform liveness analysis optimization on the AST
                mtdAST = MtdAST()
                LivenessOpti.LivenessAnalysis().visit(ast)
                LivenessOpti.LivenessOpti().visit(ast, [mtdAST, 0, {}])

        if Util.Config.printASTBool:
            PrintAST().visit(ast)
            sys.stdout.flush()

# Perform type inference
        InferType().visit(ast)

        IRUtil.init()
        compiler = IRBuilderCSF()
        res = compiler.visit(ast)

        Util.write_debug_info(compiler.name_mapping)

        # Insert a generic start_computation and end_computation function call after all input IR statements.
        res = self.insertStartEndFunctionCalls(res)

        writer = Writer(Util.Config.outputFileName)

        debugVarEzPCName = compiler.name_mapping[Util.Config.debugVar] if (
            Util.Config.debugVar in compiler.name_mapping) else None

        if Util.forEzPC():
            codegen = EzPCCodegen(writer, compiler.decls, debugVarEzPCName)
        else:
            assert False

        codegen.printAll(*res)
        writer.close()
예제 #6
0
def addOutputs(program, dictNodeNameToOutVarStr, output_tensors):
    mtdAST = MtdAST()
    assert type(program) is AST.Let
    lastLetASTNode = program
    while True:
        if type(lastLetASTNode.expr) is AST.Let:
            lastLetASTNode = lastLetASTNode.expr
        else:
            break
    assert lastLetASTNode is not None
    if output_tensors is None:
        output_name = lastLetASTNode.name
        tf_node_name = lastLetASTNode.decl.metadata[
            AST.ASTNode.mtdKeyTFNodeName]
        print(
            "Output not specified, taking output of ",
            tf_node_name,
            " as program output.",
        )
        output = AST.Output(output_name, AST.Party.CLIENT)
        mtdForCurAST = {
            AST.ASTNode.mtdKeyTFOpName: "Output",
            AST.ASTNode.mtdKeyTFNodeName: tf_node_name,
        }
        mtdAST.visit(output, mtdForCurAST)
        lastLetASTNode.expr = output
    else:
        outVarCt = 0
        outVarPrefix = "O"
        for i in range(0, len(output_tensors)):  # name, decl, expr
            t_name = output_tensors[i]
            if t_name not in dictNodeNameToOutVarStr:
                if ":" in t_name:
                    try:
                        op_name, out_n = t_name.split(":")
                        out_n = int(out_n)
                    except:
                        raise ValueError(
                            "The tensor name {} looks like a tensor name but is not a valid one"
                            .format(name))
                    if out_n == 0:
                        if op_name in dictNodeNameToOutVarStr:
                            t_name = op_name
                        else:
                            t_name = op_name + "_mpc_const_var"
                    else:
                        t_name = op_name + "_mpc_const_var" + ":" + str(out_n)
                else:
                    t_name += "_mpc_const_var"
            if i == len(output_tensors) - 1:
                output_name = AST.ID(dictNodeNameToOutVarStr[t_name])
                output = AST.Output(output_name, AST.Party.CLIENT)
                newNode = output
            else:
                output_name = AST.ID(dictNodeNameToOutVarStr[t_name])
                output = AST.Output(output_name, AST.Party.CLIENT)
                let_name_id = AST.ID(outVarPrefix + str(outVarCt))
                newNode = AST.Let(let_name_id, output, AST.ASTNode())
                mtdForCurAST = {
                    AST.ASTNode.mtdKeyTFOpName: "Output",
                    AST.ASTNode.mtdKeyTFNodeName: t_name,
                }
                mtdAST.visit(newNode, mtdForCurAST)
            lastLetASTNode.expr = newNode
            lastLetASTNode = newNode
            outVarCt += 1

    return program
예제 #7
0
def main():
    sys.setrecursionlimit(10000)
    # First read the ONNX file
    if (len(sys.argv) < 2):
        print("TF python file unspecified.", file=sys.stderr)
        exit(1)
    file_name = sys.argv[1]
    file_path = 'models/' + file_name
    model_name = file_name[:-5]  # name without the '.onnx' extension

    # load the model and extract the graph
    model = onnx.load(file_path)
    graph_def = model.graph

    print(model.graph.value_info)
    # Before shape inference (model.graph.value_info) should have shapes of all the variables and constants
    model.graph.value_info.append(
        make_tensor_value_info(
            model.graph.input[0].name, TensorProto.FLOAT,
            common.proto_val_to_dimension_tuple(model.graph.input[0])))
    model.graph.value_info.append(
        make_tensor_value_info(
            model.graph.output[0].name, TensorProto.FLOAT,
            common.proto_val_to_dimension_tuple(model.graph.output[0])))

    print(model.graph.value_info)

    for init_vals in model.graph.initializer:
        model.graph.value_info.append(
            make_tensor_value_info(init_vals.name, TensorProto.FLOAT,
                                   tuple(init_vals.dims)))

    if (DEBUG):
        print("Shape inference *****************")
        print(model.graph.value_info)

    inferred_model = onnx.shape_inference.infer_shapes(model)

    if (DEBUG):
        print("Printing shape ******************")
        print(inferred_model.graph.value_info)
        print("Done ******************")

    # value_info: dictionary of name -> (type, dimension tuple)
    value_info = {}
    for val in inferred_model.graph.value_info:
        value_info[val.name] = (val.type.tensor_type.elem_type,
                                common.proto_val_to_dimension_tuple(val))

    # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes
    program = None
    innermost_let_ast_node = None
    node_name_to_out_var_dict = {}
    out_var_count = 0
    mtdAST = MtdAST()

    (program, innermost_let_ast_node, out_var_count) = process_input_variables(
        program, innermost_let_ast_node, node_name_to_out_var_dict,
        out_var_count, mtdAST, graph_def, value_info)

    process_onnx_nodes(innermost_let_ast_node, node_name_to_out_var_dict,
                       out_var_count, mtdAST, graph_def, value_info)

    PrintAST().visit(program)

    common.write_debug_info(node_name_to_out_var_dict)

    with open('debug/' + model_name + '/' + model_name + '.pkl', 'wb') as f:
        pickle.dump(program, f)