def load_data(): # vocab loader token_vocab_fn = os.path.join( os.path.dirname(__file__), 'data', 'token.vocab.txt') token_vocab = Vocab(token_vocab_fn, mode='token') target_vocab_fn = os.path.join( os.path.dirname(__file__), 'data', 'target.vocab.txt') target_vocab = Vocab(target_vocab_fn, mode='target') # load train data train_data_fn = os.path.join( os.path.dirname(__file__), 'data', 'ner.n2n.txt') train_txt_data = N2NTextData(train_data_fn) # convert text data to id data train_id_data = N2NConverter.convert(train_txt_data, target_vocab, token_vocab) return train_id_data, token_vocab, target_vocab
def predict(token_vocab, target_vocab, sent): os.environ[ 'CUDA_VISIBLE_DEVICES'] = '-1' # force to use cpu only (prediction) model_dir = "./trained_models" # prepare sentence converting # to make raw sentence to id data easily pred_data = N2NTextData(sent, mode='sentence') pred_id_data = N2NConverter.convert(pred_data, target_vocab, token_vocab) pred_data_set = NERDataset(pred_id_data, 1, 128) # a_batch_data = next(pred_data_set.predict_iterator) # a result b_nes_id, b_token_ids, b_weight = a_batch_data # Restore graph # note that frozen_graph.tf.pb contains graph definition with parameter values in binary format _graph_fn = os.path.join(model_dir, 'frozen_graph.tf.pb') with tf.gfile.GFile(_graph_fn, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def) with tf.Session(graph=graph) as sess: # to check load graph #for n in tf.get_default_graph().as_graph_def().node: print(n.name) # make interface for input pl_token = graph.get_tensor_by_name('import/model/pl_tokens:0') pl_weight = graph.get_tensor_by_name('import/model/pl_weight:0') pl_keep_prob = graph.get_tensor_by_name('import/model/pl_keep_prob:0') # make interface for output step_out_preds = graph.get_tensor_by_name( 'import/model/step_out_preds:0') step_out_probs = graph.get_tensor_by_name( 'import/model/step_out_probs:0') # predict sentence b_best_step_pred_indexs, b_step_pred_probs = sess.run( [step_out_preds, step_out_probs], feed_dict={ pl_token: b_token_ids, pl_weight: b_weight, pl_keep_prob: 1.0, }) best_step_pred_indexs = b_best_step_pred_indexs[0] step_pred_probs = b_step_pred_probs[0] step_best_targets = [] step_best_target_probs = [] for time_step, best_pred_index in enumerate(best_step_pred_indexs): _target_class = target_vocab.get_symbol(best_pred_index) step_best_targets.append(_target_class) _prob = step_pred_probs[time_step][best_pred_index] step_best_target_probs.append(_prob) for idx, char in enumerate(list(sent)): print('{}\t{}\t{}'.format(char, step_best_targets[idx], step_best_target_probs[idx]))