예제 #1
0
파일: Compiler.py 프로젝트: mpc-msri/EzPC
    def run(self):
        with open(Util.Config.astFile, "rb") as ff:
            ast = pickle.load(ff)

        if not (Util.Config.disableAllOpti):
            if not (Util.Config.disableRMO):
                print("Performing Relu-maxpool optimization...")
                ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast)
                print("Relu-maxpool optimization done.")

            if not (Util.Config.disableLivenessOpti):
                print("Performing Garbage collection...")
                mtdAST = MtdAST()
                GC = GarbageCollector.GarbageCollector(ast)
                GC.run([mtdAST])
                print("Garbage collection done.")

        # Perform type inference and annotate nodes with type information
        InferType().visit(ast)

        # if Util.Config.printASTBool :
        if False:
            PrintAST().visit(ast)
            print("\n")
            sys.stdout.flush()

        IRUtil.init()
        compiler = IRBuilderCSF()
        res = compiler.visit(ast)
        res = self.fixOuputScale(res, compiler)
        res = self.fixNames(res, compiler)

        Util.write_debug_info(compiler.name_mapping)

        # Insert a generic start_computation and end_computation function call after all input IR statements.
        res = self.insertStartEndFunctionCalls(res)
        writer = Writer(Util.Config.outputFileName)
        debugVarEzPCName = (compiler.name_mapping[Util.Config.debugVar] if
                            (Util.Config.debugVar
                             in compiler.name_mapping) else None)

        if Util.forEzPC():
            codegen = EzPCCodegen(writer, compiler.globalDecls,
                                  debugVarEzPCName)
        else:
            assert False

        codegen.printAll(*res)
        writer.close()
예제 #2
0
def generate_seedot_ast(model, value_info, model_dir):
    graph_def = model.graph
    # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes
    program = None
    innermost_let_ast_node = None
    node_name_to_out_var_dict = {}
    out_var_count = 0
    mtdAST = MtdAST()

    (program, innermost_let_ast_node, out_var_count) = process_input_variables(
        program,
        innermost_let_ast_node,
        node_name_to_out_var_dict,
        out_var_count,
        mtdAST,
        graph_def,
        value_info,
    )

    process_onnx_nodes(
        innermost_let_ast_node,
        node_name_to_out_var_dict,
        out_var_count,
        mtdAST,
        graph_def,
        value_info,
        model,
    )

    output_tensors = [i.name for i in graph_def.output]
    addOutputs(
        output_tensors,
        innermost_let_ast_node,
        node_name_to_out_var_dict,
        mtdAST,
        value_info,
    )

    if DEBUG:
        PrintAST().visit(program)
        common.write_debug_info(node_name_to_out_var_dict)

    with open(os.path.join(model_dir, "astOutput.pkl"), "wb") as f:
        # print(program)
        pickle.dump(program, f)
        print("Dumped SeeDot AST")
예제 #3
0
    def run(self):
        with open(Util.Config.astFile, 'rb') as ff:
            ast = pickle.load(ff)

        if not (Util.Config.disableAllOpti):
            if not (Util.Config.disableRMO):
                print("Performing Relu-maxpool optimization...")
                # Perform optimizations on the AST
                ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast)

            if not (Util.Config.disableLivenessOpti):
                print("Performing Liveness Optimization...")
                # Perform liveness analysis optimization on the AST
                mtdAST = MtdAST()
                LivenessOpti.LivenessAnalysis().visit(ast)
                LivenessOpti.LivenessOpti().visit(ast, [mtdAST, 0, {}])

        if Util.Config.printASTBool:
            PrintAST().visit(ast)
            sys.stdout.flush()

# Perform type inference
        InferType().visit(ast)

        IRUtil.init()
        compiler = IRBuilderCSF()
        res = compiler.visit(ast)

        Util.write_debug_info(compiler.name_mapping)

        # Insert a generic start_computation and end_computation function call after all input IR statements.
        res = self.insertStartEndFunctionCalls(res)

        writer = Writer(Util.Config.outputFileName)

        debugVarEzPCName = compiler.name_mapping[Util.Config.debugVar] if (
            Util.Config.debugVar in compiler.name_mapping) else None

        if Util.forEzPC():
            codegen = EzPCCodegen(writer, compiler.decls, debugVarEzPCName)
        else:
            assert False

        codegen.printAll(*res)
        writer.close()
예제 #4
0
def main():
    sys.setrecursionlimit(10000)
    # First read the ONNX file
    if (len(sys.argv) < 2):
        print("TF python file unspecified.", file=sys.stderr)
        exit(1)
    file_name = sys.argv[1]
    file_path = 'models/' + file_name
    model_name = file_name[:-5]  # name without the '.onnx' extension

    # load the model and extract the graph
    model = onnx.load(file_path)
    graph_def = model.graph

    print(model.graph.value_info)
    # Before shape inference (model.graph.value_info) should have shapes of all the variables and constants
    model.graph.value_info.append(
        make_tensor_value_info(
            model.graph.input[0].name, TensorProto.FLOAT,
            common.proto_val_to_dimension_tuple(model.graph.input[0])))
    model.graph.value_info.append(
        make_tensor_value_info(
            model.graph.output[0].name, TensorProto.FLOAT,
            common.proto_val_to_dimension_tuple(model.graph.output[0])))

    print(model.graph.value_info)

    for init_vals in model.graph.initializer:
        model.graph.value_info.append(
            make_tensor_value_info(init_vals.name, TensorProto.FLOAT,
                                   tuple(init_vals.dims)))

    if (DEBUG):
        print("Shape inference *****************")
        print(model.graph.value_info)

    inferred_model = onnx.shape_inference.infer_shapes(model)

    if (DEBUG):
        print("Printing shape ******************")
        print(inferred_model.graph.value_info)
        print("Done ******************")

    # value_info: dictionary of name -> (type, dimension tuple)
    value_info = {}
    for val in inferred_model.graph.value_info:
        value_info[val.name] = (val.type.tensor_type.elem_type,
                                common.proto_val_to_dimension_tuple(val))

    # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes
    program = None
    innermost_let_ast_node = None
    node_name_to_out_var_dict = {}
    out_var_count = 0
    mtdAST = MtdAST()

    (program, innermost_let_ast_node, out_var_count) = process_input_variables(
        program, innermost_let_ast_node, node_name_to_out_var_dict,
        out_var_count, mtdAST, graph_def, value_info)

    process_onnx_nodes(innermost_let_ast_node, node_name_to_out_var_dict,
                       out_var_count, mtdAST, graph_def, value_info)

    PrintAST().visit(program)

    common.write_debug_info(node_name_to_out_var_dict)

    with open('debug/' + model_name + '/' + model_name + '.pkl', 'wb') as f:
        pickle.dump(program, f)