def compile_graph(self, acc): """ 编译当前准确率下对应的计算图为pb模型,准确率仅作为模型命名的一部分 :param acc: 准确率 :return: """ input_graph = tf.Graph() tf.compat.v1.keras.backend.clear_session() tf.compat.v1.reset_default_graph() predict_sess = tf.compat.v1.Session(graph=input_graph) tf.compat.v1.keras.backend.set_session(predict_sess) with predict_sess.graph.as_default(): model = core.NeuralNetwork(model_conf=self.model_conf, mode=RunMode.Predict, backbone=self.model_conf.neu_cnn, recurrent=self.model_conf.neu_recurrent) model.build_graph() model.build_train_op() input_graph_def = predict_sess.graph.as_graph_def() saver = tf.compat.v1.train.Saver( var_list=tf.compat.v1.global_variables()) tf.compat.v1.logging.info( tf.train.latest_checkpoint(self.model_conf.model_root_path)) saver.restore( predict_sess, tf.train.latest_checkpoint(self.model_conf.model_root_path)) output_graph_def = convert_variables_to_constants( predict_sess, input_graph_def, output_node_names=['dense_decoded']) if not os.path.exists(self.model_conf.compile_model_path): os.makedirs(self.model_conf.compile_model_path) last_compile_model_path = (os.path.join( self.model_conf.compile_model_path, "{}.pb".format(self.model_conf.model_name))).replace( '.pb', '_{}.pb'.format(int(acc * 10000))) self.model_conf.output_config(target_model_name="{}_{}".format( self.model_conf.model_name, int(acc * 10000))) with tf.io.gfile.GFile(last_compile_model_path, mode='wb') as gf: gf.write(output_graph_def.SerializeToString()) if self.model_conf.neu_recurrent not in [ RecurrentNetwork.BiLSTM, RecurrentNetwork.BiGRU, RecurrentNetwork.BiLSTMcuDNN, ]: self.compile_onnx(predict_sess, output_graph_def, last_compile_model_path, self.model_conf.loss_func)
def freeze_graph(model_dir, output_node_names): """Extract the sub graph defined by the output nodes and convert all its variables into constant Args: model_dir: the root folder containing the checkpoint state file output_node_names: a string, containing all the output node's names, comma separated """ if not tf.gfile.Exists(model_dir): raise AssertionError( "Export directory doesn't exists. Please specify an export " "directory: %s" % model_dir) if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(model_dir) input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_dir + "/frozen_model.pb" # We clear devices to allow TensorFlow to control on which device it will load operations clear_devices = True # We start a session using a temporary fresh Graph with tf.Session(graph=tf.Graph()) as sess: # We import the meta graph in the current default Graph saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We restore the weights saver.restore(sess, input_checkpoint) # We use a built-in TF helper to export variables to constants output_graph_def = tf_graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights tf.get_default_graph().as_graph_def( ), # The graph_def is used to retrieve the nodes output_node_names.split( "," ) # The output node names are used to select the usefull nodes ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node)) return output_graph_def
def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True): """Freezes the state of a session into a pruned computation graph.""" output_names = [i.split(':')[:-1][0] for i in output_names] graph = sess.graph with graph.as_default(): freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] input_graph_def = graph.as_graph_def(add_shapes=True) if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = convert_variables_to_constants(sess, input_graph_def, output_names, freeze_var_names) return frozen_graph
def freeze_graph(input_saved_model_dir, output_saved_model_dir, node_names): with tf.Session() as sess: meta_graph_def = tf.saved_model.loader.load( sess, [tf.saved_model.tag_constants.SERVING], input_saved_model_dir) node_names = node_names + [ n for n in meta_graph_def.collection_def["table_initializer"].node_list.value ] node_names = node_names + [ re.sub(r':\d$', "", n) for n in meta_graph_def.collection_def["asset_filepaths"].node_list.value ] #frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, node_names) frozen_graph_def = tf_graph_util.convert_variables_to_constants( sess, meta_graph_def.graph_def, node_names) if tf.gfile.IsDirectory(output_saved_model_dir): tf.gfile.DeleteRecursively(output_saved_model_dir) with tf.Graph().as_default() as graph: tf.import_graph_def(frozen_graph_def, name="") if len(meta_graph_def.collection_def["table_initializer"].node_list. value) > 0: main_op = graph.get_operation_by_name( meta_graph_def.collection_def["table_initializer"].node_list. value[0]) else: main_op = None assets = meta_graph_def.collection_def[ "asset_filepaths"].node_list.value assets = [graph.get_tensor_by_name(n) for n in assets] with tf.Session() as sess: builder = tf.saved_model.builder.SavedModelBuilder( output_saved_model_dir) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map=meta_graph_def.signature_def, assets_collection=assets, main_op=main_op) builder.save()