def get_entity_result(feature, id2char, id2label, y_pred): """ 提取实体 :arg tokens: 二维列表,句子处理后得到的token tags: 二维列表,预测的结果 sentences_index: 二维列表,句子拆分后,对应到原句的index :return sentences_entities: 二维列表,返回实体结果,例如[('昆凌', 'PER')...] """ sent_tag = [] y_pred_clean = [] input_ids_clean = [] # 去除 [CLS] 和 [SEP]获取正确的tag范围 print([id2char[i] for i in feature.input_ids]) print(len(feature.input_ids)) print(y_pred[0][0]) print([id2label[i] for i in list(y_pred[0])]) print(len(list(y_pred[0]))) for index, id in enumerate(feature.input_ids): char = id2char[id] tag = id2label[list(y_pred[0])[index]] if char == "[CLS]": continue if char == "[SEP]": break input_ids_clean.append(id) sent_tag.append(tag) y_pred_clean.append(list(y_pred[0])[index]) sent_tag = ' '.join(sent_tag) print(sentence + '\n' + sent_tag) entity = get_entity([sentence], [y_pred_clean], id2label) print('predict_result:') print(entity) return entity
def get_result(sess, model): while True: raw_text = input("Enter your input: ") # text = '北京勘察设计协会副会长兼秘书长周荫如' text = re.split(u'[,。!?、‘’“”()]', raw_text) print('text:') print(text) # data seqs = [] for sent in text: sent_ = sentence2id(sent, char2id) seqs.append(sent_) seq_list, seq_len_list = pad_sequences(seqs, max_len=15) feed_dict = { model.input_x: seq_list, model.input_x_len: seq_len_list, model.dropout_keep_prob: 1.0, model.lr: FLAGS.learning_rate, } time_start = datetime.datetime.now() y_pred = sess.run([model.outputs], feed_dict) print( '每条数据预测时间耗时约:{} ms '.format((datetime.datetime.now() - time_start).microseconds / 1000 )) print(y_pred) entity = get_entity(text, y_pred[0], id2tag) print('predict_result:') print(entity) sent_tag = ' '.join([id2tag[id] for id in list(y_pred[0][0])]) print(raw_text + '\n' + sent_tag) print('entity_result:') for i in entity: print(i)
def predict_outline(): """ do offline prediction. each time make prediction for one instance. you can change to a batch if you want. """ # TODO 以文件形式预测结果 def convert(line): feature = convert_single_example(line, label2id, FLAGS.max_seq_length, tokenizer) input_ids = np.reshape([feature.input_ids], (args.batch_size, FLAGS.max_seq_length)) input_mask = np.reshape([feature.input_mask], (args.batch_size, FLAGS.max_seq_length)) segment_ids = np.reshape([feature.segment_ids], (args.batch_size, FLAGS.max_seq_length)) label_ids = np.reshape([feature.label_ids], (args.batch_size, FLAGS.max_seq_length)) return input_ids, input_mask, segment_ids, label_ids global graph with graph.as_default(): print("going to restore checkpoint") # sess.run(tf.global_variables_initializer()) input_ids_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="input_mask") label_ids_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="label_ids") segment_ids_p = tf.placeholder(tf.int32, [args.batch_size, FLAGS.max_seq_length], name="segment_ids") bert_config = modeling_bert.BertConfig.from_json_file( args.bert_config_file) (total_loss, logits, trans, pred_ids) = create_model(bert_config, args.is_training, input_ids_p, input_mask_p, segment_ids_p, label_ids_p, num_labels, args.use_one_hot_embeddings) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(args.output_dir)) tokenizer = tokenization.FullTokenizer( vocab_file=args.vocab_file, do_lower_case=FLAGS.do_lower_case) # 获取id2char字典 id2char = tokenizer.inv_vocab # TODO 以文件形式预测结果 while True: print('input the test sentence:') sentence = str(input()) start = datetime.now() if len(sentence) < 2: print(sentence) continue # print('your input is:{}'.format(sentence)) input_ids, input_mask, segment_ids, label_ids = convert(sentence) feed_dict = { input_ids_p: input_ids, input_mask_p: input_mask, segment_ids_p: segment_ids, label_ids_p: label_ids } # run session get current feed_dict result y_pred = sess.run([pred_ids], feed_dict) sent_tag = [] y_pred_clean = [] input_ids_clean = [] # 去除 [CLS] 和 [SEP]获取正确的tag范围 for index, id in enumerate(list(np.reshape(input_ids, -1))): char = id2char[id] tag = id2label[list(y_pred[0][0])[index]] if char == "[CLS]": continue if char == "[SEP]": break input_ids_clean.append(id) sent_tag.append(tag) y_pred_clean.append(list(y_pred[0][0])[index]) sent_tag = ' '.join(sent_tag) print(sentence + '\n' + sent_tag) entity = get_entity([sentence], [y_pred_clean], id2label) print('predict_result:') print(entity) print('Time used: {} sec'.format((datetime.now() - start).seconds))