示例#1
0
    def decode(self,
               results: List[LubanSpan],
               threshold=0,
               verbose=False):
        keep_flags = [results[i].pred_label != 'NONE' for i in range(len(results))]

        # By threshold
        for i, span_pred in enumerate(results):
            if span_pred.pred_prob < threshold:
                keep_flags[i] = False

        # Bigger is better
        for i in range(len(results) - 1):
            if results[i + 1].bid == results[i].bid:
                keep_flags[i] = False

        # Overlapping
        while True:
            no_conflict = True
            for i in range(len(results)):
                if keep_flags[i]:
                    next_id = i + 1
                    while not next_id == len(results) and not keep_flags[next_id]:
                        next_id += 1
                    if next_id == len(results):
                        continue
                    if results[next_id].bid <= results[i].eid:
                        if results[next_id].pred_prob > results[i].pred_prob:
                            keep_flags[i] = False
                        else:
                            keep_flags[next_id] = False
                        no_conflict = False
            if no_conflict:
                break

        if verbose:
            for i in range(len(results)):
                if keep_flags[i] or results[i].gold_label != 'NONE':
                    log("{} {:>4}/{:4} {}".format(
                        "+" if keep_flags[i] else "-",
                        results[i].pred_label, results[i].gold_label, results[i].fragment))

        for i, span_pred in enumerate(results):
            if keep_flags[i] and span_pred.gold_label == span_pred.pred_label:
                self.TP[span_pred.gold_label] += 1
                self.corr_num += 1
                self.pred_num += 1
                self.gold_num += 1
            if keep_flags[i] and span_pred.gold_label != span_pred.pred_label:
                self.FP[span_pred.pred_label] += 1
                self.FN[span_pred.gold_label] += 1
                self.pred_num += 1
                if span_pred.gold_label != 'NONE':
                    self.gold_num += 1
            if not keep_flags[i] and span_pred.gold_label != 'NONE':
                self.FN[span_pred.gold_label] += 1
                self.gold_num += 1
示例#2
0
def decode_log(file_path="lstm.json.logs/last.task-4.txt",
               threshold=-4,
               verbose=True,
               epoch_id=29,
               valid_set="dev_set"):
    file = open(file_path)

    decoder = LubanEvaluator()

    results: List[LubanSpan] = []
    flag = False
    while True:
        line = file.readline()
        if line == '':
            break
        line = line.strip("\n")
        if line == ">>> epoch {} validation on {}".format(epoch_id, valid_set):
            flag = True
        if line == "<<< epoch {} validation on {}".format(epoch_id, valid_set):
            break
        if flag:
            if re.match(r"\[[^\d]*\d+\].*", line):
                if len(results) != 0:
                    decoder.decode(results, threshold, verbose)
                results = []
                if verbose:
                    log(line)
            found = re.search(
                r"\s+(\d+)~(\d+)\s+(\d+)\s+(\d*\.\d*)/(\d*\.\d*)\s*([A-Z]+)/([A-Z]+)\s*([^\s]*)",
                line)
            if found:
                results.append(
                    LubanSpan(bid=int(found.group(1)),
                              eid=int(found.group(2)),
                              lid=int(found.group(3)),
                              pred_prob=float(found.group(4)),
                              gold_prob=float(found.group(5)),
                              pred_label=found.group(6),
                              gold_label=found.group(7),
                              fragment=found.group(8)))
                # print(results)
    return decoder.prf(verbose)
示例#3
0
 def show_mean_std(self):
     log("Embedding Info")
     log("\t[char] mean {} std {}".format(
         torch.mean(self.char_embeds.weight),
         torch.std(self.char_embeds.weight),
     ))
     if self.bichar_embeds:
         log("\t[bichar] mean {} std {}".format(
             torch.mean(self.bichar_embeds.weight),
             torch.std(self.bichar_embeds.weight),
         ))
     if self.seg_embeds:
         log("\t[seg] mean {} std {}".format(
             torch.mean(self.seg_embeds.weight),
             torch.std(self.seg_embeds.weight),
         ))
     if self.pos_embeds:
         log("\t[pos] mean {} std {}".format(
             torch.mean(self.pos_embeds.weight),
             torch.std(self.pos_embeds.weight),
         ))
示例#4
0
def gen_lexicon_vocab(*data_paths, word2vec_path, out_folder, use_cache=False):
    if use_cache and os.path.exists("{}/lexicon.vocab".format(out_folder)):
        log("cache for lexicon vocab exists.")
        return
    words = set()
    for line in open(word2vec_path, encoding="utf8", errors="ignore"):
        word = re.split(r"\s+", line.strip())[0]
        words.add(word)

    lexicon = {Sp.pad: 0, Sp.oov: 1, Sp.non: 2, Sp.sos: 3, Sp.eos: 4}
    for data_path in data_paths:
        print("Gen lexicon for", data_path)
        sentences = load_sentences(data_path)
        for sid, sentence in enumerate(sentences):
            chars = group_fields(sentence, indices=0)
            for i, j in fragments(len(chars), config.max_span_length):
                frag = "".join(chars[i:j + 1])
                if frag not in lexicon and frag in words:
                    lexicon[frag] = len(lexicon)
    create_folder(out_folder)
    f_out = open("{}/lexicon.vocab".format(out_folder), "w", encoding='utf8')
    for k, v in lexicon.items():
        f_out.write("{} {}\n".format(k, v))
    f_out.close()
示例#5
0
from buff import log, log_config

log_config("small_train.goldseg.bmes", default_target="cf")

if __name__ == '__main__':
    file_name = "dataset/OntoNotes4/small_train.word.bmes"
    file = open(file_name)
    while True:
        line = file.readline()
        if line == "":
            break
        if line == "\n":
            log()
        else:
            word = line.split(" ")[0]
            word_len = len(word)
            for i, char in enumerate(word):
                if word_len == 1:
                    log(char, "S")
                else:
                    if i == 0:
                        log(char, "B")
                    elif i == word_len - 1:
                        log(char, "E")
                    else:
                        log(char, "M")
示例#6
0
    def prf(self, verbose=False):
        if verbose:
            log("TP", self.TP)
            log("FP", self.FP)
            log("FN", self.FN)
        mi_pre = (0, 0)
        mi_rec = (0, 0)
        ma_pre = 0.
        ma_rec = 0.
        for entype in self.tags:
            if entype == "NONE":
                continue
            pre = self.TP[entype] / (self.TP[entype] + self.FP[entype] + 1e-10)
            rec = self.TP[entype] / (self.TP[entype] + self.FN[entype] + 1e-10)
            ma_pre += pre
            ma_rec += rec
            mi_pre = (mi_pre[0] + self.TP[entype], mi_pre[1] + self.TP[entype] + self.FP[entype])
            mi_rec = (mi_rec[0] + self.TP[entype], mi_rec[1] + self.TP[entype] + self.FN[entype])
            f1 = 2 * pre * rec / (pre + rec + 1e-10)
            if verbose:
                log("{} pre: {:.4f} rec: {:.4f} f1:  {:.4f}".format(
                    entype, pre, rec, f1
                ))
        mi_pre = mi_pre[0] / (mi_pre[1] + 1e-10)
        mi_rec = mi_rec[0] / (mi_rec[1] + 1e-10)
        mi_f1 = 2 * mi_pre * mi_rec / (mi_pre + mi_rec + 1e-10)
        ma_pre /= 4
        ma_rec /= 4
        ma_f1 = 2 * ma_pre * ma_rec / (ma_pre + ma_rec + 1e-10)
        if verbose:
            log("micro pre: {:.4f} rec: {:.4f} f1:  {:.4f}".format(mi_pre, mi_rec, mi_f1))
            log("macro pre: {:.4f} rec: {:.4f} f1:  {:.4f}".format(ma_pre, ma_rec, ma_f1))

            # log("ignore-class pre: {:.4f} rec: {:.4f} f1:  {:.4f}".format(
            #     self.corr_num / (self.pred_num + 1e-10),
            #     self.corr_num / (self.gold_num + 1e-10),
            #     2 / (self.pred_num / (self.corr_num + 1e-10) + self.gold_num / (self.corr_num + 1e-10))
            # ))
        return mi_pre, mi_rec, mi_f1
示例#7
0
def gen_vocab(data_path,
              out_folder,
              char_count_gt=2,
              bichar_count_gt=2,
              ignore_tag_bmes=False,
              use_cache=False):
    if use_cache and os.path.exists(out_folder):
        log("cache for vocab exists.")
        return
    sentences = load_sentences(data_path)

    char_count = defaultdict(lambda: 0)
    bichar_count = defaultdict(lambda: 0)
    ner_labels = []  # BE-*
    pos_vocab = {Sp.pad: 0}
    for sentence in sentences:
        for line_idx, line in enumerate(sentence):
            char_count[line[0]] += 1
            if line[3] not in ner_labels:
                ner_labels.append(line[3])
            pos = line[2]
            if ignore_tag_bmes:
                pos = pos[2:]
            if pos not in pos_vocab:
                pos_vocab[pos] = len(pos_vocab)
            if line_idx < len(sentence) - 1:
                bichar_count[line[0] + sentence[line_idx + 1][0]] += 1
            else:
                bichar_count[line[0] + Sp.eos] += 1

    char_count = dict(
        sorted(char_count.items(), key=lambda x: x[1], reverse=True))
    bichar_count = dict(
        sorted(bichar_count.items(), key=lambda x: x[1], reverse=True))

    # gen char vocab
    char_vocab = {Sp.pad: 0, Sp.oov: 1, Sp.sos: 2, Sp.eos: 3}
    for i, k in enumerate(char_count.keys()):
        if char_count[k] > char_count_gt:
            char_vocab[k] = len(char_vocab)
    analyze_vocab_count(char_count)

    # gen char vocab
    bichar_vocab = {Sp.pad: 0, Sp.oov: 1}
    for i, k in enumerate(bichar_count.keys()):
        if bichar_count[k] > bichar_count_gt:
            bichar_vocab[k] = len(bichar_vocab)
    analyze_vocab_count(bichar_count)

    # seg vocab
    seg_vocab = {Sp.pad: 0, "B": 1, "M": 2, "E": 3, "S": 4}

    # ner vocab / BMES mode
    ner_vocab = {Sp.pad: 0, Sp.sos: 1, Sp.eos: 2}
    for tag in ner_labels:
        ner_vocab[tag] = len(ner_vocab)

    # gen label vocab
    label_vocab = {"NONE": 0}
    for label in ner_labels:
        found = re.search(".*-(.*)", label)
        if found:
            if found.group(1) not in label_vocab:
                label_vocab[found.group(1)] = len(label_vocab)

    # write to file
    create_folder(out_folder)
    for ele in {
            "char.vocab": char_vocab,
            "bichar.vocab": bichar_vocab,
            "seg.vocab": seg_vocab,
            "pos.vocab": pos_vocab,
            "ner.vocab": ner_vocab,
            "label.vocab": label_vocab,
    }.items():
        f_out = open("{}/{}".format(out_folder, ele[0]), "w", encoding='utf8')
        for k, v in ele[1].items():
            f_out.write("{} {}\n".format(k, v))
        f_out.close()
示例#8
0
    def __init__(self,
                 data_path,
                 lexicon2idx,
                 char2idx,
                 bichar2idx,
                 seg2idx,
                 pos2idx,
                 ner2idx,
                 label2idx,
                 ignore_pos_bmes=False,
                 max_text_len=19260817,
                 max_span_len=19260817,
                 sort_by_length=False):
        super(ConllDataSet, self).__init__()
        if config.lexicon_emb_pretrain != "off":
            self.word2idx = lexicon2idx
            self.idx2word = {v: k for k, v in self.word2idx.items()}

        sentences = load_sentences(data_path)

        self.__max_text_len = max_text_len
        self.__max_span_len = max_span_len
        self.__longest_text_len = -1
        self.__longest_span_len = -1

        __span_length_count = defaultdict(lambda: 0)
        __sentence_length_count = defaultdict(lambda: 0)

        for sid in range(len(sentences)):
            chars, bichars, segs, labels, poss, ners = [], [], [], [], [], []

            sen = sentences[sid]
            sen_len = len(sen)

            for cid in range(sen_len):
                char = sen[cid][0]
                chars.append(char2idx[char] if char in
                             char2idx else char2idx[Sp.oov])

                bichar = char + sen[
                    cid + 1][0] if cid < sen_len - 1 else char + Sp.eos
                bichars.append(bichar2idx[bichar] if bichar in
                               bichar2idx else bichar2idx[Sp.oov])

                segs.append(seg2idx[sen[cid][1]])

                pos = sen[cid][2]
                if ignore_pos_bmes:
                    pos = pos[2:]
                poss.append(pos2idx[pos])

                ners.append(ner2idx[sen[cid][3]])

                if re.match(r"^[BS]", sen[cid][3]):
                    state, label = sen[cid][3].split("-")
                    label_b = cid
                    label_e = cid
                    label_y = label2idx[label]
                    if state == 'B':
                        while True:
                            next_state, _ = sen[label_e][3].split("-")
                            if next_state == "E":
                                break
                            label_e += 1
                    if state == 'S':
                        pass

                    __span_length_count[label_e - label_b + 1] += 1
                    if label_e - label_b + 1 <= max_span_len:
                        labels.append(
                            SpanLabel(b=label_b, e=label_e, y=label_y))
                        self.__longest_span_len = max(self.__longest_span_len,
                                                      label_e - label_b + 1)

            __sentence_length_count[len(chars)] += 1
            if len(chars) < max_text_len:
                if config.lexicon_emb_pretrain != "off":
                    if config.match_mode == "naive":
                        lexmatches = match_lex_naive(group_fields(sen,
                                                                  indices=0),
                                                     lexicon2idx=lexicon2idx)
                        # for ele in lexmatches:
                        #     print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1]+ 1]))
                        #     for word_idx, match_type in ele[1]:
                        #         print(">>\t" ,self.idx2word[word_idx], idx2match_naive[match_type])
                    elif config.match_mode == "middle":
                        lexmatches = match_lex_middle(group_fields(sen,
                                                                   indices=0),
                                                      lexicon2idx=lexicon2idx)
                        # for ele in lexmatches:
                        #     print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1]))
                        #     for word_idx, match_type in ele[1]:
                        #         print(">>\t", self.idx2word[word_idx], idx2match_middle[match_type])
                    elif config.match_mode == "mix":
                        lexmatches = match_lex_mix(group_fields(sen,
                                                                indices=0),
                                                   lexicon2idx=lexicon2idx)
                        # for ele in lexmatches:
                        #     print("".join(group_fields(sen, indices=0)[ele[0][0]: ele[0][1] + 1]))
                        #     for word_idx, match_type in ele[1]:
                        #         print(">>\t", self.idx2word[word_idx], idx2match_mix[match_type])
                    elif config.match_mode == "off":
                        lexmatches = None
                    else:
                        raise Exception
                else:
                    lexmatches = None

                self.data.append(
                    Datum(chars=chars,
                          bichars=bichars,
                          segs=segs,
                          poss=poss,
                          ners=ners,
                          labels=labels,
                          lexmatches=lexmatches))
                self.__longest_text_len = max(self.__longest_text_len,
                                              len(chars))

        if sort_by_length:
            self.data = sorted(self.data,
                               key=lambda x: len(x[0]),
                               reverse=True)
        log("Dataset statistics for {}".format(data_path))
        log("Sentence")
        analyze_length_count(__sentence_length_count)
        log("Span")
        analyze_length_count(__span_length_count)