class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('char')
        self.bichar_alphabet = Alphabet('bichar')

        self.pos_alphabet = Alphabet('pos', is_label=True)
        self.char_type_alphabet = Alphabet('type', is_label=True)

        self.extchar_alphabet = Alphabet('extchar')
        self.extbichar_alphabet = Alphabet('extbichar')

        self.segpos_alphabet = Alphabet('segpos', is_label=True)
        self.wordlen_alphabet = Alphabet('wordlen')

        self.number_normalized = False
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.norm_bichar_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.pretrain_bichar_embedding = None

        self.wordPadID = 0
        self.charPadID = 0
        self.bicharPadID = 0
        self.charTypePadID = 0
        self.wordlenPadID = 0
        self.posPadID = 0
        self.appID = 0
        self.actionPadID = 0

        self.word_num = 0
        self.char_num = 0
        self.pos_num = 0
        self.bichar_num = 0
        self.segpos_num = 0
        self.wordlen_num = 0
        self.char_type_num = 0

        self.extchar_num = 0
        self.extbichar_num = 0

        self.wordlen_max = 7

    def build_alphabet(self, word_counter, char_counter, extchar_counter,
                       bichar_counter, extbichar_counter, char_type_counter,
                       pos_counter, wordlen_counter, gold_counter,
                       shrink_feature_threshold):
        # 可以优化,但是会比较混乱
        for word, count in word_counter.most_common():
            # if count > shrink_feature_threshold:
            self.word_alphabet.add(word, count)
        for char, count in char_counter.most_common():
            if count > shrink_feature_threshold:
                self.char_alphabet.add(char, count)
        for extchar, count in extchar_counter.most_common():
            # if count > shrink_feature_threshold:
            self.extchar_alphabet.add(extchar, count)
        for bichar, count in bichar_counter.most_common():
            if count > shrink_feature_threshold:
                self.bichar_alphabet.add(bichar, count)
        for extbichar, count in extbichar_counter.most_common():
            # if count > shrink_feature_threshold:
            self.extbichar_alphabet.add(extbichar, count)
        for char_type, count in char_type_counter.most_common():
            # if count > shrink_feature_threshold:
            self.char_type_alphabet.add(char_type, count)
        for pos, count in pos_counter.most_common():
            # if count > shrink_feature_threshold:
            self.pos_alphabet.add(pos, count)
        for wordlen, count in wordlen_counter.most_common():
            # if count > shrink_feature_threshold:
            self.wordlen_alphabet.add(wordlen, count)
        for segpos, count in gold_counter.most_common():
            # if count > shrink_feature_threshold:
            self.segpos_alphabet.add(segpos, count)
        # another method
        # reverse = lambda x: dict(zip(x, range(len(x))))
        # self.word_alphabet.word2id = reverse(self.word_alphabet.id2word)
        # self.label_alphabet.word2id = reverse(self.label_alphabet.id2word)

        ##### check
        if len(self.word_alphabet.word2id) != len(
                self.word_alphabet.id2word) or len(
                    self.word_alphabet.id2count) != len(
                        self.word_alphabet.id2word):
            print('there are errors in building word alphabet.')
        if len(self.char_alphabet.word2id) != len(
                self.char_alphabet.id2word) or len(
                    self.char_alphabet.id2count) != len(
                        self.char_alphabet.id2word):
            print('there are errors in building char alphabet.')

    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.char_num = self.char_alphabet.close()
        self.pos_num = self.pos_alphabet.close()
        self.bichar_num = self.bichar_alphabet.close()
        self.segpos_num = self.segpos_alphabet.close()
        # self.wordlen_max = self.wordlen_alphabet.size()-2            ######
        # print(self.wordlen_max)
        # self.wordlen_alphabet.add(self.wordlen_max+1)
        self.wordlen_num = self.wordlen_alphabet.close()
        # print(self.wordlen_num)
        self.char_type_num = self.char_type_alphabet.close()

    def fix_static_alphabet(self):
        self.extchar_num = self.extchar_alphabet.close()
        self.extbichar_num = self.extbichar_alphabet.close()

    def get_instance(self, file, run_insts, shrink_feature_threshold):
        insts = []
        word_counter = Counter()
        char_counter = Counter()
        bichar_counter = Counter()
        char_type_counter = Counter()
        gold_counter = Counter()
        pos_counter = Counter()
        extchar_counter = Counter()
        extbichar_counter = Counter()
        wordlen_counter = Counter()
        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                if run_insts == count: break
                line = line.strip().split(' ')
                inst = Instance()
                char_num = 0
                start = 0
                for idx, ele in enumerate(line):
                    word = ele.split('_')[0]
                    if self.number_normalized:
                        word = utils.normalize_word(word)
                    inst.words.append(word)
                    word_counter[word] += 1
                    pos = ele.split('_')[1]
                    inst.gold_pos.append(pos)
                    pos_counter[pos] += 1
                    word_list = list(word)
                    # print(len(word_list))
                    # if len(word_list) > 5:
                    #     print(len(word_list))
                    #     print(word)
                    if len(word_list) > 6:
                        # inst.word_len.append(7)
                        cur_len = 7
                    else:
                        # inst.word_len.append(len(word_list))
                        cur_len = len(word_list)
                    inst.word_len.append(cur_len)
                    wordlen_counter[cur_len] += 1
                    for id, char in enumerate(word):
                        if idx == 0 and id == 0:
                            inst.last_wordlen.append('<pad>')
                            inst.last_pos.append('<pad>')
                        elif idx != 0 and id == 0:
                            last_wordlen_cur = len(inst.words[-2])
                            if last_wordlen_cur > 6: last_wordlen_cur = 7
                            inst.last_wordlen.append(
                                last_wordlen_cur)  ### 因为当前的word和pos已经放进列表里了
                            inst.last_pos.append(inst.gold_pos[-2])
                        else:
                            last_wordlen_cur = id
                            if last_wordlen_cur > 6: last_wordlen_cur = 7
                            inst.last_wordlen.append(last_wordlen_cur)
                            inst.last_pos.append(inst.gold_pos[-1])
                        inst.chars.append(char)
                        inst.extchars.append(char)
                        char_counter[char] += 1
                        extchar_counter[char] += 1

                        char_type = instance.char_type(char)
                        # print(char_type)
                        inst.char_type.append(char_type)
                        char_type_counter[char_type] += 1
                        if id == 0:
                            inst.gold_action.append(sep + '#' + pos)
                            gold_counter[sep + '#' + pos] += 1
                            start = char_num
                        else:
                            inst.gold_action.append(app)
                            gold_counter[app] += 1
                        char_num += 1
                    inst.word_seg.append('[' + str(start) + ',' +
                                         str(char_num) + ']')
                    inst.word_seg_pos.append('[' + str(start) + ',' +
                                             str(char_num) + ']' + pos)
                right_bichars = []

                if len(inst.chars) == 1:
                    inst.left_bichars.append(nullsymbol + char)
                    inst.extleft_bichars.append(nullsymbol + char)
                    right_bichars.append(char + nullsymbol)
                    bichar_counter[nullsymbol + char] += 1
                    bichar_counter[char + nullsymbol] += 1
                    extbichar_counter[nullsymbol + char] += 1
                    extbichar_counter[char + nullsymbol] += 1
                else:
                    for id, char in enumerate(inst.chars):
                        if id == 0:
                            inst.left_bichars.append(nullsymbol + char)
                            # inst.right_bichars.append(char+inst.chars[id+1])
                            inst.extleft_bichars.append(nullsymbol + char)
                            # inst.extright_bichars.append(char+inst.chars[id+1])
                            right_bichars.append(char + inst.chars[id + 1])
                            # right_index_mark.append(1)
                            bichar_counter[nullsymbol + char] += 1
                            bichar_counter[char + inst.chars[id + 1]] += 1
                            extbichar_counter[nullsymbol + char] += 1
                            extbichar_counter[char + inst.chars[id + 1]] += 1
                        elif id == (len(inst.chars) - 1):
                            inst.left_bichars.append(inst.chars[id - 1] + char)
                            # inst.right_bichars.append(char+nullsymbol)
                            inst.extleft_bichars.append(inst.chars[id - 1] +
                                                        char)
                            # inst.extright_bichars.append(char + nullsymbol)
                            right_bichars.append(char + nullsymbol)
                            # right_index_mark.append(1)
                            bichar_counter[inst.chars[id - 1] + char] += 1
                            bichar_counter[char + nullsymbol] += 1
                            extbichar_counter[inst.chars[id - 1] + char] += 1
                            extbichar_counter[char + nullsymbol] += 1
                        else:
                            inst.left_bichars.append(inst.chars[id - 1] + char)
                            # inst.right_bichars.append(char+inst.chars[id+1])
                            inst.extleft_bichars.append(inst.chars[id - 1] +
                                                        char)
                            # inst.extright_bichars.append(char + inst.chars[id + 1])
                            right_bichars.append(char + inst.chars[id + 1])
                            # right_index_mark.append(1)
                            bichar_counter[inst.chars[id - 1] + char] += 1
                            bichar_counter[char + inst.chars[id + 1]] += 1
                            extbichar_counter[inst.chars[id - 1] + char] += 1
                            extbichar_counter[char + inst.chars[id + 1]] += 1
                # right_bichars = list(reversed(right_bichars))         #####
                # print(right_bichars)
                # print(right_index_mark)
                # right_index_mark = list(reversed(right_index_mark))
                inst.right_bichars = right_bichars
                inst.extright_bichars = right_bichars
                # print(line)
                # print(inst.words)
                # print(inst.chars)
                # print(inst.char_type)
                # print(inst.word_seg)
                # print(inst.word_seg_pos)
                # print(inst.gold)
                # print(inst.pos)
                # print(inst.left_bichars)
                # print(inst.right_bichars)
                count += 1
                insts.append(inst)

        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, char_counter, extchar_counter,
                                bichar_counter, extbichar_counter,
                                char_type_counter, pos_counter,
                                wordlen_counter, gold_counter,
                                shrink_feature_threshold)
        # insts_index = []

        for inst in insts:
            inst.words_index = [
                self.word_alphabet.get_index(w) for w in inst.words
            ]
            inst.chars_index = [
                self.char_alphabet.get_index(c) for c in inst.chars
            ]
            inst.extchars_index = [
                self.extchar_alphabet.get_index(ec) for ec in inst.extchars
            ]
            inst.left_bichar_index = [
                self.bichar_alphabet.get_index(b) for b in inst.left_bichars
            ]
            inst.right_bichar_index = [
                self.bichar_alphabet.get_index(b) for b in inst.right_bichars
            ]
            inst.extleft_bichar_index = [
                self.extbichar_alphabet.get_index(eb)
                for eb in inst.extleft_bichars
            ]
            inst.extright_bichar_index = [
                self.extbichar_alphabet.get_index(eb)
                for eb in inst.extright_bichars
            ]
            inst.pos_index = [
                self.pos_alphabet.get_index(p) for p in inst.gold_pos
            ]
            inst.char_type_index = [
                self.char_type_alphabet.get_index(t) for t in inst.char_type
            ]
            # inst.char_type_index = []
            # for t in inst.char_type:
            #     print(t)
            #     print(self.char_type_alphabet.word2id)
            #     temp = self.char_type_alphabet.get_index(t)
            #     print(temp)
            #     inst.char_type_index.append(temp)
            inst.segpos_index = [
                self.segpos_alphabet.get_index(g) for g in inst.gold_action
            ]
            inst.word_len_index = [
                self.wordlen_alphabet.get_index(w) for w in inst.word_len
            ]
            inst.last_wordlen_index = [
                self.wordlen_alphabet.get_index(l) for l in inst.last_wordlen
            ]
            inst.last_pos_index = [
                self.pos_alphabet.get_index(p) for p in inst.last_pos
            ]

        self.wordPadID = self.word_alphabet.get_index(pad)
        self.charPadID = self.char_alphabet.get_index(pad)
        self.bicharPadID = self.bichar_alphabet.get_index(pad)
        self.charTypePadID = self.char_type_alphabet.get_index(pad)
        self.wordlenPadID = self.wordlen_alphabet.get_index(pad)
        self.posPadID = self.pos_alphabet.get_index(pad)
        self.appID = self.segpos_alphabet.get_index(app)
        self.actionPadID = self.segpos_alphabet.get_index(pad)

        ##### sorted sentences
        # insts_sorted, insts_index_sorted = utils.sorted_instances(insts, insts_index)
        return insts

    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_uniform(
            emb_path, self.word_alphabet.word2id, word_dims,
            self.norm_word_emb)

    # 可以优化
    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_uniform(
            emb_path, self.extchar_alphabet.word2id, char_dims,
            self.norm_char_emb)

    # 可以优化
    def build_bichar_pretrain_emb(self, emb_path, bichar_dims):
        self.pretrain_bichar_embedding = utils.load_pretrained_emb_uniform(
            emb_path, self.extbichar_alphabet.word2id, bichar_dims,
            self.norm_bichar_emb)

    def generate_batch_buckets(self, batch_size, insts):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        buckets = [[[], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
                   for _ in range(batch_num)]
        # labels_raw = [[] for _ in range(batch_num)]
        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sort_instances(inst_save)
                max_length = len(inst_sorted[0].chars_index)
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy].chars_index)

                    buckets[idx - 1][0].append(
                        inst_sorted[idy].chars_index +
                        [self.char_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][1].append(
                        inst_sorted[idy].left_bichar_index +
                        [self.bichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][2].append(
                        inst_sorted[idy].right_bichar_index +
                        [self.bichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][3].append(
                        inst_sorted[idy].extchars_index +
                        [self.extchar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][4].append(
                        inst_sorted[idy].extleft_bichar_index +
                        [self.extbichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][5].append(
                        inst_sorted[idy].extright_bichar_index +
                        [self.extbichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][6].append(
                        inst_sorted[idy].char_type_index +
                        [self.char_type_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][7].append(
                        inst_sorted[idy].segpos_index +
                        [self.segpos_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][8].append(
                        inst_sorted[idy].last_wordlen_index +
                        [self.wordlen_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx -
                            1][9].append(inst_sorted[idy].last_pos_index +
                                         [self.pos_alphabet.word2id['<pad>']] *
                                         (max_length - cur_length))
                    buckets[idx - 1][10].append(inst_sorted[idy].gold_action +
                                                ['<pad>'] *
                                                (max_length - cur_length))
                    buckets[idx - 1][11].append(inst_sorted[idy].gold_pos +
                                                ['<pad>'] *
                                                (max_length - cur_length))
                    buckets[idx -
                            1][12].append(inst_sorted[idy].chars + ['<pad>'] *
                                          (max_length - cur_length))
                    buckets[idx -
                            1][13].append(inst_sorted[idy].words + ['<pad>'] *
                                          (max_length - cur_length))

                    buckets[idx - 1][-1].append([1] * cur_length + [0] *
                                                (max_length - cur_length))
                    # labels_raw[idx-1].append(inst_sorted[idy][-1])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sort_instances(inst_save)
            max_length = len(inst_sorted[0].chars_index)
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy].chars_index)
                buckets[batch_num -
                        1][0].append(inst_sorted[idy].chars_index +
                                     [self.char_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][1].append(inst_sorted[idy].left_bichar_index +
                                     [self.bichar_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][2].append(inst_sorted[idy].right_bichar_index +
                                     [self.bichar_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][3].append(inst_sorted[idy].extchars_index +
                                     [self.extchar_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num - 1][4].append(
                    inst_sorted[idy].extleft_bichar_index +
                    [self.extbichar_alphabet.word2id['<pad>']] *
                    (max_length - cur_length))
                buckets[batch_num - 1][5].append(
                    inst_sorted[idy].extright_bichar_index +
                    [self.extbichar_alphabet.word2id['<pad>']] *
                    (max_length - cur_length))
                buckets[batch_num - 1][6].append(
                    inst_sorted[idy].char_type_index +
                    [self.char_type_alphabet.word2id['<pad>']] *
                    (max_length - cur_length))
                buckets[batch_num -
                        1][7].append(inst_sorted[idy].segpos_index +
                                     [self.segpos_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][8].append(inst_sorted[idy].last_wordlen_index +
                                     [self.wordlen_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][9].append(inst_sorted[idy].last_pos_index +
                                     [self.pos_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][10].append(inst_sorted[idy].gold_action +
                                      ['<pad>'] * (max_length - cur_length))
                buckets[batch_num -
                        1][11].append(inst_sorted[idy].gold_pos + ['<pad>'] *
                                      (max_length - cur_length))
                buckets[batch_num -
                        1][12].append(inst_sorted[idy].chars + ['<pad>'] *
                                      (max_length - cur_length))
                buckets[batch_num -
                        1][13].append(inst_sorted[idy].words + ['<pad>'] *
                                      (max_length - cur_length))

                buckets[batch_num - 1][-1].append([1] * cur_length + [0] *
                                                  (max_length - cur_length))
                # labels_raw[batch_num-1].append(inst_sorted[idy][-1])
        return buckets

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] +
                                   [self.word_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] +
                                   inst[-1] +
                                   [self.label_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append(
                    (inst[1] +
                     [[self.char_alphabet.word2id['<pad>']] * char_length] *
                     (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] *
                                    (max_length - (cur_length + 1)))

        return buckets
Ejemplo n.º 2
0
class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('char')
        self.label_alphabet = Alphabet('label', is_label=True)

        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        self.max_char_length = 0

        self.word_num = 0
        self.char_num = 0
        self.label_num = 0

    def build_alphabet(self,
                       word_counter,
                       char_counter,
                       label_counter,
                       shrink_feature_threshold,
                       char=False):
        for word, count in word_counter.most_common():
            if count > shrink_feature_threshold:
                # self.word_alphabet.id2word.append(word)
                # self.word_alphabet.id2count.append(count)
                # if self.number_normalized: word = utils.normalize_word(word)
                self.word_alphabet.add(word, count)
        for label, count in label_counter.most_common():
            if count > shrink_feature_threshold:
                # self.label_alphabet.id2word.append(label)
                # self.label_alphabet.id2count.append(count)
                self.label_alphabet.add(label, count)

        # another method
        # reverse = lambda x: dict(zip(x, range(len(x))))
        # self.word_alphabet.word2id = reverse(self.word_alphabet.id2word)
        # self.label_alphabet.word2id = reverse(self.label_alphabet.id2word)

        ##### check
        if len(self.word_alphabet.word2id) != len(
                self.word_alphabet.id2word) or len(
                    self.word_alphabet.id2count) != len(
                        self.word_alphabet.id2word):
            print('there are errors in building word alphabet.')
        if len(self.label_alphabet.word2id) != len(
                self.label_alphabet.id2word) or len(
                    self.label_alphabet.id2count) != len(
                        self.label_alphabet.id2word):
            print('there are errors in building label alphabet.')

        if char:
            for char, count in char_counter.most_common():
                if count > shrink_feature_threshold:
                    self.char_alphabet.add(char, count)
                    # self.char_alphabet.id2word.append(char)
                    # self.char_alphabet.id2count.append(count)
            # self.char_alphabet.word2id = reverse(self.char_alphabet.id2word)
            if len(self.char_alphabet.word2id) != len(
                    self.char_alphabet.id2word) or len(
                        self.char_alphabet.id2count) != len(
                            self.char_alphabet.id2word):
                print('there are errors in building char alphabet.')

    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.char_num = self.char_alphabet.close()
        self.label_num = self.label_alphabet.close()

    def get_instance(self,
                     file,
                     run_insts,
                     shrink_feature_threshold,
                     char=False,
                     char_padding_symbol='<pad>'):
        words = []
        chars = []
        labels = []
        insts = []
        word_counter = Counter()
        char_counter = Counter()
        label_counter = Counter()
        char_length_max = 0
        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            ##### if one sentence is a line, you can use the method to control instances for debug.
            # if run_insts == -1:
            #     fin_lines = f.readlines()
            # else:
            #     fin_lines = f.readlines()[:run_insts]
            # in_lines = open(file, 'r', encoding='utf-8').readlines()
            for line in f.readlines():
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    if line[0] == 'token':
                        word = line[1]
                        if self.number_normalized:
                            word = utils.normalize_word(word)
                        label = line[-1]
                        words.append(word)
                        labels.append(label)
                        word_counter[word] += 1  #####
                        label_counter[label] += 1

                        if char:
                            char_list = []
                            for char in word:
                                char_list.append(char)
                                char_counter[char] += 1
                            chars.append(char_list)
                            char_length = len(char_list)
                            if char_length > char_length_max:
                                char_length_max = char_length
                            if char_length_max > self.max_char_length:
                                self.max_char_length = char_length_max
                else:
                    if char:
                        chars_padded = []
                        for index, char_list in enumerate(chars):
                            char_number = len(char_list)
                            if char_number < char_length_max:
                                char_list = char_list + [
                                    char_padding_symbol
                                ] * (char_length_max - char_number)
                                char_counter[char_padding_symbol] += (
                                    char_length_max - char_number)
                            chars_padded.append(char_list)
                            assert (len(char_list) == char_length_max)

                        insts.append([words, chars_padded, labels])
                    else:
                        insts.append([words, labels])
                    words = []
                    chars = []
                    labels = []
                    char_length_max = 0
                    count += 1
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, char_counter, label_counter,
                                shrink_feature_threshold, char)
        insts_index = []

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [self.label_alphabet.get_index(l) for l in inst[-1]]
            chars_index = []
            if char:
                # words, chars, labels = inst
                # words_index = [self.word_alphabet.get_index(w) for w in words]
                # labels_index = [self.label_alphabet.get_index(l) for l in labels]
                # char_index = []
                for char in inst[1]:
                    char_index = [
                        self.char_alphabet.get_index(c) for c in char
                    ]
                    chars_index.append(char_index)
                insts_index.append([words_index, chars_index, labels_index])
            else:
                # words, labels = inst
                # words_index = [self.word_alphabet.get_index(w) for w in words]
                # labels_index = [self.label_alphabet.get_index(l) for l in labels]
                insts_index.append([words_index, labels_index])

        ##### sorted sentences
        # insts_sorted, insts_index_sorted = utils.sorted_instances(insts, insts_index)
        return insts, insts_index

    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.word_alphabet.word2id, word_dims,
            self.norm_word_emb)

    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.char_alphabet.word2id, char_dims,
            self.norm_char_emb)

    def generate_batch_buckets(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        labels_raw = [[] for _ in range(batch_num)]
        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sorted_instances_index(inst_save)
                max_length = len(inst_sorted[0][0])
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy][0])
                    buckets[idx - 1][0].append(
                        inst_sorted[idy][0] +
                        [self.word_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][1].append(
                        inst_sorted[idy][-1] +
                        [self.label_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))

                    if char:
                        cur_char_length = len(inst_sorted[idy][1][0])
                        inst_sorted[idy][1] = [
                            (ele + [self.char_alphabet.word2id['<pad>']] *
                             (self.max_char_length - cur_char_length))
                            for ele in inst_sorted[idy][1]
                        ]
                        buckets[idx - 1][2].append(
                            (inst_sorted[idy][1] +
                             [[self.char_alphabet.word2id['<pad>']] *
                              self.max_char_length] *
                             (max_length - cur_length)))
                    buckets[idx - 1][-1].append([1] * cur_length + [0] *
                                                (max_length - cur_length))
                    labels_raw[idx - 1].append(inst_sorted[idy][-1])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sorted_instances_index(inst_save)
            max_length = len(inst_sorted[0][0])
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy][0])
                buckets[batch_num -
                        1][0].append(inst_sorted[idy][0] +
                                     [self.word_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][1].append(inst_sorted[idy][-1] +
                                     [self.label_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                if char:
                    cur_char_length = len(inst_sorted[idy][1][0])
                    inst_sorted[idy][1] = [
                        (ele + [self.char_alphabet.word2id['<pad>']] *
                         (self.max_char_length - cur_char_length))
                        for ele in inst_sorted[idy][1]
                    ]
                    buckets[batch_num - 1][2].append(
                        (inst_sorted[idy][1] +
                         [[self.char_alphabet.word2id['<pad>']] *
                          self.max_char_length] * (max_length - cur_length)))
                buckets[batch_num - 1][-1].append([1] * cur_length + [0] *
                                                  (max_length - cur_length))
                labels_raw[batch_num - 1].append(inst_sorted[idy][-1])
        return buckets, labels_raw

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] +
                                   [self.word_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] +
                                   inst[-1] +
                                   [self.label_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append(
                    (inst[1] +
                     [[self.char_alphabet.word2id['<pad>']] * char_length] *
                     (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] *
                                    (max_length - (cur_length + 1)))

        return buckets
Ejemplo n.º 3
0
class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.category_alphabet = Alphabet('category', is_category=True)
        self.label_alphabet = Alphabet('label', is_label=True)
        self.char_alphabet = Alphabet('char')
        self.pos_alphabet = Alphabet('pos')
        self.parent_alphabet = Alphabet('parent')
        self.rel_alphabet = Alphabet('rel')

        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        self.max_char_length = 0

        self.word_num = 0
        self.char_num = 0
        self.label_num = 0
        self.category_num = 0
        self.parent_num = 0
        self.pos_num = 0
        self.rel_num = 0


    def build_alphabet(self, word_counter, label_counter, category_counter, pos_counter, rel_counter, shrink_feature_threshold, char=False):
        for word, count in word_counter.most_common():
            if count > shrink_feature_threshold:
                self.word_alphabet.add(word, count)
        for label, count in label_counter.most_common():
            self.label_alphabet.add(label, count)
        for category, count in category_counter.most_common():
            self.category_alphabet.add(category, count)
        for pos, count in pos_counter.most_common():
            self.pos_alphabet.add(pos, count)
        # for parent, count in parent_counter.most_common():
        #     self.parent_alphabet.add(parent, count)
        # print(rel_counter)
        for rel, count in rel_counter.most_common():
            self.rel_alphabet.add(rel, count)


    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.category_num = self.category_alphabet.close()
        self.label_num = self.label_alphabet.close()
        # self.category_num = self.category_alphabet.close()
        self.parent_num = self.parent_alphabet.close()
        self.pos_num = self.pos_alphabet.close()
        self.rel_num = self.rel_alphabet.close()


    def get_instance(self, file, run_insts, shrink_feature_threshold, char=False, char_padding_symbol='<pad>'):
        words = []
        labels = []
        categorys = []
        poss = []
        parents = []
        rels = []
        insts = []
        word_counter = Counter()
        label_counter = Counter()
        category_counter = Counter()
        pos_counter = Counter()
        # parent_counter = Counter()
        rel_counter = Counter()

        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            for id, line in enumerate(f.readlines()):
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    if '' in line: line.remove('')
                    if len(line) != 6: print(id)
                    # print(line)
                    word = line[0]
                    if self.number_normalized: word = utils.normalize_word(word)
                    label = line[1]
                    category = line[2]
                    pos = line[3]
                    parent = line[-2]
                    if ',' in parent:
                        # print(parent)
                        parent = parent.split(',')[0]
                    # if parent == '':
                    #     print(id)
                    rel = line[-1]

                    words.append(word)
                    labels.append(label)
                    categorys.append(category)
                    poss.append(pos)
                    parents.append(parent)
                    rels.append(rel)

                    word_counter[word] += 1        #####
                    label_counter[label] += 1
                    category_counter[category] += 1
                    pos_counter[pos] += 1
                    # parent_counter[category] += 1
                    rel_counter[rel] += 1
                else:
                    # print(words)
                    # print(parents)
                    insts.append([words, labels, categorys, poss, parents, rels])
                    words = []
                    labels = []
                    categorys = []
                    poss = []
                    parents = []
                    rels = []
                    count += 1
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, label_counter, category_counter, pos_counter, rel_counter, shrink_feature_threshold, char)
        insts_index = []

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [self.label_alphabet.get_index(l) for l in inst[1]]
            categorys_index = [self.category_alphabet.get_index(c) for c in inst[2]]
            poss_index = [self.pos_alphabet.get_index(p) for p in inst[3]]
            # parents_index = [self.parent_alphabet.get_index(p) for p in inst[-2]]
            parents_index = [int(p)-1 for p in inst[-2]]
            rels_index = [self.rel_alphabet.get_index(r) for r in inst[-1]]
            insts_index.append([words_index, labels_index, categorys_index, poss_index, parents_index, rels_index])
        return insts, insts_index


    def get_instance_tree(self, file, run_insts, shrink_feature_threshold, write_word_path, label_path, pos_path, parse_path, category_path, char=False, char_padding_symbol='<pad>'):
        words = []
        labels = []
        categorys = []
        poss = []
        parses = []
        insts = []
        word_counter = Counter()
        label_counter = Counter()
        category_counter = Counter()
        pos_counter = Counter()
        parse_counter = Counter()

        word_file = open(write_word_path, 'w', encoding='utf-8')
        label_file = open(label_path, 'w', encoding='utf-8')
        pos_file = open(pos_path, 'w', encoding='utf-8')
        parse_file = open(parse_path, 'w', encoding='utf-8')
        category_file = open(category_path, 'w', encoding='utf-8')

        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            for id, line in enumerate(f.readlines()):
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    # print(line)
                    word = line[0]
                    if self.number_normalized: word = utils.normalize_word(word)
                    label = line[1]
                    category = line[2]
                    pos = line[-2]
                    parse = line[-1]

                    words.append(word)
                    labels.append(label)
                    categorys.append(category)
                    parses.append(parse)
                    poss.append(pos)

                    word_counter[word] += 1        #####
                    label_counter[label] += 1
                    category_counter[category] += 1
                    pos_counter[pos] += 1
                    parse_counter[category] += 1
                else:
                    # print(words)
                    word_write = ' '.join(words)
                    word_file.write(word_write)
                    word_file.write('\n')
                    parse_write = ' '.join(parses)
                    parse_file.write(parse_write)
                    parse_file.write('\n')

                    insts.append([words, labels, categorys, poss, parses])
                    words = []
                    labels = []
                    categorys = []
                    poss = []
                    parses = []
                    count += 1
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, label_counter, category_counter, pos_counter, parse_counter, shrink_feature_threshold, char)
        insts_index = []

        path = r"C:\Users\song\Desktop\treelstm_word\examples\vocab-2.txt"
        file = open(path, 'w', encoding='utf-8')
        words = self.word_alphabet.id2word
        # print(len(words))       # 6799, 6310
        for id in range(len(words)):
            file.write(words[id])
            file.write('\n')
        file.close()

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [str(self.label_alphabet.get_index(l)) for l in inst[1]]
            categorys_index = [str(self.category_alphabet.get_index(c)) for c in inst[2]]
            pos_index = [str(self.pos_alphabet.get_index(p)) for p in inst[-2]]
            parses_index = [self.parse_alphabet.get_index(p) for p in inst[-1]]
            insts_index.append([words_index, labels_index, categorys_index, pos_index, parses_index])

            label_write = ' '.join(labels_index)
            label_file.write(label_write)
            label_file.write('\n')
            pos_write = ' '.join(pos_index)
            pos_file.write(pos_write)
            pos_file.write('\n')
            category_write = ' '.join(categorys_index)
            category_file.write(category_write)
            category_file.write('\n')
        return insts, insts_index


    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_avg(emb_path, self.word_alphabet.word2id, word_dims, self.norm_word_emb)

    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_avg(emb_path, self.char_alphabet.word2id, char_dims, self.norm_char_emb)


    def generate_batch_buckets(self, batch_size, insts, char=False):
        batch_num = int(math.ceil(len(insts) / batch_size))
        buckets = [[[], [], [], [], [], []] for _ in range(batch_num)]
        labels_raw = [[] for _ in range(batch_num)]
        category_raw = [[] for _ in range(batch_num)]
        target_start = [[] for _ in range(batch_num)]
        target_end = [[] for _ in range(batch_num)]

        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sorted_instances_index(inst_save)
                max_length = len(inst_sorted[0][0])
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy][0])
                    buckets[idx-1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][1].append(inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][2].append(inst_sorted[idy][-3] + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][3].append(inst_sorted[idy][-2] + [self.parent_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][4].append(inst_sorted[idy][-1] + [self.rel_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][-1].append([1] * cur_length + [0] * (max_length - cur_length))
                    labels_raw[idx-1].append(inst_sorted[idy][1])

                    start, end = evaluation.extract_target(inst_sorted[idy][1], self.label_alphabet)
                    if start == []:
                        start = [0]
                        end = [0]
                    target_start[idx-1].append(start[0])
                    target_end[idx-1].append(end[0])
                    # target_start.extend(start)
                    # target_end.extend(end)
                    category_raw[idx-1].append(inst_sorted[idy][2][0])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sorted_instances_index(inst_save)
            max_length = len(inst_sorted[0][0])
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy][0])
                buckets[batch_num-1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][1].append(inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][2].append(inst_sorted[idy][-3] + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][3].append(inst_sorted[idy][-2] + [self.parent_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][4].append(inst_sorted[idy][-1] + [self.rel_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][-1].append([1] * cur_length + [0] * (max_length - cur_length))
                labels_raw[batch_num-1].append(inst_sorted[idy][1])
                category_raw[batch_num-1].append(inst_sorted[idy][2][0])
                start, end = evaluation.extract_target(inst_sorted[idy][1], self.label_alphabet)
                if start == []:
                    start = [0]
                    end = [0]
                target_start[batch_num - 1].append(start[0])
                target_end[batch_num - 1].append(end[0])
        # print(buckets)
        # print(labels_raw)
        # print(category_raw)
        # print(target_start)
        # print(target_end)
        return buckets, labels_raw, category_raw, target_start, target_end

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] + inst[-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append((inst[1] + [[self.char_alphabet.word2id['<pad>']] * char_length] * (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] * (max_length - (cur_length + 1)))

        return buckets
Ejemplo n.º 4
0
class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.category_alphabet = Alphabet('category', is_category=True)
        self.label_alphabet = Alphabet('label', is_label=True)
        self.char_alphabet = Alphabet('char')

        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        self.max_char_length = 0

        self.word_num = 0
        self.char_num = 0
        self.label_num = 0

    def build_alphabet(self,
                       word_counter,
                       label_counter,
                       category_counter,
                       shrink_feature_threshold,
                       char=False):
        for word, count in word_counter.most_common():
            if count > shrink_feature_threshold:
                self.word_alphabet.add(word, count)
        for label, count in label_counter.most_common():
            self.label_alphabet.add(label, count)
        for category, count in category_counter.most_common():
            self.category_alphabet.add(category, count)

        ##### check
        if len(self.word_alphabet.word2id) != len(
                self.word_alphabet.id2word) or len(
                    self.word_alphabet.id2count) != len(
                        self.word_alphabet.id2word):
            print('there are errors in building word alphabet.')
        if len(self.label_alphabet.word2id) != len(
                self.label_alphabet.id2word) or len(
                    self.label_alphabet.id2count) != len(
                        self.label_alphabet.id2word):
            print('there are errors in building label alphabet.')
        if len(self.category_alphabet.word2id) != len(
                self.category_alphabet.id2word) or len(
                    self.category_alphabet.id2count) != len(
                        self.category_alphabet.id2word):
            print('there are errors in building category alphabet.')

    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.category_num = self.category_alphabet.close()
        self.label_num = self.label_alphabet.close()

    def get_instance(self,
                     file,
                     run_insts,
                     shrink_feature_threshold,
                     char=False,
                     char_padding_symbol='<pad>'):
        words = []
        labels = []
        categorys = []
        insts = []
        word_counter = Counter()
        char_counter = Counter()
        label_counter = Counter()
        category_counter = Counter()

        count = 0
        ner_num = 0
        with open(file, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    word = line[0]
                    if self.number_normalized:
                        word = utils.normalize_word(word)
                    if len(list(line[1])) > 1:
                        label = line[1].split('-')[0]
                        if label == 'b':
                            ner_num += 1
                        category = line[1].split('-')[1]
                        categorys.append(category)
                        category_counter[category] += 1
                    else:
                        label = line[1]
                    label = label + '-target'
                    words.append(word)
                    labels.append(label)
                    # categorys.append(category)
                    word_counter[word] += 1  #####
                    label_counter[label] += 1
                    # category_counter[category] += 1
                else:
                    insts.append([words, labels, categorys])
                    words = []
                    labels = []
                    categorys = []
                    count += 1
                    if ner_num > 1: print(ner_num)
                    ner_num = 0
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, label_counter, category_counter,
                                shrink_feature_threshold, char)
        insts_index = []

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [self.label_alphabet.get_index(l) for l in inst[1]]
            length = len(labels_index)
            categorys_index = [self.category_alphabet.get_index(inst[-1][0])
                               ] * length
            # print(len(categorys_index))
            insts_index.append([words_index, labels_index, categorys_index])

        return insts, insts_index

    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.word_alphabet.word2id, word_dims,
            self.norm_word_emb)

    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.char_alphabet.word2id, char_dims,
            self.norm_char_emb)

    def generate_batch_buckets(self, batch_size, insts, char=False):
        batch_num = int(math.ceil(len(insts) / batch_size))
        buckets = [[[], [], []] for _ in range(batch_num)]
        labels_raw = [[] for _ in range(batch_num)]
        category_raw = [[] for _ in range(batch_num)]
        target_start = [[] for _ in range(batch_num)]
        target_end = [[] for _ in range(batch_num)]

        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sorted_instances_index(inst_save)
                max_length = len(inst_sorted[0][0])
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy][0])
                    buckets[idx - 1][0].append(
                        inst_sorted[idy][0] +
                        [self.word_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][1].append(
                        inst_sorted[idy][1] +
                        [self.label_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][-1].append([1] * cur_length + [0] *
                                                (max_length - cur_length))
                    labels_raw[idx - 1].append(inst_sorted[idy][1])

                    start, end = evaluation.extract_target(
                        inst_sorted[idy][1], self.label_alphabet)
                    target_start[idx - 1].append(start[0])
                    target_end[idx - 1].append(end[0])
                    # target_start.extend(start)
                    # target_end.extend(end)
                    category_raw[idx - 1].append(inst_sorted[idy][-1][0])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sorted_instances_index(inst_save)
            max_length = len(inst_sorted[0][0])
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy][0])
                buckets[batch_num -
                        1][0].append(inst_sorted[idy][0] +
                                     [self.word_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][1].append(inst_sorted[idy][1] +
                                     [self.label_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num - 1][-1].append([1] * cur_length + [0] *
                                                  (max_length - cur_length))
                labels_raw[batch_num - 1].append(inst_sorted[idy][1])
                category_raw[batch_num - 1].append(inst_sorted[idy][-1][0])
                start, end = evaluation.extract_target(inst_sorted[idy][1],
                                                       self.label_alphabet)
                target_start[batch_num - 1].append(start[0])
                target_end[batch_num - 1].append(end[0])
        # print(buckets)
        # print(labels_raw)
        # print(category_raw)
        # print(target_start)
        # print(target_end)
        return buckets, labels_raw, category_raw, target_start, target_end

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] +
                                   [self.word_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] +
                                   inst[-1] +
                                   [self.label_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append(
                    (inst[1] +
                     [[self.char_alphabet.word2id['<pad>']] * char_length] *
                     (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] *
                                    (max_length - (cur_length + 1)))

        return buckets