def decode(self, results: List[LubanSpan], threshold=0, verbose=False): keep_flags = [results[i].pred_label != 'NONE' for i in range(len(results))] # By threshold for i, span_pred in enumerate(results): if span_pred.pred_prob < threshold: keep_flags[i] = False # Bigger is better for i in range(len(results) - 1): if results[i + 1].bid == results[i].bid: keep_flags[i] = False # Overlapping while True: no_conflict = True for i in range(len(results)): if keep_flags[i]: next_id = i + 1 while not next_id == len(results) and not keep_flags[next_id]: next_id += 1 if next_id == len(results): continue if results[next_id].bid <= results[i].eid: if results[next_id].pred_prob > results[i].pred_prob: keep_flags[i] = False else: keep_flags[next_id] = False no_conflict = False if no_conflict: break if verbose: for i in range(len(results)): if keep_flags[i] or results[i].gold_label != 'NONE': log("{} {:>4}/{:4} {}".format( "+" if keep_flags[i] else "-", results[i].pred_label, results[i].gold_label, results[i].fragment)) for i, span_pred in enumerate(results): if keep_flags[i] and span_pred.gold_label == span_pred.pred_label: self.TP[span_pred.gold_label] += 1 self.corr_num += 1 self.pred_num += 1 self.gold_num += 1 if keep_flags[i] and span_pred.gold_label != span_pred.pred_label: self.FP[span_pred.pred_label] += 1 self.FN[span_pred.gold_label] += 1 self.pred_num += 1 if span_pred.gold_label != 'NONE': self.gold_num += 1 if not keep_flags[i] and span_pred.gold_label != 'NONE': self.FN[span_pred.gold_label] += 1 self.gold_num += 1
def decode_log(file_path="lstm.json.logs/last.task-4.txt", threshold=-4, verbose=True, epoch_id=29, valid_set="dev_set"): file = open(file_path) decoder = LubanEvaluator() results: List[LubanSpan] = [] flag = False while True: line = file.readline() if line == '': break line = line.strip("\n") if line == ">>> epoch {} validation on {}".format(epoch_id, valid_set): flag = True if line == "<<< epoch {} validation on {}".format(epoch_id, valid_set): break if flag: if re.match(r"\[[^\d]*\d+\].*", line): if len(results) != 0: decoder.decode(results, threshold, verbose) results = [] if verbose: log(line) found = re.search( r"\s+(\d+)~(\d+)\s+(\d+)\s+(\d*\.\d*)/(\d*\.\d*)\s*([A-Z]+)/([A-Z]+)\s*([^\s]*)", line) if found: results.append( LubanSpan(bid=int(found.group(1)), eid=int(found.group(2)), lid=int(found.group(3)), pred_prob=float(found.group(4)), gold_prob=float(found.group(5)), pred_label=found.group(6), gold_label=found.group(7), fragment=found.group(8))) # print(results) return decoder.prf(verbose)
def show_mean_std(self): log("Embedding Info") log("\t[char] mean {} std {}".format( torch.mean(self.char_embeds.weight), torch.std(self.char_embeds.weight), )) if self.bichar_embeds: log("\t[bichar] mean {} std {}".format( torch.mean(self.bichar_embeds.weight), torch.std(self.bichar_embeds.weight), )) if self.seg_embeds: log("\t[seg] mean {} std {}".format( torch.mean(self.seg_embeds.weight), torch.std(self.seg_embeds.weight), )) if self.pos_embeds: log("\t[pos] mean {} std {}".format( torch.mean(self.pos_embeds.weight), torch.std(self.pos_embeds.weight), ))
def gen_lexicon_vocab(*data_paths, word2vec_path, out_folder, use_cache=False): if use_cache and os.path.exists("{}/lexicon.vocab".format(out_folder)): log("cache for lexicon vocab exists.") return words = set() for line in open(word2vec_path, encoding="utf8", errors="ignore"): word = re.split(r"\s+", line.strip())[0] words.add(word) lexicon = {Sp.pad: 0, Sp.oov: 1, Sp.non: 2, Sp.sos: 3, Sp.eos: 4} for data_path in data_paths: print("Gen lexicon for", data_path) sentences = load_sentences(data_path) for sid, sentence in enumerate(sentences): chars = group_fields(sentence, indices=0) for i, j in fragments(len(chars), config.max_span_length): frag = "".join(chars[i:j + 1]) if frag not in lexicon and frag in words: lexicon[frag] = len(lexicon) create_folder(out_folder) f_out = open("{}/lexicon.vocab".format(out_folder), "w", encoding='utf8') for k, v in lexicon.items(): f_out.write("{} {}\n".format(k, v)) f_out.close()
from buff import log, log_config log_config("small_train.goldseg.bmes", default_target="cf") if __name__ == '__main__': file_name = "dataset/OntoNotes4/small_train.word.bmes" file = open(file_name) while True: line = file.readline() if line == "": break if line == "\n": log() else: word = line.split(" ")[0] word_len = len(word) for i, char in enumerate(word): if word_len == 1: log(char, "S") else: if i == 0: log(char, "B") elif i == word_len - 1: log(char, "E") else: log(char, "M")
def prf(self, verbose=False): if verbose: log("TP", self.TP) log("FP", self.FP) log("FN", self.FN) mi_pre = (0, 0) mi_rec = (0, 0) ma_pre = 0. ma_rec = 0. for entype in self.tags: if entype == "NONE": continue pre = self.TP[entype] / (self.TP[entype] + self.FP[entype] + 1e-10) rec = self.TP[entype] / (self.TP[entype] + self.FN[entype] + 1e-10) ma_pre += pre ma_rec += rec mi_pre = (mi_pre[0] + self.TP[entype], mi_pre[1] + self.TP[entype] + self.FP[entype]) mi_rec = (mi_rec[0] + self.TP[entype], mi_rec[1] + self.TP[entype] + self.FN[entype]) f1 = 2 * pre * rec / (pre + rec + 1e-10) if verbose: log("{} pre: {:.4f} rec: {:.4f} f1: {:.4f}".format( entype, pre, rec, f1 )) mi_pre = mi_pre[0] / (mi_pre[1] + 1e-10) mi_rec = mi_rec[0] / (mi_rec[1] + 1e-10) mi_f1 = 2 * mi_pre * mi_rec / (mi_pre + mi_rec + 1e-10) ma_pre /= 4 ma_rec /= 4 ma_f1 = 2 * ma_pre * ma_rec / (ma_pre + ma_rec + 1e-10) if verbose: log("micro pre: {:.4f} rec: {:.4f} f1: {:.4f}".format(mi_pre, mi_rec, mi_f1)) log("macro pre: {:.4f} rec: {:.4f} f1: {:.4f}".format(ma_pre, ma_rec, ma_f1)) # log("ignore-class pre: {:.4f} rec: {:.4f} f1: {:.4f}".format( # self.corr_num / (self.pred_num + 1e-10), # self.corr_num / (self.gold_num + 1e-10), # 2 / (self.pred_num / (self.corr_num + 1e-10) + self.gold_num / (self.corr_num + 1e-10)) # )) return mi_pre, mi_rec, mi_f1
def gen_vocab(data_path, out_folder, char_count_gt=2, bichar_count_gt=2, ignore_tag_bmes=False, use_cache=False): if use_cache and os.path.exists(out_folder): log("cache for vocab exists.") return sentences = load_sentences(data_path) char_count = defaultdict(lambda: 0) bichar_count = defaultdict(lambda: 0) ner_labels = [] # BE-* pos_vocab = {Sp.pad: 0} for sentence in sentences: for line_idx, line in enumerate(sentence): char_count[line[0]] += 1 if line[3] not in ner_labels: ner_labels.append(line[3]) pos = line[2] if ignore_tag_bmes: pos = pos[2:] if pos not in pos_vocab: pos_vocab[pos] = len(pos_vocab) if line_idx < len(sentence) - 1: bichar_count[line[0] + sentence[line_idx + 1][0]] += 1 else: bichar_count[line[0] + Sp.eos] += 1 char_count = dict( sorted(char_count.items(), key=lambda x: x[1], reverse=True)) bichar_count = dict( sorted(bichar_count.items(), key=lambda x: x[1], reverse=True)) # gen char vocab char_vocab = {Sp.pad: 0, Sp.oov: 1, Sp.sos: 2, Sp.eos: 3} for i, k in enumerate(char_count.keys()): if char_count[k] > char_count_gt: char_vocab[k] = len(char_vocab) analyze_vocab_count(char_count) # gen char vocab bichar_vocab = {Sp.pad: 0, Sp.oov: 1} for i, k in enumerate(bichar_count.keys()): if bichar_count[k] > bichar_count_gt: bichar_vocab[k] = len(bichar_vocab) analyze_vocab_count(bichar_count) # seg vocab seg_vocab = {Sp.pad: 0, "B": 1, "M": 2, "E": 3, "S": 4} # ner vocab / BMES mode ner_vocab = {Sp.pad: 0, Sp.sos: 1, Sp.eos: 2} for tag in ner_labels: ner_vocab[tag] = len(ner_vocab) # gen label vocab label_vocab = {"NONE": 0} for label in ner_labels: found = re.search(".*-(.*)", label) if found: if found.group(1) not in label_vocab: label_vocab[found.group(1)] = len(label_vocab) # write to file create_folder(out_folder) for ele in { "char.vocab": char_vocab, "bichar.vocab": bichar_vocab, "seg.vocab": seg_vocab, "pos.vocab": pos_vocab, "ner.vocab": ner_vocab, "label.vocab": label_vocab, }.items(): f_out = open("{}/{}".format(out_folder, ele[0]), "w", encoding='utf8') for k, v in ele[1].items(): f_out.write("{} {}\n".format(k, v)) f_out.close()
def __init__(self, data_path, lexicon2idx, char2idx, bichar2idx, seg2idx, pos2idx, ner2idx, label2idx, ignore_pos_bmes=False, max_text_len=19260817, max_span_len=19260817, sort_by_length=False): super(ConllDataSet, self).__init__() if config.lexicon_emb_pretrain != "off": self.word2idx = lexicon2idx self.idx2word = {v: k for k, v in self.word2idx.items()} sentences = load_sentences(data_path) self.__max_text_len = max_text_len self.__max_span_len = max_span_len self.__longest_text_len = -1 self.__longest_span_len = -1 __span_length_count = defaultdict(lambda: 0) __sentence_length_count = defaultdict(lambda: 0) for sid in range(len(sentences)): chars, bichars, segs, labels, poss, ners = [], [], [], [], [], [] sen = sentences[sid] sen_len = len(sen) for cid in range(sen_len): char = sen[cid][0] chars.append(char2idx[char] if char in char2idx else char2idx[Sp.oov]) bichar = char + sen[ cid + 1][0] if cid < sen_len - 1 else char + Sp.eos bichars.append(bichar2idx[bichar] if bichar in bichar2idx else bichar2idx[Sp.oov]) segs.append(seg2idx[sen[cid][1]]) pos = sen[cid][2] if ignore_pos_bmes: pos = pos[2:] poss.append(pos2idx[pos]) ners.append(ner2idx[sen[cid][3]]) if re.match(r"^[BS]", sen[cid][3]): state, label = sen[cid][3].split("-") label_b = cid label_e = cid label_y = label2idx[label] if state == 'B': while True: next_state, _ = sen[label_e][3].split("-") if next_state == "E": break label_e += 1 if state == 'S': pass __span_length_count[label_e - label_b + 1] += 1 if label_e - label_b + 1 <= max_span_len: labels.append( SpanLabel(b=label_b, e=label_e, y=label_y)) self.__longest_span_len = max(self.__longest_span_len, label_e - label_b + 1) __sentence_length_count[len(chars)] += 1 if len(chars) < max_text_len: if config.lexicon_emb_pretrain != "off": if config.match_mode == "naive": lexmatches = match_lex_naive(group_fields(sen, indices=0), lexicon2idx=lexicon2idx) # for ele in lexmatches: # print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1]+ 1])) # for word_idx, match_type in ele[1]: # print(">>\t" ,self.idx2word[word_idx], idx2match_naive[match_type]) elif config.match_mode == "middle": lexmatches = match_lex_middle(group_fields(sen, indices=0), lexicon2idx=lexicon2idx) # for ele in lexmatches: # print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1])) # for word_idx, match_type in ele[1]: # print(">>\t", self.idx2word[word_idx], idx2match_middle[match_type]) elif config.match_mode == "mix": lexmatches = match_lex_mix(group_fields(sen, indices=0), lexicon2idx=lexicon2idx) # for ele in lexmatches: # print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1])) # for word_idx, match_type in ele[1]: # print(">>\t", self.idx2word[word_idx], idx2match_mix[match_type]) elif config.match_mode == "off": lexmatches = None else: raise Exception else: lexmatches = None self.data.append( Datum(chars=chars, bichars=bichars, segs=segs, poss=poss, ners=ners, labels=labels, lexmatches=lexmatches)) self.__longest_text_len = max(self.__longest_text_len, len(chars)) if sort_by_length: self.data = sorted(self.data, key=lambda x: len(x[0]), reverse=True) log("Dataset statistics for {}".format(data_path)) log("Sentence") analyze_length_count(__sentence_length_count) log("Span") analyze_length_count(__span_length_count)