Ejemplo n.º 1
0
def dev_offline(file):
    """
    do online prediction. each time make prediction for one instance.
    you can change to a batch if you want.

    :param line: a list. element is: [dummy_label,text_a,text_b]
    :return:
    """
    def convert(line, label):
        feature = convert_single_example_dev(2, line, label, label2id,
                                             FLAGS.max_seq_length, tokenizer)
        input_ids = np.reshape([feature.input_ids], (1, FLAGS.max_seq_length))
        input_mask = np.reshape([feature.input_mask],
                                (1, FLAGS.max_seq_length))
        segment_ids = np.reshape([feature.segment_ids],
                                 (1, FLAGS.max_seq_length))
        label_ids = np.reshape([feature.label_ids], (1, FLAGS.max_seq_length))
        return input_ids, input_mask, segment_ids, label_ids

    global graph
    with graph.as_default():
        # sess.run(tf.global_variables_initializer())
        input_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                     name="input_ids")
        input_mask_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                      name="input_mask")
        label_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                     name="label_ids")
        segment_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length],
                                       name="segment_ids")

        bert_config = modeling_bert.BertConfig.from_json_file(
            args.bert_config_file)
        (total_loss, logits, trans,
         pred_ids) = create_model(bert_config, args.is_training, input_ids_p,
                                  input_mask_p, segment_ids_p, label_ids_p,
                                  num_labels, args.use_one_hot_embeddings)

        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(args.output_dir))

        tokenizer = tokenization.FullTokenizer(
            vocab_file=args.vocab_file, do_lower_case=FLAGS.do_lower_case)
        # 获取id2char字典
        id2char = tokenizer.inv_vocab

        dev_texts, dev_labels = zip(*parse_file(file))
        start = datetime.now()

        pred_labels_all = []
        true_labels_all = []
        x_all = []
        for index, text in enumerate(dev_texts):
            sentence = str(text)
            input_ids, input_mask, segment_ids, label_ids = convert(
                sentence, dev_labels[index])

            feed_dict = {
                input_ids_p: input_ids,
                input_mask_p: input_mask,
                segment_ids_p: segment_ids,
                label_ids_p: label_ids
            }
            # run session get current feed_dict result
            y_pred = sess.run([pred_ids], feed_dict)
            # print(list(y_pred[0][0]))
            # print(len(list(y_pred[0][0])))

            sent_tag = []
            y_pred_clean = []
            input_ids_clean = []
            y_true_clean = []
            # 去除 [CLS] 和 [SEP]获取正确的tag范围
            for index_b, id in enumerate(list(np.reshape(input_ids, -1))):
                char = id2char[id]
                tag = id2label[list(y_pred[0][0])[index_b]]
                if char == "[CLS]":
                    continue
                if char == "[SEP]":
                    break
                input_ids_clean.append(id)
                sent_tag.append(tag)
                y_pred_clean.append(list(y_pred[0][0])[index_b])
                y_true_clean.append(label_ids[0][index_b])

            pred_labels_all.append(y_pred_clean)
            true_labels_all.append(y_true_clean)
            x_all.append(input_ids_clean)

        print('预测标签与真实标签评价结果......')
        print(pred_labels_all)
        print(len(pred_labels_all))
        print(true_labels_all)
        print(len(true_labels_all))

        metrics = Metrics(true_labels_all,
                          pred_labels_all,
                          id2label,
                          remove_O=True)
        metrics.report_scores()
        # metrics.report_confusion_matrix()

        print('预测实体与真实实体评价结果......')
        precision, recall, f1 = entity_metrics(x_all, pred_labels_all,
                                               true_labels_all, id2char,
                                               id2label)
        print("Dev P/R/F1: {} / {} / {}".format(round(precision, 2),
                                                round(recall, 2), round(f1,
                                                                        2)))
        print('Time used: {} sec'.format((datetime.now() - start).seconds))
Ejemplo n.º 2
0
                        test_y_pred = dev_step(x_test,
                                               test_seq_len_list,
                                               y_test,
                                               writer=dev_summary_writer)
                        test_labels_all.extend(y_test)
                        test_labels_pred.extend(test_y_pred)
                        x_all.extend(x_test)

                    logger.info('预测标签与真实标签评价结果......')
                    metrics = Metrics(test_labels_all,
                                      test_labels_pred,
                                      id2tag,
                                      remove_O=True)
                    metrics.report_scores()
                    metrics.report_confusion_matrix()

                    logger.info('预测实体与真实实体评价结果......')
                    precision, recall, f1 = entity_metrics(
                        x_all, test_labels_all, test_labels_pred, id2char,
                        id2tag)
                    logger.info("Test P/R/F1: {} / {} / {}".format(
                        round(precision, 2), round(recall, 2), round(f1, 2)))
                    if f1 > best_f1:
                        best_f1 = f1
                    logger.info("")
                if current_step % FLAGS.checkpoint_every == 0:
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info("Saved model checkpoint to {}\n".format(path))
        logger.info('best_f1: {}'.format(best_f1))