def crf_decode(self, batch_data) -> List[List[int]]: chars = group_fields(batch_data, keys="chars") token_reprs = self.compute_token_reprs(batch_data) scores = self.ner_score(token_reprs) masks = torch.tensor(batch_mask(chars, mask_zero=True), dtype=torch.uint8, device=self.device) results = self.ner_crf.decode(scores, masks) return results
def compute_token_reprs(self, batch_data): chars, bichars, segs, poss = group_fields( batch_data, keys=["chars", "bichars", "segs", "poss"] ) text_lens = batch_lens(chars) pad_chars = batch_pad(chars, self.char2idx[Sp.pad]) pad_bichars = batch_pad(bichars, self.bichar2idx[Sp.pad]) pad_segs = batch_pad(segs, self.seg2idx[Sp.pad]) pad_poss = batch_pad(poss, self.pos2idx[Sp.pad]) pad_chars_tensor = torch.tensor(pad_chars, device=self.device) pad_bichars_tensor = torch.tensor(pad_bichars, device=self.device) pad_segs_tensor = torch.tensor(pad_segs, device=self.device) pad_poss_tensor = torch.tensor(pad_poss, device=self.device) input_embs = self.embeds(pad_chars_tensor, pad_bichars_tensor, pad_segs_tensor, pad_poss_tensor) if config.token_type == 'rnn': token_reprs = self.token_encoder(input_embs, text_lens) elif config.token_type == 'tfer': masks = self.token_encoder.gen_masks(pad_chars_tensor) token_reprs = self.token_encoder(input_embs, masks, text_lens) else: token_reprs = input_embs return token_reprs
def crf_nll(self, batch_data) -> torch.Tensor: gold_tags = group_fields(batch_data, keys="ners") token_reprs = self.compute_token_reprs(batch_data) scores = self.ner_score(token_reprs) gold_tags = torch.tensor(batch_pad(gold_tags, 0)) masks = torch.tensor(batch_mask(gold_tags, mask_zero=True), dtype=torch.uint8, device=self.device) crf_loss = self.ner_crf(scores, gold_tags, masks, reduction="mean") return - crf_loss
def gen_lexicon_vocab(*data_paths, word2vec_path, out_folder, use_cache=False): if use_cache and os.path.exists("{}/lexicon.vocab".format(out_folder)): log("cache for lexicon vocab exists.") return words = set() for line in open(word2vec_path, encoding="utf8", errors="ignore"): word = re.split(r"\s+", line.strip())[0] words.add(word) lexicon = {Sp.pad: 0, Sp.oov: 1, Sp.non: 2, Sp.sos: 3, Sp.eos: 4} for data_path in data_paths: print("Gen lexicon for", data_path) sentences = load_sentences(data_path) for sid, sentence in enumerate(sentences): chars = group_fields(sentence, indices=0) for i, j in fragments(len(chars), config.max_span_length): frag = "".join(chars[i:j + 1]) if frag not in lexicon and frag in words: lexicon[frag] = len(lexicon) create_folder(out_folder) f_out = open("{}/lexicon.vocab".format(out_folder), "w", encoding='utf8') for k, v in lexicon.items(): f_out.write("{} {}\n".format(k, v)) f_out.close()
def get_span_score_tags(self, batch_data, lex_att=False): chars = group_fields(batch_data, keys='chars') text_lens = batch_lens(chars) labels = group_fields(batch_data, keys='labels') token_reprs = self.compute_token_reprs(batch_data) if config.frag_type != "off": frag_reprs = self.fragment_encoder(token_reprs, text_lens, sos_repr=self.sos_token, eos_repr=self.eos_token) if config.frag_att_type != "off": # print(Color.red("ATTENTION!")) d_frag = frag_reprs.size(1) att_frag_reprs = [] offset = 0 for i, text_len in enumerate(text_lens): q = frag_reprs[offset: offset + span_num(text_len)].unsqueeze(0) k = token_reprs[i][:text_len].unsqueeze(0) v = k mask = gen_att_mask(k_lens=(text_len,), max_k_len=text_len, max_q_len=q.size(1), device=self.device) att_frag_repr, att_score = self.multi_att(q, k, v, mask) att_frag_reprs.append(att_frag_repr.squeeze(0)) offset += text_len att_frag_reprs = torch.cat(att_frag_reprs) att_frag_reprs = self.att_norm(att_frag_reprs) / math.sqrt(d_frag) # show_mean_std(frag_reprs) # show_mean_std(att_frag_reprs) if config.frag_att_type == 'cat': frag_reprs = torch.cat([frag_reprs, att_frag_reprs], dim=1) elif config.frag_att_type == 'add': frag_reprs = (frag_reprs + att_frag_reprs) / math.sqrt(2) else: raise Exception else: frag_reprs = None if config.ctx_type in ['include', 'exclude']: left_ctx_reprs, right_ctx_reprs = self.context_encoder(token_reprs, text_lens) if frag_reprs is not None: frag_reprs = torch.cat([frag_reprs, left_ctx_reprs, right_ctx_reprs], dim=1) else: frag_reprs = torch.cat([left_ctx_reprs, right_ctx_reprs], dim=1) elif config.ctx_type == 'off': pass else: raise Exception if config.lexicon_emb_pretrain != "off": if config.match_mode in ["naive", "middle", "mix"]: lexmatches = group_fields(batch_data, keys='lexmatches') lexmatches = [item[1] for sublist in lexmatches for item in sublist] assert len(lexmatches) == frag_reprs.size(0) frag_match_lexicons, frag_match_types = [], [] for it_match in lexmatches: if len(it_match) == 0: frag_match_lexicons.append([]) frag_match_types.append([]) else: frag_match_lexicons.append(group_fields(it_match, indices=0)) frag_match_types.append(group_fields(it_match, indices=1)) mask = gen_att_mask(batch_lens(frag_match_lexicons), max_match_num, 1).to(self.device) frag_match_lexicons = batch_pad(frag_match_lexicons, pad_len=max_match_num) frag_match_types = batch_pad(frag_match_types, pad_len=max_match_num) frag_match_lexicons = torch.tensor(frag_match_lexicons, dtype=torch.long, device=self.device) frag_match_types = torch.tensor(frag_match_types, dtype=torch.long, device=self.device) mem_lexicon = self.lexicon_embeds(frag_match_lexicons) mem_match = self.match_embeds(frag_match_types) memory = torch.cat([mem_lexicon, mem_match], dim=2) if config.match_head == 0: att_word, lex_att_score = self.lexicon_attention(frag_reprs.unsqueeze(1), memory, mask) elif config.match_head > 0: att_word, lex_att_score = self.lexicon_attention(frag_reprs.unsqueeze(1), memory, memory, mask) else: raise Exception frag_reprs = torch.cat([frag_reprs, att_word.squeeze(1)], dim=1) elif config.match_mode == "off": pass else: raise Exception span_ys = self.gen_span_ys(chars, labels) # score = frag_reprs @ self.label_weight + self.label_bias score = self.scorer(frag_reprs) if lex_att: return score, span_ys, lex_att_score else: return score, span_ys
def __init__(self, data_path, lexicon2idx, char2idx, bichar2idx, seg2idx, pos2idx, ner2idx, label2idx, ignore_pos_bmes=False, max_text_len=19260817, max_span_len=19260817, sort_by_length=False): super(ConllDataSet, self).__init__() if config.lexicon_emb_pretrain != "off": self.word2idx = lexicon2idx self.idx2word = {v: k for k, v in self.word2idx.items()} sentences = load_sentences(data_path) self.__max_text_len = max_text_len self.__max_span_len = max_span_len self.__longest_text_len = -1 self.__longest_span_len = -1 __span_length_count = defaultdict(lambda: 0) __sentence_length_count = defaultdict(lambda: 0) for sid in range(len(sentences)): chars, bichars, segs, labels, poss, ners = [], [], [], [], [], [] sen = sentences[sid] sen_len = len(sen) for cid in range(sen_len): char = sen[cid][0] chars.append(char2idx[char] if char in char2idx else char2idx[Sp.oov]) bichar = char + sen[ cid + 1][0] if cid < sen_len - 1 else char + Sp.eos bichars.append(bichar2idx[bichar] if bichar in bichar2idx else bichar2idx[Sp.oov]) segs.append(seg2idx[sen[cid][1]]) pos = sen[cid][2] if ignore_pos_bmes: pos = pos[2:] poss.append(pos2idx[pos]) ners.append(ner2idx[sen[cid][3]]) if re.match(r"^[BS]", sen[cid][3]): state, label = sen[cid][3].split("-") label_b = cid label_e = cid label_y = label2idx[label] if state == 'B': while True: next_state, _ = sen[label_e][3].split("-") if next_state == "E": break label_e += 1 if state == 'S': pass __span_length_count[label_e - label_b + 1] += 1 if label_e - label_b + 1 <= max_span_len: labels.append( SpanLabel(b=label_b, e=label_e, y=label_y)) self.__longest_span_len = max(self.__longest_span_len, label_e - label_b + 1) __sentence_length_count[len(chars)] += 1 if len(chars) < max_text_len: if config.lexicon_emb_pretrain != "off": if config.match_mode == "naive": lexmatches = match_lex_naive(group_fields(sen, indices=0), lexicon2idx=lexicon2idx) # for ele in lexmatches: # print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1]+ 1])) # for word_idx, match_type in ele[1]: # print(">>\t" ,self.idx2word[word_idx], idx2match_naive[match_type]) elif config.match_mode == "middle": lexmatches = match_lex_middle(group_fields(sen, indices=0), lexicon2idx=lexicon2idx) # for ele in lexmatches: # print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1])) # for word_idx, match_type in ele[1]: # print(">>\t", self.idx2word[word_idx], idx2match_middle[match_type]) elif config.match_mode == "mix": lexmatches = match_lex_mix(group_fields(sen, indices=0), lexicon2idx=lexicon2idx) # for ele in lexmatches: # print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1])) # for word_idx, match_type in ele[1]: # print(">>\t", self.idx2word[word_idx], idx2match_mix[match_type]) elif config.match_mode == "off": lexmatches = None else: raise Exception else: lexmatches = None self.data.append( Datum(chars=chars, bichars=bichars, segs=segs, poss=poss, ners=ners, labels=labels, lexmatches=lexmatches)) self.__longest_text_len = max(self.__longest_text_len, len(chars)) if sort_by_length: self.data = sorted(self.data, key=lambda x: len(x[0]), reverse=True) log("Dataset statistics for {}".format(data_path)) log("Sentence") analyze_length_count(__sentence_length_count) log("Span") analyze_length_count(__span_length_count)