def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weights): model_name = os.path.basename(model_fname)[:-3] print("Loading tf graph ", model_fname) graph = tf_graph_io.load_pb(model_fname) assert tensors_exist(graph, output_t_names) if input_t_info == {}: input_t_info = infer_input_info(graph) else: tensors_exist(graph, list(input_t_info.keys())) graph = set_input_shapes(graph, input_t_info) input_t_names = list(input_t_info.keys()) graph_def = grappler.optimize(graph, input_t_names, output_t_names) graph_def = grappler.convert_consts_to_var(graph_def) graph = get_graph_from(graph_def) feed_dict = {} for name, shape in input_t_info.items(): tensor = get_tensor(graph, name) zeros = np.zeros(shape) feed_dict[tensor] = zeros cwd = os.getcwd() with graph.as_default(): with tf.compat.v1.Session() as sess: # Run initializers generated by preprocessing if check_operation_exists(graph, "init_constvars"): sess.run(graph.get_operation_by_name("init_constvars")) sess.run(tf.compat.v1.global_variables_initializer()) model_dir = os.path.realpath(os.path.dirname(model_fname)) os.chdir(model_dir) # At this stage the graph still has constants embedded in it # in the assign nodes for variables. We cannot execute the graph without # these constants. We strip them away in a new graph def which is amenable # to codegen but leave them in the graph. optimized_graph_def = DumpTFMtData.strip_variable_init_constants( graph_def, input_t_names, output_t_names) tf_graph_io.dump_graph_def_pb(optimized_graph_def, "optimised_" + model_name + ".pb") DumpTFMtData.save_graphdef(optimized_graph_def) DumpTFMtData.save_sizeinfo(optimized_graph_def, sess, feed_dict) print("Model compilation done.") weights_path = "" if save_weights: weights_fname = (model_name + "_input_weights_fixedpt_scale_" + str(scaling_factor) + ".inp") print( "\nDumping model weights in ", model_dir + "/" + weights_fname, ".\nThese are to be used as input for party which owns the model\n", ) DumpTFMtData.save_weights(optimized_graph_def, sess, feed_dict, weights_fname, scaling_factor) weights_path = os.path.join(model_dir, weights_fname) os.chdir(cwd) return weights_path
if __name__ == "__main__": if len(sys.argv) != 2: print("""Usage: python remove_node.py config.json config.json should have the following fields. { "model_name" : "model.pb", "nodes_to_replace" : ["loss", "model_outputs"] } """) sys.exit() config_path = sys.argv[1] with open(config_path) as f: try: config = json.load(f) except JSONDecodeError as e: sys.exit("Error while parsing the config json:\n" + e.msg + " at line no. " + str(e.lineno)) model_name = config["model_name"] nodes_to_replace = config["nodes_to_remove"] gd = load_graph_def_pb(model_name) to_replace = [n for n in gd.node if n.name in nodes_to_replace] for n in gd.node: if n.name in nodes_to_replace: n.op = "Identity" new_graph_name = "processed_" + model_name dump_graph_def_pb(gd, new_graph_name) print("Pruned graph is dumped in {}".format(new_graph_name))