コード例 #1
0
def model_init(model_name):
    if os.name == 'nt':  # windows path config
        model_dir = 'C:/ChineseNLP/NERmaster/%s/model' % model_name
        #model_dir ='C:/ChineseNLP/NERmaster/albert'
        bert_dir = 'C:/ChineseNLP/NERmaster/bert_model_info/chinese_L-12_H-768_A-12'
    else:  # linux path config
        model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name
        bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'

    batch_size = 1
    max_seq_length = 500

    print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))
    if not os.path.exists(os.path.join(model_dir, "checkpoint")):
        raise Exception("failed to get checkpoint. going to return ")

    # 加载label->id的词典
    with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
        label2id = pickle.load(rf)
        id2label = {value: key for key, value in label2id.items()}

    with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
        label_list = pickle.load(rf)
    num_labels = len(label_list) + 1

    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    graph = tf.Graph()
    sess = tf.Session(graph=graph, config=gpu_config)

    with graph.as_default():
        print("going to restore checkpoint")
        # sess.run(tf.global_variables_initializer())
        input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length],
                                     name="input_ids")
        input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length],
                                      name="input_mask")

        bert_config = modeling.BertConfig.from_json_file(
            os.path.join(bert_dir, 'bert_config.json'))
        (total_loss, logits, trans,
         pred_ids) = create_model(bert_config=bert_config,
                                  is_training=False,
                                  input_ids=input_ids_p,
                                  input_mask=input_mask_p,
                                  segment_ids=None,
                                  labels=None,
                                  num_labels=num_labels,
                                  use_one_hot_embeddings=False,
                                  dropout_rate=1.0)

        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(model_dir))

    tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(
        bert_dir, 'vocab.txt'),
                                           do_lower_case=args.do_lower_case)

    return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length
コード例 #2
0
def optimize_ner_model(args, num_labels,  logger=None):
    """
    加载中文NER模型
    :param args:
    :param num_labels:
    :param logger:
    :return:
    """
    if not logger:
        logger = set_logger(colored('NER_MODEL, Lodding...', 'cyan'), args.verbose)
    try:
        # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径
        if args.model_pb_dir is None:
            # 获取当前的运行路径
            tmp_file = os.path.join(os.getcwd(), 'predict_optimizer')
            if not os.path.exists(tmp_file):
                os.mkdir(tmp_file)
        else:
            tmp_file = args.model_pb_dir
        pb_file = os.path.join(tmp_file, 'ner_model.pb')
        if os.path.exists(pb_file):
            print('pb_file exits', pb_file)
            return pb_file

        import tensorflow as tf
        graph = tf.Graph()
        with graph.as_default():
            with tf.Session() as sess:
                input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
                input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')

                bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json'))
                from bert_base.train.models import create_model
                (total_loss, logits, trans, pred_ids) = create_model(
                    bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=None,
                    labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0, lstm_size=args.lstm_size)
                pred_ids = tf.identity(pred_ids, 'pred_ids')
                print("server.graph.py_line290: ",pred_ids.shape) # (?, 128)
                saver = tf.train.Saver()

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                saver.restore(sess, tf.train.latest_checkpoint(args.model_dir))
                logger.info('freeze...')
                from tensorflow.python.framework import graph_util
                # 这里是把输出节点给重新命名了:pred_ids = tf.identity(pred_ids, 'pred_ids') 即'pred_ids'
                tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_ids'])
                logger.info('model cut finished !!!')
                # from tensorflow.python.tools import freeze_graph


        # 存储二进制模型到文件中
        logger.info('write graph to a tmp file: %s' % pb_file)
        with tf.gfile.GFile(pb_file, 'wb') as f:
            f.write(tmp_g.SerializeToString())
        return pb_file
    except Exception as e:
        logger.error('fail to optimize the graph! %s' % e, exc_info=True)
コード例 #3
0
ファイル: bert_ner.py プロジェクト: monkeyfx/bert-django
def optimize_ner_model(model_pb_file='', max_seq_len=128,bert_model_dir='',model_dir='',num_labels=[]):
    """
    :param model_pb_file: ner pb模型文件路径
    :param max_seq_len: 最大长度
    :param bert_model_dir: bert模型文件所在目录
    :param model_dir: ner模型所在目录
    :param num_labels: 标签list
    :return:
    """
    lg.info('NER_MODEL, Loading...')
    try:
        # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径

        pb_file = os.path.join(model_pb_file)
        if os.path.exists(pb_file):
            print('pb_file exits', pb_file)
            return pb_file
        # 不存在pb file ,则保存pb文件
        lg.info('%s dont exist,need create and save it! ' % pb_file)
        import tensorflow as tf

        graph = tf.Graph()
        with graph.as_default():
            with tf.Session() as sess:
                input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
                input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')

                bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_model_dir, 'bert_config.json'))
                from bert_base.train.models import create_model
                (total_loss, logits, trans, pred_ids) = create_model(
                    bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=None,
                    labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)
                pred_ids = tf.identity(pred_ids, 'pred_ids')
                saver = tf.train.Saver()

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                saver.restore(sess, tf.train.latest_checkpoint(model_dir))
                lg.info('freeze...')
                from tensorflow.python.framework import graph_util
                tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_ids'])
                lg.info('model cut finished !!!')
        # 存储二进制模型到文件中
        lg.info('write graph to a tmp file: %s' % pb_file)
        with tf.gfile.GFile(pb_file, 'wb') as f:
            f.write(tmp_g.SerializeToString())
        return pb_file
    except Exception as e:
        lg.error('fail to optimize the graph! %s' % e, exc_info=True)
コード例 #4
0
with open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
    label_list = pickle.load(rf)
num_labels = len(label_list) + 1


graph = tf.get_default_graph()
with graph.as_default():
    print("going to restore checkpoint")
    #sess.run(tf.global_variables_initializer())
    input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids")
    input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask")

    bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))
    (total_loss, logits, trans, pred_ids) = create_model(
        bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None,
        labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(model_dir))

tokenizer = tokenization.FullTokenizer(
        vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=True)

app = flask.Flask(__name__)


@app.route('/ner_predict_service', methods=['GET'])
def ner_predict_service():
    """
    do online prediction. each time make prediction for one instance.
コード例 #5
0
    def model_fn(features, labels, mode, params):
        tf.logging.info("*** Features ***")

        #没有属性先注释掉
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        print('shape of input_ids', input_ids.shape)
        # label_mask = features["label_mask"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # 使用参数构建模型,input_idx 就是输入的样本idx表示,label_ids 就是标签的idx表示
        total_loss, logits, trans, pred_ids = create_model(
            bert_config, is_training, input_ids, input_mask, segment_ids,
            label_ids, num_labels, False, args.dropout_rate, args.lstm_size,
            args.cell, args.num_layers)

        tvars = tf.trainable_variables()
        # 加载BERT模型
        if init_checkpoint:
            (assignment_map, initialized_variable_names) = \
                 modeling.get_assignment_map_from_checkpoint(tvars,
                                                             init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        # 打印变量名
        # logger.info("**** Trainable Variables ****")
        #
        # # 打印加载模型的参数
        # for var in tvars:
        #     init_string = ""
        #     if var.name in initialized_variable_names:
        #         init_string = ", *INIT_FROM_CKPT*"
        #     logger.info("  name = %s, shape = %s%s", var.name, var.shape,
        #                     init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            #train_op = optimizer.optimizer(total_loss, learning_rate, num_train_steps)
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, False)
            hook_dict = {}
            hook_dict['loss'] = total_loss
            hook_dict['global_steps'] = tf.train.get_or_create_global_step()
            logging_hook = tf.train.LoggingTensorHook(
                hook_dict, every_n_iter=args.save_summary_steps)

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[logging_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # 针对NER ,进行了修改
            def metric_fn(label_ids, pred_ids):
                return {
                    "eval_loss":
                    tf.metrics.mean_squared_error(labels=label_ids,
                                                  predictions=pred_ids),
                }

            eval_metrics = metric_fn(label_ids, pred_ids)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=total_loss, eval_metric_ops=eval_metrics)
        else:
            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     predictions=pred_ids)
        return output_spec
コード例 #6
0
def main(max_seq_len, model_dir, num_labels):

    with tf.Session() as sess:
        #输入占位符
        input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
        input_mask = tf.placeholder(tf.int32, (None, max_seq_len),
                                    'input_mask')
        #模型前向传播
        from bert_base.bert import modeling
        bert_config_file = "D:\\Program Files\\JetBrains\\PyCharm 2017.2.4\\bert_ner_3\\cased_L-12_H-768_A-12\\bert_config.json"

        bert_config = modeling.BertConfig.from_json_file(bert_config_file)
        from bert_base.train.models import create_model

        (total_loss, logits, trans,
         pred_ids) = create_model(bert_config=bert_config,
                                  is_training=False,
                                  input_ids=input_ids,
                                  input_mask=input_mask,
                                  segment_ids=None,
                                  labels=None,
                                  num_labels=num_labels,
                                  use_one_hot_embeddings=False)
        pred_ids = tf.identity(pred_ids, 'pred_ids')
        print("server.graph.py_line290: ", pred_ids.shape)  # (?, 128)
        saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        latest_checkpoint = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, latest_checkpoint)
        # Create SavedModelBuilder class
        # defines where the model will be exported
        export_path_base = FLAGS.export_model_dir
        export_path = os.path.join(
            tf.compat.as_bytes(export_path_base),
            tf.compat.as_bytes(str(FLAGS.model_version)))
        print('Exporting trained model to', export_path)
        if os.path.exists(export_path):
            shutil.rmtree(export_path)

        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        # Creates the TensorInfo protobuf objects that encapsulates the input/output tensors
        input_ids_tensor = tf.saved_model.utils.build_tensor_info(input_ids)
        input_mask_tensor = tf.saved_model.utils.build_tensor_info(input_mask)
        # output tensor info
        pred_ids_output = tf.saved_model.utils.build_tensor_info(pred_ids)

        # Defines the DeepLab signatures, uses the TF Predict API
        # It receives an image and its dimensions and output the segmentation mask
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={
                    'input_ids': input_ids_tensor,
                    'input_mask': input_mask_tensor
                },
                outputs={'pred_label': pred_ids_output},
                method_name=tf.saved_model.signature_constants.
                PREDICT_METHOD_NAME))

        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'result': prediction_signature,
            })
        # export the model
        # builder.save(as_text=True) # saved_model.pbtxt
        builder.save()  # saved_model.pb
        print('Done exporting!')
コード例 #7
0
    print("going to restore checkpoint")
    #sess.run(tf.global_variables_initializer())
    input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length],
                                 name="input_ids")
    input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length],
                                  name="input_mask")

    bert_config = modeling.BertConfig.from_json_file(
        os.path.join(bert_dir, 'bert_config.json'))
    (total_loss, logits, trans,
     pred_ids) = create_model(bert_config=bert_config,
                              is_training=False,
                              input_ids=input_ids_p,
                              input_mask=input_mask_p,
                              segment_ids=None,
                              labels=None,
                              num_labels=num_labels,
                              use_one_hot_embeddings=False,
                              dropout_rate=1.0,
                              num_layers=1,
                              lstm_size=128)

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(model_dir))

tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(
    bert_dir, 'vocab.txt'),
                                       do_lower_case=True)


@app.route('/ner_predicts_service', methods=['POST'])
コード例 #8
0
with graph.as_default():
    print("going to restore checkpoint")
    input_ids_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length],
                                 name="input_ids")
    input_mask_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length],
                                  name="input_mask")

    bert_config = modeling.BertConfig.from_json_file(
        os.path.join(bert_dir, 'bert_config.json'))
    (total_loss, logits, trans,
     pred_ids) = create_model(bert_config=bert_config,
                              is_training=False,
                              input_ids=input_ids_p,
                              input_mask=input_mask_p,
                              segment_ids=None,
                              labels=None,
                              num_labels=num_labels,
                              use_one_hot_embeddings=False,
                              lstm_size=args.lstm_size,
                              dropout_rate=1.0,
                              crf_only=args.crf_only,
                              is_add_self_attention=args.is_add_self_attention)

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(model_dir))

tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(
    bert_dir, 'vocab.txt'),
                                       do_lower_case=args.do_lower_case)


def predict_online(sentence):
コード例 #9
0
    def model_fn(features, labels, mode, params):
        logger.info("*** Features ***")
        for name in sorted(features.keys()):
            logger.info("  name = %s, shape = %s" %
                        (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        print('shape of input_ids', input_ids.shape)
        # label_mask = features["label_mask"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # 使用参数构建模型,input_idx 就是输入的样本idx表示,label_ids 就是标签的idx表示
        total_loss, logits, pred_ids = create_model(
            bert_config, is_training, input_ids, input_mask, segment_ids,
            label_ids, num_labels, False, args.dropout_rate, args.lstm_size,
            args.cell, args.num_layers, args.crf_only, args.lstm_only)

        tvars = tf.trainable_variables()
        # 加载BERT模型
        if init_checkpoint:
            (assignment_map, initialized_variable_names) = \
                 modeling.get_assignment_map_from_checkpoint(tvars,
                                                             init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
        # enable trainable on last layers, and frozen other layers
        if args.trainable_last_layers:
            last_layers = [
                'finetune', 'bert/encoder/layer_11', 'bert/encoder/layer_10',
                'bert/encoder/layer_9', 'bert/encoder/layer_8',
                'bert/encoder/layer_7'
            ]
            trainable_layers_index = [
                int(ind) for ind in args.trainable_last_layers.split(',')
            ]
            trainable_layers = [
                name for index, name in enumerate(last_layers)
                if index in trainable_layers_index
            ]
            tf.logging.info("Only allow last n layers are trainable: " +
                            str(trainable_layers))
            filter_trainable_variables(trainable_layers)
        # 打印变量名
        if args.verbose:
            tvars = tf.trainable_variables()
            logger.info("**** Trainable Variables ****")
            # 打印加载模型的参数
            for var in tvars:
                init_string = ""
                if var.name in initialized_variable_names:
                    init_string = ", *INIT_FROM_CKPT*"
                logger.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            #train_op = optimizer.optimizer(total_loss, learning_rate, num_train_steps)
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, False)
            hook_dict = {}
            hook_dict['loss'] = total_loss
            hook_dict['global_steps'] = tf.train.get_or_create_global_step()
            logging_hook = tf.train.LoggingTensorHook(
                hook_dict, every_n_iter=args.save_summary_steps)

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[logging_hook])  #如果不送值,则训练过程中不会显示字典中的数值

        elif mode == tf.estimator.ModeKeys.EVAL:
            # 针对NER ,进行了修改
            def metric_fn(label_ids, pred_ids, mask):
                # 计算无pad的位置
                mask = tf.cast(tf.reshape(mask, [-1]), dtype=tf.bool)
                # 去除pad
                pred_ids = tf.to_int32(tf.reshape(pred_ids, [-1]))[mask]
                label_ids = tf.to_int32(tf.reshape(label_ids, [-1]))[mask]
                # metrics
                loss = tf.metrics.mean_squared_error(labels=label_ids,
                                                     predictions=pred_ids)
                accuracy = tf.metrics.accuracy(label_ids, pred_ids)
                pos_indices = [i for (i, label) in enumerate(label_list, 1)]
                f1 = tf_metrics.f1(label_ids,
                                   pred_ids,
                                   num_labels,
                                   pos_indices=pos_indices)
                return {
                    "eval_loss": loss,
                    "eval_accuracy": accuracy,
                    "eval_f1": f1,
                }

            eval_metrics = metric_fn(label_ids, pred_ids, input_mask)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=total_loss, eval_metric_ops=eval_metrics)
        else:
            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     predictions=pred_ids)
        return output_spec
コード例 #10
0
    def model_fn(features, labels, mode, params):
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        print('shape of input_ids', input_ids.shape)
        # label_mask = features["label_mask"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # 使用参数构建模型,input_idx 就是输入的样本idx表示,label_ids 就是标签的idx表示
        (total_loss, logits, trans,
         pred_ids) = create_model(bert_config, is_training, input_ids,
                                  input_mask, segment_ids, label_ids,
                                  num_labels, False, args.dropout_rate,
                                  args.lstm_size, args.cell, args.num_layers)

        tvars = tf.trainable_variables()
        # 加载BERT模型
        if init_checkpoint:
            (assignment_map, initialized_variable_names) = \
                 modeling.get_assignment_map_from_checkpoint(tvars,
                                                             init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        # 打印变量名
        # logger.info("**** Trainable Variables ****")
        #
        # # 打印加载模型的参数
        # for var in tvars:
        #     init_string = ""
        #     if var.name in initialized_variable_names:
        #         init_string = ", *INIT_FROM_CKPT*"
        #     logger.info("  name = %s, shape = %s%s", var.name, var.shape,
        #                     init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            #train_op = optimizer.optimizer(total_loss, learning_rate, num_train_steps)
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, False)
            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op)
        elif mode == tf.estimator.ModeKeys.EVAL:
            # 针对NER ,进行了修改
            def metric_fn(label_ids, logits, trans):
                # 首先对结果进行维特比解码
                # crf 解码

                weight = tf.sequence_mask(args.max_seq_length)
                precision = tf_metrics.precision(label_ids, pred_ids,
                                                 num_labels, None, weight)
                recall = tf_metrics.recall(label_ids, pred_ids, num_labels,
                                           None, weight)
                f = tf_metrics.f1(label_ids, pred_ids, num_labels, None,
                                  weight)

                return {
                    "eval_precision":
                    precision,
                    "eval_recall":
                    recall,
                    "eval_f":
                    f,
                    "eval_loss":
                    tf.metrics.mean_squared_error(labels=label_ids,
                                                  predictions=pred_ids),
                }

            eval_metrics = metric_fn(label_ids, logits, trans)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=total_loss, eval_metric_ops=eval_metrics)
        else:
            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     predictions=pred_ids)
        return output_spec
コード例 #11
0
    input_mask_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length],
                                  name="input_mask")
    pos_ids_p = tf.placeholder(tf.int32, [batch_size, args.max_seq_length],
                               name="pos_ids")

    bert_config = modeling.BertConfig.from_json_file(
        os.path.join(bert_dir, 'bert_config.json'))

    # jzhang: 注意,如果加载的模型用到了lstm,则一定要设置lstm_size与加载模型的lstm_size相等,不然会报错
    (total_loss, logits, trans, pred_ids, best_score,
     lstm_output) = create_model(bert_config=bert_config,
                                 is_training=False,
                                 input_ids=input_ids_p,
                                 input_mask=input_mask_p,
                                 segment_ids=None,
                                 labels=None,
                                 num_labels=num_labels,
                                 use_one_hot_embeddings=False,
                                 pos_ids=pos_ids_p,
                                 dropout_rate=1.0,
                                 lstm_size=args.lstm_size)

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(model_dir))

tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(
    bert_dir, 'vocab.txt'),
                                       do_lower_case=args.do_lower_case)


def predict_batch(input_txts):