Example #1
0
def compile(model_fname, input_t_info, output_t_names, scaling_factor,
            save_weights):
    model_name = os.path.basename(model_fname)[:-3]
    print("Loading tf graph ", model_fname)
    graph = tf_graph_io.load_pb(model_fname)
    assert tensors_exist(graph, output_t_names)

    if input_t_info == {}:
        input_t_info = infer_input_info(graph)
    else:
        tensors_exist(graph, list(input_t_info.keys()))
        graph = set_input_shapes(graph, input_t_info)
    input_t_names = list(input_t_info.keys())
    graph_def = grappler.optimize(graph, input_t_names, output_t_names)
    graph_def = grappler.convert_consts_to_var(graph_def)
    graph = get_graph_from(graph_def)

    feed_dict = {}
    for name, shape in input_t_info.items():
        tensor = get_tensor(graph, name)
        zeros = np.zeros(shape)
        feed_dict[tensor] = zeros

    cwd = os.getcwd()
    with graph.as_default():
        with tf.compat.v1.Session() as sess:
            # Run initializers generated by preprocessing
            if check_operation_exists(graph, "init_constvars"):
                sess.run(graph.get_operation_by_name("init_constvars"))
            sess.run(tf.compat.v1.global_variables_initializer())
            model_dir = os.path.realpath(os.path.dirname(model_fname))
            os.chdir(model_dir)

            # At this stage the graph still has constants embedded in it
            # in the assign nodes for variables. We cannot execute the graph without
            # these constants. We strip them away in a new graph def which is amenable
            # to codegen but leave them in the graph.
            optimized_graph_def = DumpTFMtData.strip_variable_init_constants(
                graph_def, input_t_names, output_t_names)

            tf_graph_io.dump_graph_def_pb(optimized_graph_def,
                                          "optimised_" + model_name + ".pb")
            DumpTFMtData.save_graphdef(optimized_graph_def)
            DumpTFMtData.save_sizeinfo(optimized_graph_def, sess, feed_dict)
            print("Model compilation done.")
            weights_path = ""
            if save_weights:
                weights_fname = (model_name + "_input_weights_fixedpt_scale_" +
                                 str(scaling_factor) + ".inp")
                print(
                    "\nDumping model weights in ",
                    model_dir + "/" + weights_fname,
                    ".\nThese are to be used as input for party which owns the model\n",
                )
                DumpTFMtData.save_weights(optimized_graph_def, sess, feed_dict,
                                          weights_fname, scaling_factor)
                weights_path = os.path.join(model_dir, weights_fname)
    os.chdir(cwd)
    return weights_path
if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("""Usage: python remove_node.py config.json
config.json should have the following fields.
{
  "model_name" : "model.pb",
  "nodes_to_replace" : ["loss", "model_outputs"]
}

""")
        sys.exit()
    config_path = sys.argv[1]
    with open(config_path) as f:
        try:
            config = json.load(f)
        except JSONDecodeError as e:
            sys.exit("Error while parsing the config json:\n" + e.msg +
                     " at line no. " + str(e.lineno))
    model_name = config["model_name"]
    nodes_to_replace = config["nodes_to_remove"]
    gd = load_graph_def_pb(model_name)
    to_replace = [n for n in gd.node if n.name in nodes_to_replace]
    for n in gd.node:
        if n.name in nodes_to_replace:
            n.op = "Identity"

    new_graph_name = "processed_" + model_name
    dump_graph_def_pb(gd, new_graph_name)
    print("Pruned graph is dumped in {}".format(new_graph_name))