Esempio n. 1
0
def predict_line(param):
    # 初始化日志对象
    logger = get_logger(param.test_log_file)
    tf_config = tf.ConfigProto()
    # 读取字典
    mapping_dict = get_dict(param.dict_file)
    # 根据保存的模型读取模型
    model = Model(param, mapping_dict)
    # 开始测试
    with tf.Session(config=tf_config) as sess:
        # 首先检查模型是否存在
        ckpt_path = param.ckpt_path
        ckpt = tf.train.get_checkpoint_state(ckpt_path)
        # 看是否存在训练好的模型
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            logger.info("Reading model parameters from {}".format(
                ckpt.model_checkpoint_path))
            # 如果存在就进行重新加载
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            logger.info("Cannot find the ckpt files!")
        while True:
            # 反复输入句子进行预测
            line = input("请输入测试句子:")
            raw_inputs, model_inputs = input_from_line_with_feature(line)
            tag = model.evaluate_line(sess, model_inputs)
            result = result_to_json(raw_inputs, tag)
            result = js.dumps(result,
                              ensure_ascii=False,
                              indent=4,
                              separators=(',', ': '))
            with open('./result/result.json', 'w', encoding='utf-8') as f:
                f.write(result)
            print("预测结果为:{}".format(result))
Esempio n. 2
0
    def predict_from_pb(self, document):
        row = {'content': document}
        df = pandas.DataFrame().append(row, ignore_index=True)
        filename = "data/{}.csv".format(time.time())
        df.to_csv(filename, index=False, escapechar="\\", columns=['content'])

        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            with open(os.path.join(args.model_folder, 'money_model.pb'), "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")

            with tf.Session() as sess:  # config=config
                sess.run(tf.global_variables_initializer())

                # Get the input placeholders from the graph by name
                char_input = sess.graph.get_tensor_by_name('CharInputs:0')
                seg_input = sess.graph.get_tensor_by_name('SegInputs:0')
                drop_keep_prob = sess.graph.get_tensor_by_name('Dropout:0')
                # Tensors we want to evaluate,outputs
                lengths = sess.graph.get_tensor_by_name('lengths:0')
                logits = sess.graph.get_tensor_by_name("project/logits_outputs:0")
                # predictions = sess.graph.get_tensor_by_name("Accuracy/predictions:0")

                trans = sess.graph.get_tensor_by_name("crf_loss/transitions:0")

                # fo = open(test_data_path, "r", encoding='utf8')
                # all_data = fo.readlines()
                # fo.close()
                # for line in all_data:  # 一行行遍历
                lines = document.split(r"\n")
                lines = [line for line in lines if len(line)>0]

                list_original = []
                list_amounts = []
                for line in lines:
                    input_batch = input_from_line(line, self.char_to_id)  # 处理测试数据格式
                    feed_dict = create_feed_dict(input_batch, char_input, seg_input, drop_keep_prob)  # 创建输入的feed_dict
                    seq_len, scores = sess.run([lengths, logits], feed_dict)
                    print('---')

                    transition_matrix = trans.eval()
                    batch_paths = decode(scores, seq_len, transition_matrix)
                    tags = [self.id_to_tag[str(idx)] for idx in batch_paths[0]]
                    #print(tags)
                    result = result_to_json(input_batch[0][0], tags)
                    original = str(result['string'])
                    entities = result['entities']

                    if len(entities) != 0:
                        list_original.extend([original] * len(entities))
                        for entity in entities:
                            #是数字金额需要增强逻辑
                            if digit_regex.match(entity['word']) and len(entity['word']) >= 1:
                                aug_word = augment(entity['word'], line)
                                list_amounts.append(aug_word)
                            else:
                                list_amounts.append(entity['word'])
                            #print(entity['word'])
                return {"answer":list_amounts,"line":list_original}
Esempio n. 3
0
def predict(input_str):
    
    with open(config.map_file, "rb") as f:
        char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)
    
    """ 用cpu预测 """
    model = torch.load(os.path.join(config.save_dir,"medical_ner_f1_0.976.ckpt"), 
                       map_location="cpu"
    )
    model.eval()
    
    if not input_str:
        input_str = input("请输入文本: ")    
    
    _, char_ids, seg_ids, _ = prepare_dataset([input_str], char_to_id, tag_to_id, test=True)[0]
    char_tensor = torch.LongTensor(char_ids).view(1,-1)
    seg_tensor = torch.LongTensor(seg_ids).view(1,-1)
    
    with torch.no_grad():
        
        """ 得到维特比解码后的路径,并转换为标签 """
        paths = model(char_tensor,seg_tensor)    
        tags = [id_to_tag[idx] for idx in paths[0]]
    
    pprint(result_to_json(input_str, tags))
Esempio n. 4
0
def cpu_predict(input_str):
    with open(config.data_proc_file, "rb") as f:
        train_data, dev_data, test_data = pickle.load(f)
        char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)
        emb_matrix = pickle.load(f)

    with open(config.map_file, "rb") as f:
        char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)

    device = torch.device("cuda" if None else "cpu")
    model = NERLSTM_CRF(config, char_to_id, tag_to_id, emb_matrix, device)
    state_dict = torch.load(os.path.join(config.save_dir, "medical_ner.ckpt"),
                            map_location="cpu")
    model.load_state_dict(state_dict)
    """ 用cpu预测 """

    model.eval()
    if not input_str:
        input_str = input("请输入文本: ")

    _, char_ids, seg_ids, _ = prepare_dataset([input_str],
                                              char_to_id,
                                              tag_to_id,
                                              test=True)[0]
    char_tensor = torch.LongTensor(char_ids).view(1, -1)
    seg_tensor = torch.LongTensor(seg_ids).view(1, -1)

    with torch.no_grad():
        """ 得到维特比解码后的路径,并转换为标签 """
        paths = model(char_tensor, seg_tensor)
        tags = [id_to_tag[idx] for idx in paths[0]]

    pprint(result_to_json(input_str, tags))
Esempio n. 5
0
def predict(input_str):

    with open(config.map_file, "rb") as f:
        char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)

    model = torch.load(os.path.join(config.save_dir,
                                    "medical_ner_0.9723.ckpt"),
                       map_location="cpu")

    if not input_str:
        input_str = input("请输入文本: ")

    char_ids, seg_ids, _ = prepare_dataset([input_str],
                                           char_to_id,
                                           tag_to_id,
                                           test=True)[0]
    char_tensor = torch.LongTensor(char_ids).view(1, -1)
    seg_tensor = torch.LongTensor(seg_ids).view(1, -1)
    """ 得到维特比解码后的路径,并转换为标签 """
    paths = model(char_tensor, seg_tensor)
    tags = [id_to_tag[idx] for idx in paths[0]]
    """ 把开头和结尾的<start>,<end>标签去掉 """
    tags.pop(0)
    tags.pop(-1)

    pprint(result_to_json(input_str, tags))
Esempio n. 6
0
def predict_from_pb():
    with open(FLAGS.json_file,
              "r") as f:  # with open(FLAGS.map_file, "rb") as f:
        char_to_id, id_to_char, tag_to_id, id_to_tag = json.load(
            f)  # pickle.load(f)
        print('json file loaded')
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")

        with tf.Session() as sess:  # config=config
            sess.run(tf.global_variables_initializer())

            # Get the input placeholders from the graph by name
            char_input = sess.graph.get_tensor_by_name('CharInputs:0')
            seg_input = sess.graph.get_tensor_by_name('SegInputs:0')
            drop_keep_prob = sess.graph.get_tensor_by_name('Dropout:0')
            # Tensors we want to evaluate,outputs
            lengths = sess.graph.get_tensor_by_name('lengths:0')
            logits = sess.graph.get_tensor_by_name("project/logits_outputs:0")
            # predictions = sess.graph.get_tensor_by_name("Accuracy/predictions:0")

            trans = sess.graph.get_tensor_by_name("crf_loss/transitions:0")

            fo = open(test_data_path, "r", encoding='utf8')
            all_data = fo.readlines()
            fo.close()
            for line in all_data:  # 一行行遍历
                input_batch = input_from_line(line, char_to_id)  # 处理测试数据格式
                feed_dict = create_feed_dict(input_batch, char_input,
                                             seg_input,
                                             drop_keep_prob)  # 创建输入的feed_dict
                seq_len, scores = sess.run([lengths, logits], feed_dict)
                print('---')

                transition_matrix = trans.eval()
                batch_paths = decode(scores, seq_len, transition_matrix)
                tags = [id_to_tag[str(idx)] for idx in batch_paths[0]]
                print(tags)
                result = result_to_json(input_batch[0][0], tags)
                original = str(result['string'])
                entities = result['entities']
                if len(entities) != 0:
                    for entity in entities:
                        print(entity['word'])
Esempio n. 7
0
    def predict_text(model, input_str):
        if not input_str:
            input_str = input("请输入文本: ")

        _, char_ids, seg_ids, _ = prepare_dataset([input_str],
                                                  char_to_id,
                                                  tag_to_id,
                                                  test=True)[0]
        char_tensor = torch.LongTensor(char_ids).view(1, -1)
        seg_tensor = torch.LongTensor(seg_ids).view(1, -1)

        with torch.no_grad():
            """ 得到维特比解码后的路径,并转换为标签 """
            paths = model(char_tensor, seg_tensor)
            tags = [id_to_tag[idx] for idx in paths[0]]

        return result_to_json(input_str, tags)