示例#1
0
def get_f1(model, config, test=True):
    if test:
        model = BiLSTM_CRF(config)
        model.load_state_dict(torch.load(config.model_save_path))
        model.eval()
        x, y = x_test, y_test
    else:
        x, y = x_valid, y_valid

    n_batch = math.ceil(len(x) / config.batch_size)
    entity_pred, entity_true = [], []

    for i in range(n_batch):
        start = i * config.batch_size
        end = (i + 1) * config.batch_size if i != (n_batch - 1) else len(x)
        batch_ids, batch_inputs, batch_outputs, masks, length = random_batch(embeddings,
                                                                             x[start:end],
                                                                             y[start:end],
                                                                             end - start,
                                                                             False)
        scores, sequences = model(batch_inputs, masks, length)
        entity_pred += retrieve_entity(batch_ids, sequences, masks, id2tag, id2word)
        entity_true += retrieve_entity(batch_ids, batch_outputs.numpy(), masks, id2tag, id2word)

    union = [i for i in entity_pred if i in entity_true]
    precision = float(len(union)) / len(entity_pred)
    recall = float(len(union)) / len(entity_true)
    f1_score = 2 * precision * recall / (precision + recall) if len(union) != 0 else 0.0

    return entity_pred, f1_score, precision, recall
示例#2
0
def singel_predict(model_path, content, char_to_id_json_path, batch_size,
                   embedding_dim, hidden_dim, num_layers, sentence_length,
                   offset, target_type_list, tag2id):

    char_to_id = json.load(
        open(char_to_id_json_path, mode="r", encoding="utf-8"))
    # 将字符串转为码表id列表
    char_ids = content_to_id(content, char_to_id)
    # 处理成 batch_size * sentence_length 的 tensor 数据
    # 定义模型输入列表
    model_inputs_list, model_input_map_list = build_model_input_list(
        content, char_ids, batch_size, sentence_length, offset)
    # 加载模型
    model = BiLSTM_CRF(vocab_size=len(char_to_id),
                       tag_to_ix=tag2id,
                       embedding_dim=embedding_dim,
                       hidden_dim=hidden_dim,
                       batch_size=batch_size,
                       num_layers=num_layers,
                       sequence_length=sentence_length)
    # 加载模型字典
    model.load_state_dict(torch.load(model_path))

    tag_id_dict = {
        v: k
        for k, v in tag_to_id.items() if k[2:] in target_type_list
    }
    # 定义返回实体列表
    entities = []
    with torch.no_grad():
        for step, model_inputs in enumerate(model_inputs_list):
            prediction_value = model(model_inputs)
            # 获取每一行预测结果
            for line_no, line_value in enumerate(prediction_value):
                # 定义将要识别的实体
                entity = None
                # 获取当前行每个字的预测结果
                for char_idx, tag_id in enumerate(line_value):
                    # 若预测结果 tag_id 属于目标字典数据 key 中
                    if tag_id in tag_id_dict:
                        # 取符合匹配字典id的第一个字符,即B, I
                        tag_index = tag_id_dict[tag_id][0]
                        # 计算当前字符确切的下标位置
                        current_char = model_input_map_list[step][line_no][
                            char_idx]
                        # 若当前字标签起始为 B, 则设置为实体开始
                        if tag_index == "B":
                            entity = current_char
                        # 若当前字标签起始为 I, 则进行字符串追加
                        elif tag_index == "I" and entity:
                            entity += current_char
                    # 当实体不为空且当前标签类型为 O 时,加入实体列表
                    if tag_id == tag_to_id["O"] and entity:
                        # 满足当前字符为O,上一个字符为目标提取实体结尾时,将其加入实体列表
                        entities.append(entity)
                        # 重置实体
                        entity = None
    return entities
示例#3
0
    'Frequency': 6,
    'Amount': 7,
    'Method': 8,
    'Treatment': 9,
    'Operation': 10,
    'Anatomy': 11,
    'Level': 12,
    'Duration': 13,
    'SideEff': 14,
    'O': 15,
    START_TAG: 16,
    STOP_TAG: 17
}

model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
model.load_state_dict(
    torch.load('../model_dict/hidim512_1_0.6659692762422823_params.pkl'))
model.cuda()  # 调用cuda

base_path = '/home/lingang/chris/knowledge_graph'
# base_path = '/media/chris/D/challenge/knowledge_graph_rename'

file_txt_path = os.path.join(base_path, 'dataset', 'test_data')

result_dict = {}

with torch.no_grad():
    for p in os.listdir(file_txt_path):
        p1 = os.path.join(file_txt_path, p)
        output = []
        with open(p1, 'rb') as f:
            ans = pickle.load(f)