Пример #1
0
def dev_offline(file):
    """
    do online prediction. each time make prediction for one instance.
    you can change to a batch if you want.

    :param line: a list. element is: [dummy_label,text_a,text_b]
    :return:
    """
    def convert(line, label):
        feature = convert_single_example_dev(2, line, label, label2id,
                                             FLAGS.max_seq_length, tokenizer)
        input_ids = np.reshape([feature.input_ids], (1, FLAGS.max_seq_length))
        input_mask = np.reshape([feature.input_mask],
                                (1, FLAGS.max_seq_length))
        segment_ids = np.reshape([feature.segment_ids],
                                 (1, FLAGS.max_seq_length))
        label_ids = np.reshape([feature.label_ids], (1, FLAGS.max_seq_length))
        return input_ids, input_mask, segment_ids, label_ids

    global graph
    with graph.as_default():
        # sess.run(tf.global_variables_initializer())
        input_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                     name="input_ids")
        input_mask_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                      name="input_mask")
        label_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                     name="label_ids")
        segment_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                       name="segment_ids")

        bert_config = modeling_bert.BertConfig.from_json_file(
            args.bert_config_file)
        (total_loss, logits, trans,
         pred_ids) = create_model(bert_config, args.is_training, input_ids_p,
                                  input_mask_p, segment_ids_p, label_ids_p,
                                  num_labels, args.use_one_hot_embeddings)

        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(args.output_dir))

        tokenizer = tokenization.FullTokenizer(
            vocab_file=args.vocab_file, do_lower_case=FLAGS.do_lower_case)
        # 获取id2char字典
        id2char = tokenizer.inv_vocab

        dev_texts, dev_labels = zip(*parse_file(file))
        start = datetime.now()

        pred_labels_all = []
        true_labels_all = []
        x_all = []
        for index, text in enumerate(dev_texts):
            sentence = str(text)
            input_ids, input_mask, segment_ids, label_ids = convert(
                sentence, dev_labels[index])

            feed_dict = {
                input_ids_p: input_ids,
                input_mask_p: input_mask,
                segment_ids_p: segment_ids,
                label_ids_p: label_ids
            }
            # run session get current feed_dict result
            y_pred = sess.run([pred_ids], feed_dict)
            # print(list(y_pred[0][0]))
            # print(len(list(y_pred[0][0])))

            sent_tag = []
            y_pred_clean = []
            input_ids_clean = []
            y_true_clean = []
            # 去除 [CLS] 和 [SEP]获取正确的tag范围
            for index_b, id in enumerate(list(np.reshape(input_ids, -1))):
                char = id2char[id]
                tag = id2label[list(y_pred[0][0])[index_b]]
                if char == "[CLS]":
                    continue
                if char == "[SEP]":
                    break
                input_ids_clean.append(id)
                sent_tag.append(tag)
                y_pred_clean.append(list(y_pred[0][0])[index_b])
                y_true_clean.append(label_ids[0][index_b])

            pred_labels_all.append(y_pred_clean)
            true_labels_all.append(y_true_clean)
            x_all.append(input_ids_clean)

        print('预测标签与真实标签评价结果......')
        print(pred_labels_all)
        print(len(pred_labels_all))
        print(true_labels_all)
        print(len(true_labels_all))

        metrics = Metrics(true_labels_all,
                          pred_labels_all,
                          id2label,
                          remove_O=True)
        metrics.report_scores()
        # metrics.report_confusion_matrix()

        print('预测实体与真实实体评价结果......')
        precision, recall, f1 = entity_metrics(x_all, pred_labels_all,
                                               true_labels_all, id2char,
                                               id2label)
        print("Dev P/R/F1: {} / {} / {}".format(round(precision, 2),
                                                round(recall, 2), round(f1,
                                                                        2)))
        print('Time used: {} sec'.format((datetime.now() - start).seconds))
Пример #2
0
def predict_outline():
    """
    do offline prediction. each time make prediction for one instance.
    you can change to a batch if you want.

    """

    # TODO 以文件形式预测结果
    def convert(line):
        feature = convert_single_example(line, label2id, FLAGS.max_seq_length,
                                         tokenizer)
        input_ids = np.reshape([feature.input_ids],
                               (args.batch_size, FLAGS.max_seq_length))
        input_mask = np.reshape([feature.input_mask],
                                (args.batch_size, FLAGS.max_seq_length))
        segment_ids = np.reshape([feature.segment_ids],
                                 (args.batch_size, FLAGS.max_seq_length))
        label_ids = np.reshape([feature.label_ids],
                               (args.batch_size, FLAGS.max_seq_length))
        return input_ids, input_mask, segment_ids, label_ids

    global graph
    with graph.as_default():
        print("going to restore checkpoint")
        # sess.run(tf.global_variables_initializer())
        input_ids_p = tf.placeholder(tf.int32,
                                     [args.batch_size, FLAGS.max_seq_length],
                                     name="input_ids")
        input_mask_p = tf.placeholder(tf.int32,
                                      [args.batch_size, FLAGS.max_seq_length],
                                      name="input_mask")
        label_ids_p = tf.placeholder(tf.int32,
                                     [args.batch_size, FLAGS.max_seq_length],
                                     name="label_ids")
        segment_ids_p = tf.placeholder(tf.int32,
                                       [args.batch_size, FLAGS.max_seq_length],
                                       name="segment_ids")

        bert_config = modeling_bert.BertConfig.from_json_file(
            args.bert_config_file)
        (total_loss, logits, trans,
         pred_ids) = create_model(bert_config, args.is_training, input_ids_p,
                                  input_mask_p, segment_ids_p, label_ids_p,
                                  num_labels, args.use_one_hot_embeddings)

        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(args.output_dir))

        tokenizer = tokenization.FullTokenizer(
            vocab_file=args.vocab_file, do_lower_case=FLAGS.do_lower_case)
        # 获取id2char字典
        id2char = tokenizer.inv_vocab

        # TODO 以文件形式预测结果
        while True:
            print('input the test sentence:')
            sentence = str(input())
            start = datetime.now()
            if len(sentence) < 2:
                print(sentence)
                continue
            # print('your input is:{}'.format(sentence))
            input_ids, input_mask, segment_ids, label_ids = convert(sentence)

            feed_dict = {
                input_ids_p: input_ids,
                input_mask_p: input_mask,
                segment_ids_p: segment_ids,
                label_ids_p: label_ids
            }
            # run session get current feed_dict result
            y_pred = sess.run([pred_ids], feed_dict)

            sent_tag = []
            y_pred_clean = []
            input_ids_clean = []
            # 去除 [CLS] 和 [SEP]获取正确的tag范围
            for index, id in enumerate(list(np.reshape(input_ids, -1))):
                char = id2char[id]
                tag = id2label[list(y_pred[0][0])[index]]
                if char == "[CLS]":
                    continue
                if char == "[SEP]":
                    break
                input_ids_clean.append(id)
                sent_tag.append(tag)
                y_pred_clean.append(list(y_pred[0][0])[index])

            sent_tag = ' '.join(sent_tag)
            print(sentence + '\n' + sent_tag)
            entity = get_entity([sentence], [y_pred_clean], id2label)
            print('predict_result:')
            print(entity)
            print('Time used: {} sec'.format((datetime.now() - start).seconds))