def dev_offline(file): """ do online prediction. each time make prediction for one instance. you can change to a batch if you want. :param line: a list. element is: [dummy_label,text_a,text_b] :return: """ def convert(line, label): feature = convert_single_example_dev(2, line, label, label2id, FLAGS.max_seq_length, tokenizer) input_ids = np.reshape([feature.input_ids], (1, FLAGS.max_seq_length)) input_mask = np.reshape([feature.input_mask], (1, FLAGS.max_seq_length)) segment_ids = np.reshape([feature.segment_ids], (1, FLAGS.max_seq_length)) label_ids = np.reshape([feature.label_ids], (1, FLAGS.max_seq_length)) return input_ids, input_mask, segment_ids, label_ids global graph with graph.as_default(): # sess.run(tf.global_variables_initializer()) input_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length], name="input_mask") label_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length], name="label_ids") segment_ids_p = tf.placeholder(tf.int32, [1, FLAGS.max_seq_length], name="segment_ids") bert_config = modeling_bert.BertConfig.from_json_file( args.bert_config_file) (total_loss, logits, trans, pred_ids) = create_model(bert_config, args.is_training, input_ids_p, input_mask_p, segment_ids_p, label_ids_p, num_labels, args.use_one_hot_embeddings) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(args.output_dir)) tokenizer = tokenization.FullTokenizer( vocab_file=args.vocab_file, do_lower_case=FLAGS.do_lower_case) # 获取id2char字典 id2char = tokenizer.inv_vocab dev_texts, dev_labels = zip(*parse_file(file)) start = datetime.now() pred_labels_all = [] true_labels_all = [] x_all = [] for index, text in enumerate(dev_texts): sentence = str(text) input_ids, input_mask, segment_ids, label_ids = convert( sentence, dev_labels[index]) feed_dict = { input_ids_p: input_ids, input_mask_p: input_mask, segment_ids_p: segment_ids, label_ids_p: label_ids } # run session get current feed_dict result y_pred = sess.run([pred_ids], feed_dict) # print(list(y_pred[0][0])) # print(len(list(y_pred[0][0]))) sent_tag = [] y_pred_clean = [] input_ids_clean = [] y_true_clean = [] # 去除 [CLS] 和 [SEP]获取正确的tag范围 for index_b, id in enumerate(list(np.reshape(input_ids, -1))): char = id2char[id] tag = id2label[list(y_pred[0][0])[index_b]] if char == "[CLS]": continue if char == "[SEP]": break input_ids_clean.append(id) sent_tag.append(tag) y_pred_clean.append(list(y_pred[0][0])[index_b]) y_true_clean.append(label_ids[0][index_b]) pred_labels_all.append(y_pred_clean) true_labels_all.append(y_true_clean) x_all.append(input_ids_clean) print('预测标签与真实标签评价结果......') print(pred_labels_all) print(len(pred_labels_all)) print(true_labels_all) print(len(true_labels_all)) metrics = Metrics(true_labels_all, pred_labels_all, id2label, remove_O=True) metrics.report_scores() # metrics.report_confusion_matrix() print('预测实体与真实实体评价结果......') precision, recall, f1 = entity_metrics(x_all, pred_labels_all, true_labels_all, id2char, id2label) print("Dev P/R/F1: {} / {} / {}".format(round(precision, 2), round(recall, 2), round(f1, 2))) print('Time used: {} sec'.format((datetime.now() - start).seconds))
def predict_outline(): """ do offline prediction. each time make prediction for one instance. you can change to a batch if you want. """ # TODO 以文件形式预测结果 def convert(line): feature = convert_single_example(line, label2id, FLAGS.max_seq_length, tokenizer) input_ids = np.reshape([feature.input_ids], (args.batch_size, FLAGS.max_seq_length)) input_mask = np.reshape([feature.input_mask], (args.batch_size, FLAGS.max_seq_length)) segment_ids = np.reshape([feature.segment_ids], (args.batch_size, FLAGS.max_seq_length)) label_ids = np.reshape([feature.label_ids], (args.batch_size, FLAGS.max_seq_length)) return input_ids, input_mask, segment_ids, label_ids global graph with graph.as_default(): print("going to restore checkpoint") # sess.run(tf.global_variables_initializer()) input_ids_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="input_mask") label_ids_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="label_ids") segment_ids_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="segment_ids") bert_config = modeling_bert.BertConfig.from_json_file( args.bert_config_file) (total_loss, logits, trans, pred_ids) = create_model(bert_config, args.is_training, input_ids_p, input_mask_p, segment_ids_p, label_ids_p, num_labels, args.use_one_hot_embeddings) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(args.output_dir)) tokenizer = tokenization.FullTokenizer( vocab_file=args.vocab_file, do_lower_case=FLAGS.do_lower_case) # 获取id2char字典 id2char = tokenizer.inv_vocab # TODO 以文件形式预测结果 while True: print('input the test sentence:') sentence = str(input()) start = datetime.now() if len(sentence) < 2: print(sentence) continue # print('your input is:{}'.format(sentence)) input_ids, input_mask, segment_ids, label_ids = convert(sentence) feed_dict = { input_ids_p: input_ids, input_mask_p: input_mask, segment_ids_p: segment_ids, label_ids_p: label_ids } # run session get current feed_dict result y_pred = sess.run([pred_ids], feed_dict) sent_tag = [] y_pred_clean = [] input_ids_clean = [] # 去除 [CLS] 和 [SEP]获取正确的tag范围 for index, id in enumerate(list(np.reshape(input_ids, -1))): char = id2char[id] tag = id2label[list(y_pred[0][0])[index]] if char == "[CLS]": continue if char == "[SEP]": break input_ids_clean.append(id) sent_tag.append(tag) y_pred_clean.append(list(y_pred[0][0])[index]) sent_tag = ' '.join(sent_tag) print(sentence + '\n' + sent_tag) entity = get_entity([sentence], [y_pred_clean], id2label) print('predict_result:') print(entity) print('Time used: {} sec'.format((datetime.now() - start).seconds))