def match_entities(self, tags_lists):
     """
     返回Batch个句子中所有可能存在关系的实体对
     [[s1, e1, s2, e2, r],
      [s1, e1, s2, e2, 0]
      ...]]
     """
     tags_lists = torch.max(tags_lists, 2)[1]
     if self.opt.use_gpu:
         tags_lists = tags_lists.cpu()
     tags_lists = tags_lists.tolist()
     all_entitys = []
     for tags_list in tags_lists:
         all_entity = []
         tags_list = [self.id2tag[str(i)] for i in tags_list]
         ent_and_position = get_entities(tags_list)
         for ent1 in ent_and_position:
             for ent2 in ent_and_position:
                 if ent2 == ent1:
                     continue
                 ent2_for_ent1 = self.type2types.get(ent1[0], [])
                 if ent2[0] not in ent2_for_ent1:
                     continue
                 all_entity.append([ent1[1], ent1[2], ent2[1], ent2[2], 0])
         all_entitys.append(all_entity)
     return all_entitys
예제 #2
0
 def get_sample_all_entity2relation(self, tags_list, golden_map):
     """
     返回一个句子所有可能实体组合极其关系
     [[s1, e1, s2, e2, r],
      [s1, e1, s2, e2, 0]
      ...]]
     """
     all_entity = []
     NA_entity = []
     NA_num = 0
     rel_num = 0
     tags_list = [self.id2tag[i] for i in tags_list]
     ent_and_position = metrics.get_entities(tags_list)
     for ent1 in ent_and_position:
         for ent2 in ent_and_position:
             if ent2 == ent1:
                 continue
             ent2_for_ent1 = self.type2types.get(ent1[0], [])
             if ent2[0] not in ent2_for_ent1:
                 continue
             entity_tuple = (ent1[1], ent1[2], ent2[1], ent2[2])
             # 0代表关系为NA
             re = golden_map.get(entity_tuple, self.r2id['NA'])
             ent_list = [entity_tuple[i] for i in range(4)]
             ent_list.append(re)
             if re == self.r2id['NA']:
                 NA_entity.append(ent_list)
             else:
                 all_entity.append(ent_list)
     rel_num = len(all_entity)
     if len(NA_entity) > 0:
         all_entity.extend(NA_entity[:min(2, len(NA_entity))])
         NA_num = min(opt.naNum, len(NA_entity))
     return all_entity, rel_num, NA_num
예제 #3
0
def interAct(model, data_iterator, params, mark='Interactive', verbose=False):
    """Evaluate the model on `steps` batches."""
    # set model to evaluation mode
    model.eval()

    idx2tag = params.idx2tag

    true_tags = []
    pred_tags = []

    # a running average object for loss
    loss_avg = utils.RunningAverage()

    batch_data, batch_token_starts = next(data_iterator)
    batch_masks = batch_data.gt(0)

    batch_output = model((batch_data, batch_token_starts),
                         token_type_ids=None,
                         attention_mask=batch_masks)[
                             0]  # shape: (batch_size, max_len, num_labels)

    batch_output = batch_output.detach().cpu().numpy()

    pred_tags.extend([[idx2tag.get(idx) for idx in indices]
                      for indices in np.argmax(batch_output, axis=2)])

    return (get_entities(pred_tags))
예제 #4
0
def postprocess(params):
    """分析文本形式结果
    """
    # get text
    with open(params.data_dir / f'test.bio', 'r', encoding='utf-8') as f:
        sentences = [line.strip().split(' ') for line in f]

    # 预测标签
    with open(params.data_dir / f'test_tags_pre.txt', 'r') as f:
        result = []
        sample_list = []
        for idx, line in enumerate(f):
            # get BIO-tag
            entities = get_entities(line.strip().split(' '))
            for entity in entities:
                label_type = IO2STR[entity[0]]
                start_ind = entity[1]
                end_ind = entity[2]
                # get en from sentence
                en = sentences[idx // len(IO2STR)][start_ind:end_ind + 1]
                sample_list.append((label_type, ''.join(en)))
            # one sample
            if (idx + 1) % len(IO2STR) == 0:
                result.append(sample_list)
                sample_list = []
    return result
예제 #5
0
def analyze_result(params):
    """分析文本形式结果
    """
    # get text
    with open(params.data_dir / f'{args.mode}.data', 'r',
              encoding='utf-8') as f:
        text_data = json.load(f)
        sentences = [list(sample['context'].strip()) for sample in text_data]

    # 预测标签
    result = []
    with open(params.data_dir / f'{args.mode}_tags_pre.txt', 'r') as f:
        sample_list = []
        for idx, line in enumerate(f):
            # get BIO-tag
            entities = get_entities(line.strip().split(' '))
            for entity in entities:
                label_type = IO2STR[entity[0]]
                start_ind = entity[1]
                end_ind = entity[2]
                # get en from sentence
                en = sentences[idx // len(IO2STR)][start_ind:end_ind + 1]
                sample_list.append((label_type, ''.join(en)))
            # one sample
            if (idx + 1) % len(IO2STR) == 0:
                result.append(sample_list)
                sample_list = []

    return result
예제 #6
0
def apply_fn(group):
    result = []
    # 获取该组的所有实体
    for tags, s2o in zip(group.tags, group.split_to_ori):
        entities = get_entities(eval(tags))
        for entity in entities:
            result.append((entity[0], eval(s2o)[entity[1]], eval(s2o)[entity[2]]))
    return result
예제 #7
0
 def merged_slot(self, tokens, pred_lbls):
     chunks = get_entities(pred_lbls)
     slot_result = {}
     for chunk in chunks:
         tag, start, end = chunk[0], chunk[1], chunk[2]
         tok = ''.join(tokens[chunk[1]:chunk[2] + 1])
         # string = '<{0}>: {1}'.format(tag, tok)
         while tag in slot_result:
             tag += '#'
         slot_result[tag] = tok
     return slot_result
예제 #8
0
def pretty_print(tokens, pred_lbls, pred_cls):
    print('\n==============RAW==================', flush=True)
    print('{0}\n{1}'.format(' '.join(tokens), ' '.join(pred_lbls)), flush=True)

    chunks = get_entities(pred_lbls)
    slot_result = []
    print('\n===================================', flush=True)
    print('Intent\n\t', pred_cls, flush=True)
    print('Slots', flush=True)
    for chunk in chunks:
        tag, start, end = chunk[0], chunk[1], chunk[2]
        tok = ''.join(tokens[chunk[1]:chunk[2] + 1])
        string = '<{0}>: {1}'.format(tag, tok)
        slot_result.append(string)

    print('\t' + '\n\t'.join(slot_result), flush=True)
    print('===================================', flush=True)
예제 #9
0
def get_type_entity(f, sentences):
    """获取实体类别和文本
    :param f: 标签文件
    :param sentences (List[List[str]]): 文本
    :return: result: 实体类别和文本
    """
    result = []
    for idx, line in enumerate(f):
        # get BIO-tag
        entities = get_entities(line.strip().split(' '))
        sample_list = []
        for entity in entities:
            label_type = IO2str[entity[0]]
            start_ind = entity[1]
            end_ind = entity[2]
            en = sentences[idx][start_ind:end_ind + 1]
            sample_list.append({label_type: ''.join(en)})
        result.append(sample_list)
    return result
예제 #10
0
def get_submit():
    # read tags.txt to dict
    with open(params.data_dir / 'test/tags_pre.txt', 'r') as f:
        submit = {}
        for idx, line in enumerate(f):
            entities = get_entities(line.strip().split(' '))
            sample_list = []
            for entity in entities:
                enti_dict = {
                    'label_type': None,
                    'overlap': 0,
                    'start_pos': None,
                    'end_pos': None
                }
                enti_dict['label_type'] = IO2str[entity[0].strip()]
                enti_dict['start_pos'] = entity[1] + 1
                enti_dict['end_pos'] = entity[2] + 1
                sample_list.append(enti_dict)
            submit[f"validate_V2_{idx + 1}.json"] = sample_list

        # convert dict to json
        with open(params.data_dir / 'submit.json', 'w', encoding='utf-8') as w:
            json_data = json.dumps(submit, indent=4, ensure_ascii=False)
            w.write(json_data)
예제 #11
0
def generate_report_txt(
        model,
        dl_test,
        save_dir,
        criterion_clsf=nn.CrossEntropyLoss().to(device),
        criterion_tgt=nn.CrossEntropyLoss(ignore_index=PAD).to(device),
        verbose=False):
    loss_test = 0
    pred_tags = []
    true_tags = []

    pred_clss = []
    true_clss = []
    criterion_clsf = criterion_clsf
    criterion_tgt = criterion_tgt

    idx2lbl = load_obj(save_dir + 'idx2lbl.json')
    idx2cls = load_obj(save_dir + "idx2cls.json")

    sents = load_obj(save_dir + "TestDataSentence.txt")
    for enc, tgt, cls in dl_test[:]:
        model.eval()
        with torch.no_grad():
            enc = enc.to(device)
            tgt = tgt.to(device)
            cls = cls.to(device)
            enc_self_attn_mask = get_attn_pad_mask(enc, enc)
            enc_self_attn_mask.to(device)

            logits_tgt, logits_clsf = model(enc, enc_self_attn_mask)
            loss_tgt = criterion_tgt(logits_tgt.transpose(1, 2),
                                     tgt)  # for masked LM
            loss_tgt = (loss_tgt.float()).mean()
            loss_clsf = criterion_clsf(logits_clsf,
                                       cls)  # for sentence classification
            loss = loss_clsf + loss_tgt
            # loss = loss_clsf
            loss_test += loss

        pad_mask = enc.data.eq(0).sum(axis=1)

        score_cls, cls_idx = torch.max(logits_clsf, dim=-1)
        # get valid slot for a specific intent
        idx_mask = load_mask(save_dir)
        masked_logits_tgt = softmax_mask(logits_tgt, cls_idx, idx_mask)
        score_tgt, tgt_idx = torch.max(masked_logits_tgt, dim=-1)

        for pre, true, pad_num in zip(tgt_idx, tgt, pad_mask):
            pred_tags.append(pre[0:-pad_num].data.tolist())
            true_tags.append(true[0:-pad_num].data.tolist())

        pred_clss += cls_idx.tolist()
        true_clss += cls.tolist()

    print("Prediction completed", flush=True)

    lines_correct = []
    lines_intent_error = []
    lines_slot_error = []
    for idx in range(len(true_clss)):
        tokens = sents[idx].split(' ')
        true_lbls = []
        pred_lbls = []
        true_tags_idx = true_tags[idx]
        pred_tags_idx = pred_tags[idx]
        for t, p in zip(true_tags_idx, pred_tags_idx):
            true_lbls.append(idx2lbl[str(t)])
            pred_lbls.append(idx2lbl[str(p)])

        true_entities = get_entities(true_lbls)
        pred_entities = get_entities(pred_lbls)

        slots_true = []
        slots_pred = []

        for chunk_true, chunk_pred, cls_true, cls_pred in zip(
                true_entities, pred_entities, true_clss, pred_clss):
            tag, start, end = chunk_true[0], chunk_true[1], chunk_true[2]
            tok = ''.join(tokens[start:end + 1])
            slot_true = '<{0}>: {1}'.format(tag, tok)
            slots_true.append(slot_true)

            tag, start, end = chunk_pred[0], chunk_pred[1], chunk_pred[2]
            tok = ''.join(tokens[start:end + 1])
            slot_pred = '<{0}>: {1}'.format(tag, tok)
            slots_pred.append(slot_pred)

        intent_true = idx2cls[str(true_clss[idx])]
        intent_pred = idx2cls[str(pred_clss[idx])]

        line = "Sentence:{0:}\nExpect: \t{1}\t{2}\nPredict:\t{3}\t{4}\n".format(
            sents[idx], intent_true, slots_true, intent_pred, slots_pred)
        if intent_true != intent_pred:
            lines_intent_error.append(line)
        elif slots_true != slots_pred:
            lines_slot_error.append(line)
        else:
            lines_correct.append(line)

    correct_num = len(lines_correct)
    intent_w_num = len(lines_intent_error)
    slot_w_num = len(lines_slot_error)
    total_line = len(lines_intent_error) + len(lines_correct) + len(
        lines_slot_error)

    score1 = 'total line = {0}; Exact match = {1}, with intent fail = {2}, with slot fail = {3};'.format(
        total_line, correct_num, intent_w_num, slot_w_num)
    score2 = 'Accuracy = {0:.4f}'.format(correct_num / total_line)
    scores = [score1, score2]

    # saveing report
    print("Saving reports...", flush=True)
    report_dir = os.path.join(save_dir, 'reports', '')
    create_dir(report_dir)

    remove_old_file(report_dir + 'reports_correct.txt')
    remove_old_file(report_dir + 'reports_intent_error.txt')
    remove_old_file(report_dir + 'reports_slot_error.txt')
    remove_old_file(report_dir + 'scores.txt')

    with open(report_dir + 'reports_correct.txt', 'w', encoding='utf-8') as f:
        for line in lines_correct:
            f.write("{0}".format(line + '\n'))
    with open(report_dir + 'reports_intent_error.txt', 'w',
              encoding='utf-8') as f:
        for line in lines_intent_error:
            f.write("{0}".format(line + '\n'))
    with open(report_dir + 'reports_slot_error.txt', 'w',
              encoding='utf-8') as f:
        for line in lines_slot_error:
            f.write("{0}".format(line + '\n'))
    with open(report_dir + 'scores.txt', 'w', encoding='utf-8') as f:
        for line in scores:
            f.write("{0}".format(line + '\n'))