Пример #1
0
    def compile_graph(self, acc):
        """
        编译当前准确率下对应的计算图为pb模型,准确率仅作为模型命名的一部分
        :param acc: 准确率
        :return:
        """
        input_graph = tf.Graph()
        tf.compat.v1.keras.backend.clear_session()
        tf.compat.v1.reset_default_graph()
        predict_sess = tf.compat.v1.Session(graph=input_graph)
        tf.compat.v1.keras.backend.set_session(predict_sess)

        with predict_sess.graph.as_default():
            model = core.NeuralNetwork(model_conf=self.model_conf,
                                       mode=RunMode.Predict,
                                       backbone=self.model_conf.neu_cnn,
                                       recurrent=self.model_conf.neu_recurrent)
            model.build_graph()
            model.build_train_op()
            input_graph_def = predict_sess.graph.as_graph_def()
            saver = tf.compat.v1.train.Saver(
                var_list=tf.compat.v1.global_variables())
            tf.compat.v1.logging.info(
                tf.train.latest_checkpoint(self.model_conf.model_root_path))
            saver.restore(
                predict_sess,
                tf.train.latest_checkpoint(self.model_conf.model_root_path))

            output_graph_def = convert_variables_to_constants(
                predict_sess,
                input_graph_def,
                output_node_names=['dense_decoded'])

        if not os.path.exists(self.model_conf.compile_model_path):
            os.makedirs(self.model_conf.compile_model_path)

        last_compile_model_path = (os.path.join(
            self.model_conf.compile_model_path,
            "{}.pb".format(self.model_conf.model_name))).replace(
                '.pb', '_{}.pb'.format(int(acc * 10000)))

        self.model_conf.output_config(target_model_name="{}_{}".format(
            self.model_conf.model_name, int(acc * 10000)))
        with tf.io.gfile.GFile(last_compile_model_path, mode='wb') as gf:
            gf.write(output_graph_def.SerializeToString())

        if self.model_conf.neu_recurrent not in [
                RecurrentNetwork.BiLSTM,
                RecurrentNetwork.BiGRU,
                RecurrentNetwork.BiLSTMcuDNN,
        ]:
            self.compile_onnx(predict_sess, output_graph_def,
                              last_compile_model_path,
                              self.model_conf.loss_func)
Пример #2
0
def freeze_graph(model_dir, output_node_names):
    """Extract the sub graph defined by the output nodes and convert
    all its variables into constant
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names,
                            comma separated
    """
    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % model_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                           clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf_graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(
            ),  # The graph_def is used to retrieve the nodes
            output_node_names.split(
                ","
            )  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

    return output_graph_def
Пример #3
0
def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True):
    """Freezes the state of a session into a pruned computation graph."""
    output_names = [i.split(':')[:-1][0] for i in output_names]
    graph = sess.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def(add_shapes=True)
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(sess, input_graph_def, output_names, freeze_var_names)
        return frozen_graph
def freeze_graph(input_saved_model_dir, output_saved_model_dir, node_names):
    with tf.Session() as sess:
        meta_graph_def = tf.saved_model.loader.load(
            sess, [tf.saved_model.tag_constants.SERVING],
            input_saved_model_dir)
        node_names = node_names + [
            n for n in
            meta_graph_def.collection_def["table_initializer"].node_list.value
        ]
        node_names = node_names + [
            re.sub(r':\d$', "", n) for n in
            meta_graph_def.collection_def["asset_filepaths"].node_list.value
        ]
        #frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, node_names)
        frozen_graph_def = tf_graph_util.convert_variables_to_constants(
            sess, meta_graph_def.graph_def, node_names)

    if tf.gfile.IsDirectory(output_saved_model_dir):
        tf.gfile.DeleteRecursively(output_saved_model_dir)
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(frozen_graph_def, name="")
        if len(meta_graph_def.collection_def["table_initializer"].node_list.
               value) > 0:
            main_op = graph.get_operation_by_name(
                meta_graph_def.collection_def["table_initializer"].node_list.
                value[0])
        else:
            main_op = None
        assets = meta_graph_def.collection_def[
            "asset_filepaths"].node_list.value
        assets = [graph.get_tensor_by_name(n) for n in assets]
        with tf.Session() as sess:
            builder = tf.saved_model.builder.SavedModelBuilder(
                output_saved_model_dir)
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map=meta_graph_def.signature_def,
                assets_collection=assets,
                main_op=main_op)
            builder.save()