Ejemplo n.º 1
0
 def crf_decode(self, batch_data) -> List[List[int]]:
     chars = group_fields(batch_data, keys="chars")
     token_reprs = self.compute_token_reprs(batch_data)
     scores = self.ner_score(token_reprs)
     masks = torch.tensor(batch_mask(chars, mask_zero=True), dtype=torch.uint8, device=self.device)
     results = self.ner_crf.decode(scores, masks)
     return results
Ejemplo n.º 2
0
    def compute_token_reprs(self, batch_data):
        chars, bichars, segs, poss = group_fields(
            batch_data,
            keys=["chars", "bichars", "segs", "poss"]
        )
        text_lens = batch_lens(chars)

        pad_chars = batch_pad(chars, self.char2idx[Sp.pad])
        pad_bichars = batch_pad(bichars, self.bichar2idx[Sp.pad])
        pad_segs = batch_pad(segs, self.seg2idx[Sp.pad])
        pad_poss = batch_pad(poss, self.pos2idx[Sp.pad])

        pad_chars_tensor = torch.tensor(pad_chars, device=self.device)
        pad_bichars_tensor = torch.tensor(pad_bichars, device=self.device)
        pad_segs_tensor = torch.tensor(pad_segs, device=self.device)
        pad_poss_tensor = torch.tensor(pad_poss, device=self.device)

        input_embs = self.embeds(pad_chars_tensor,
                                 pad_bichars_tensor,
                                 pad_segs_tensor,
                                 pad_poss_tensor)

        if config.token_type == 'rnn':
            token_reprs = self.token_encoder(input_embs, text_lens)
        elif config.token_type == 'tfer':
            masks = self.token_encoder.gen_masks(pad_chars_tensor)
            token_reprs = self.token_encoder(input_embs, masks, text_lens)
        else:
            token_reprs = input_embs
        return token_reprs
Ejemplo n.º 3
0
 def crf_nll(self, batch_data) -> torch.Tensor:
     gold_tags = group_fields(batch_data, keys="ners")
     token_reprs = self.compute_token_reprs(batch_data)
     scores = self.ner_score(token_reprs)
     gold_tags = torch.tensor(batch_pad(gold_tags, 0))
     masks = torch.tensor(batch_mask(gold_tags, mask_zero=True), dtype=torch.uint8, device=self.device)
     crf_loss = self.ner_crf(scores, gold_tags, masks, reduction="mean")
     return - crf_loss
Ejemplo n.º 4
0
def gen_lexicon_vocab(*data_paths, word2vec_path, out_folder, use_cache=False):
    if use_cache and os.path.exists("{}/lexicon.vocab".format(out_folder)):
        log("cache for lexicon vocab exists.")
        return
    words = set()
    for line in open(word2vec_path, encoding="utf8", errors="ignore"):
        word = re.split(r"\s+", line.strip())[0]
        words.add(word)

    lexicon = {Sp.pad: 0, Sp.oov: 1, Sp.non: 2, Sp.sos: 3, Sp.eos: 4}
    for data_path in data_paths:
        print("Gen lexicon for", data_path)
        sentences = load_sentences(data_path)
        for sid, sentence in enumerate(sentences):
            chars = group_fields(sentence, indices=0)
            for i, j in fragments(len(chars), config.max_span_length):
                frag = "".join(chars[i:j + 1])
                if frag not in lexicon and frag in words:
                    lexicon[frag] = len(lexicon)
    create_folder(out_folder)
    f_out = open("{}/lexicon.vocab".format(out_folder), "w", encoding='utf8')
    for k, v in lexicon.items():
        f_out.write("{} {}\n".format(k, v))
    f_out.close()
Ejemplo n.º 5
0
    def get_span_score_tags(self, batch_data, lex_att=False):
        chars = group_fields(batch_data, keys='chars')
        text_lens = batch_lens(chars)
        labels = group_fields(batch_data, keys='labels')
        token_reprs = self.compute_token_reprs(batch_data)

        if config.frag_type != "off":
            frag_reprs = self.fragment_encoder(token_reprs, text_lens,
                                               sos_repr=self.sos_token,
                                               eos_repr=self.eos_token)
            if config.frag_att_type != "off":
                # print(Color.red("ATTENTION!"))
                d_frag = frag_reprs.size(1)
                att_frag_reprs = []
                offset = 0
                for i, text_len in enumerate(text_lens):
                    q = frag_reprs[offset: offset + span_num(text_len)].unsqueeze(0)
                    k = token_reprs[i][:text_len].unsqueeze(0)
                    v = k
                    mask = gen_att_mask(k_lens=(text_len,),
                                        max_k_len=text_len, max_q_len=q.size(1),
                                        device=self.device)
                    att_frag_repr, att_score = self.multi_att(q, k, v, mask)
                    att_frag_reprs.append(att_frag_repr.squeeze(0))
                    offset += text_len
                att_frag_reprs = torch.cat(att_frag_reprs)

                att_frag_reprs = self.att_norm(att_frag_reprs) / math.sqrt(d_frag)
                # show_mean_std(frag_reprs)
                # show_mean_std(att_frag_reprs)
                if config.frag_att_type == 'cat':
                    frag_reprs = torch.cat([frag_reprs, att_frag_reprs], dim=1)
                elif config.frag_att_type == 'add':
                    frag_reprs = (frag_reprs + att_frag_reprs) / math.sqrt(2)
                else:
                    raise Exception
        else:
            frag_reprs = None

        if config.ctx_type in ['include', 'exclude']:
            left_ctx_reprs, right_ctx_reprs = self.context_encoder(token_reprs, text_lens)
            if frag_reprs is not None:
                frag_reprs = torch.cat([frag_reprs, left_ctx_reprs, right_ctx_reprs], dim=1)
            else:
                frag_reprs = torch.cat([left_ctx_reprs, right_ctx_reprs], dim=1)
        elif config.ctx_type == 'off':
            pass
        else:
            raise Exception

        if config.lexicon_emb_pretrain != "off":
            if config.match_mode in ["naive", "middle", "mix"]:
                lexmatches = group_fields(batch_data, keys='lexmatches')
                lexmatches = [item[1] for sublist in lexmatches for item in sublist]
                assert len(lexmatches) == frag_reprs.size(0)
                frag_match_lexicons, frag_match_types = [], []
                for it_match in lexmatches:
                    if len(it_match) == 0:
                        frag_match_lexicons.append([])
                        frag_match_types.append([])
                    else:
                        frag_match_lexicons.append(group_fields(it_match, indices=0))
                        frag_match_types.append(group_fields(it_match, indices=1))

                mask = gen_att_mask(batch_lens(frag_match_lexicons), max_match_num, 1).to(self.device)
                frag_match_lexicons = batch_pad(frag_match_lexicons, pad_len=max_match_num)
                frag_match_types = batch_pad(frag_match_types, pad_len=max_match_num)
                frag_match_lexicons = torch.tensor(frag_match_lexicons, dtype=torch.long, device=self.device)
                frag_match_types = torch.tensor(frag_match_types, dtype=torch.long, device=self.device)
                mem_lexicon = self.lexicon_embeds(frag_match_lexicons)
                mem_match = self.match_embeds(frag_match_types)
                memory = torch.cat([mem_lexicon, mem_match], dim=2)
                if config.match_head == 0:
                    att_word, lex_att_score = self.lexicon_attention(frag_reprs.unsqueeze(1), memory, mask)
                elif config.match_head > 0:
                    att_word, lex_att_score = self.lexicon_attention(frag_reprs.unsqueeze(1), memory, memory, mask)
                else:
                    raise Exception
                frag_reprs = torch.cat([frag_reprs, att_word.squeeze(1)], dim=1)
            elif config.match_mode == "off":
                pass
            else:
                raise Exception

        span_ys = self.gen_span_ys(chars, labels)
        # score = frag_reprs @ self.label_weight + self.label_bias
        score = self.scorer(frag_reprs)
        if lex_att:
            return score, span_ys, lex_att_score
        else:
            return score, span_ys
Ejemplo n.º 6
0
    def __init__(self,
                 data_path,
                 lexicon2idx,
                 char2idx,
                 bichar2idx,
                 seg2idx,
                 pos2idx,
                 ner2idx,
                 label2idx,
                 ignore_pos_bmes=False,
                 max_text_len=19260817,
                 max_span_len=19260817,
                 sort_by_length=False):
        super(ConllDataSet, self).__init__()
        if config.lexicon_emb_pretrain != "off":
            self.word2idx = lexicon2idx
            self.idx2word = {v: k for k, v in self.word2idx.items()}

        sentences = load_sentences(data_path)

        self.__max_text_len = max_text_len
        self.__max_span_len = max_span_len
        self.__longest_text_len = -1
        self.__longest_span_len = -1

        __span_length_count = defaultdict(lambda: 0)
        __sentence_length_count = defaultdict(lambda: 0)

        for sid in range(len(sentences)):
            chars, bichars, segs, labels, poss, ners = [], [], [], [], [], []

            sen = sentences[sid]
            sen_len = len(sen)

            for cid in range(sen_len):
                char = sen[cid][0]
                chars.append(char2idx[char] if char in
                             char2idx else char2idx[Sp.oov])

                bichar = char + sen[
                    cid + 1][0] if cid < sen_len - 1 else char + Sp.eos
                bichars.append(bichar2idx[bichar] if bichar in
                               bichar2idx else bichar2idx[Sp.oov])

                segs.append(seg2idx[sen[cid][1]])

                pos = sen[cid][2]
                if ignore_pos_bmes:
                    pos = pos[2:]
                poss.append(pos2idx[pos])

                ners.append(ner2idx[sen[cid][3]])

                if re.match(r"^[BS]", sen[cid][3]):
                    state, label = sen[cid][3].split("-")
                    label_b = cid
                    label_e = cid
                    label_y = label2idx[label]
                    if state == 'B':
                        while True:
                            next_state, _ = sen[label_e][3].split("-")
                            if next_state == "E":
                                break
                            label_e += 1
                    if state == 'S':
                        pass

                    __span_length_count[label_e - label_b + 1] += 1
                    if label_e - label_b + 1 <= max_span_len:
                        labels.append(
                            SpanLabel(b=label_b, e=label_e, y=label_y))
                        self.__longest_span_len = max(self.__longest_span_len,
                                                      label_e - label_b + 1)

            __sentence_length_count[len(chars)] += 1
            if len(chars) < max_text_len:
                if config.lexicon_emb_pretrain != "off":
                    if config.match_mode == "naive":
                        lexmatches = match_lex_naive(group_fields(sen,
                                                                  indices=0),
                                                     lexicon2idx=lexicon2idx)
                        # for ele in lexmatches:
                        #     print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1]+ 1]))
                        #     for word_idx, match_type in ele[1]:
                        #         print(">>\t" ,self.idx2word[word_idx], idx2match_naive[match_type])
                    elif config.match_mode == "middle":
                        lexmatches = match_lex_middle(group_fields(sen,
                                                                   indices=0),
                                                      lexicon2idx=lexicon2idx)
                        # for ele in lexmatches:
                        #     print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1]))
                        #     for word_idx, match_type in ele[1]:
                        #         print(">>\t", self.idx2word[word_idx], idx2match_middle[match_type])
                    elif config.match_mode == "mix":
                        lexmatches = match_lex_mix(group_fields(sen,
                                                                indices=0),
                                                   lexicon2idx=lexicon2idx)
                        # for ele in lexmatches:
                        #     print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1]))
                        #     for word_idx, match_type in ele[1]:
                        #         print(">>\t", self.idx2word[word_idx], idx2match_mix[match_type])
                    elif config.match_mode == "off":
                        lexmatches = None
                    else:
                        raise Exception
                else:
                    lexmatches = None

                self.data.append(
                    Datum(chars=chars,
                          bichars=bichars,
                          segs=segs,
                          poss=poss,
                          ners=ners,
                          labels=labels,
                          lexmatches=lexmatches))
                self.__longest_text_len = max(self.__longest_text_len,
                                              len(chars))

        if sort_by_length:
            self.data = sorted(self.data,
                               key=lambda x: len(x[0]),
                               reverse=True)
        log("Dataset statistics for {}".format(data_path))
        log("Sentence")
        analyze_length_count(__sentence_length_count)
        log("Span")
        analyze_length_count(__span_length_count)