示例#1
0
class Model(base.Model):
    # overwrite config
    name = 'char2word_to_char/europarl-ende'
    # datasets
    train_x_files = ['data/train/europarl-v7.de-en.en']
    train_t_files = ['data/train/europarl-v7.de-en.de']
    valid_x_files = ['data/valid/devtest2006.en']
    valid_t_files = ['data/valid/devtest2006.de']
    test_x_files = ['data/valid/test2007.en']
    test_t_files = ['data/valid/test2007.de']

    # settings that are local to the model
    alphabet_src = Alphabet('data/alphabet/dict_europarl.de-en.en', eos='*')
    alphabet_tar = Alphabet('data/alphabet/dict_europarl.de-en.de', eos='*', sos='')
示例#2
0
class Model(base.Model):
    # overwrite config
    name = 'char2word_to_char/wmt-deen'
    train_x_files = [
        'data/train/europarl-v7.de-en.de.tok',
        'data/train/commoncrawl.de-en.de.tok',
        'data/train/news-commentary-v10.de-en.de.tok'
    ]
    train_t_files = [
        'data/train/europarl-v7.de-en.en.tok',
        'data/train/commoncrawl.de-en.en.tok',
        'data/train/news-commentary-v10.de-en.en.tok'
    ]
    valid_x_files = ['data/valid/newstest2013.de.tok']
    valid_t_files = ['data/valid/newstest2013.en.tok']
    test_x_files = ['data/valid/newstest2014.deen.de.tok']
    test_t_files = ['data/valid/newstest2014.deen.en.tok']

    # settings that are local to the model
    alphabet_src = Alphabet('data/alphabet/dict_wmt_tok.de-en.de', eos='*')
    alphabet_tar = Alphabet('data/alphabet/dict_wmt_tok.de-en.en',
                            eos='*',
                            sos='')
示例#3
0
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.category_alphabet = Alphabet('category', is_category=True)
        self.label_alphabet = Alphabet('label', is_label=True)
        self.char_alphabet = Alphabet('char')

        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        self.max_char_length = 0

        self.word_num = 0
        self.char_num = 0
        self.label_num = 0
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('char')
        self.bichar_alphabet = Alphabet('bichar')

        self.pos_alphabet = Alphabet('pos', is_label=True)
        self.char_type_alphabet = Alphabet('type', is_label=True)

        self.extchar_alphabet = Alphabet('extchar')
        self.extbichar_alphabet = Alphabet('extbichar')

        self.segpos_alphabet = Alphabet('segpos', is_label=True)
        self.wordlen_alphabet = Alphabet('wordlen')

        self.number_normalized = False
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.norm_bichar_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.pretrain_bichar_embedding = None

        self.wordPadID = 0
        self.charPadID = 0
        self.bicharPadID = 0
        self.charTypePadID = 0
        self.wordlenPadID = 0
        self.posPadID = 0
        self.appID = 0
        self.actionPadID = 0

        self.word_num = 0
        self.char_num = 0
        self.pos_num = 0
        self.bichar_num = 0
        self.segpos_num = 0
        self.wordlen_num = 0
        self.char_type_num = 0

        self.extchar_num = 0
        self.extbichar_num = 0

        self.wordlen_max = 7
class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('char')
        self.bichar_alphabet = Alphabet('bichar')

        self.pos_alphabet = Alphabet('pos', is_label=True)
        self.char_type_alphabet = Alphabet('type', is_label=True)

        self.extchar_alphabet = Alphabet('extchar')
        self.extbichar_alphabet = Alphabet('extbichar')

        self.segpos_alphabet = Alphabet('segpos', is_label=True)
        self.wordlen_alphabet = Alphabet('wordlen')

        self.number_normalized = False
        self.norm_word_emb = False
        self.norm_char_emb = False
        self.norm_bichar_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None
        self.pretrain_bichar_embedding = None

        self.wordPadID = 0
        self.charPadID = 0
        self.bicharPadID = 0
        self.charTypePadID = 0
        self.wordlenPadID = 0
        self.posPadID = 0
        self.appID = 0
        self.actionPadID = 0

        self.word_num = 0
        self.char_num = 0
        self.pos_num = 0
        self.bichar_num = 0
        self.segpos_num = 0
        self.wordlen_num = 0
        self.char_type_num = 0

        self.extchar_num = 0
        self.extbichar_num = 0

        self.wordlen_max = 7

    def build_alphabet(self, word_counter, char_counter, extchar_counter,
                       bichar_counter, extbichar_counter, char_type_counter,
                       pos_counter, wordlen_counter, gold_counter,
                       shrink_feature_threshold):
        # 可以优化,但是会比较混乱
        for word, count in word_counter.most_common():
            # if count > shrink_feature_threshold:
            self.word_alphabet.add(word, count)
        for char, count in char_counter.most_common():
            if count > shrink_feature_threshold:
                self.char_alphabet.add(char, count)
        for extchar, count in extchar_counter.most_common():
            # if count > shrink_feature_threshold:
            self.extchar_alphabet.add(extchar, count)
        for bichar, count in bichar_counter.most_common():
            if count > shrink_feature_threshold:
                self.bichar_alphabet.add(bichar, count)
        for extbichar, count in extbichar_counter.most_common():
            # if count > shrink_feature_threshold:
            self.extbichar_alphabet.add(extbichar, count)
        for char_type, count in char_type_counter.most_common():
            # if count > shrink_feature_threshold:
            self.char_type_alphabet.add(char_type, count)
        for pos, count in pos_counter.most_common():
            # if count > shrink_feature_threshold:
            self.pos_alphabet.add(pos, count)
        for wordlen, count in wordlen_counter.most_common():
            # if count > shrink_feature_threshold:
            self.wordlen_alphabet.add(wordlen, count)
        for segpos, count in gold_counter.most_common():
            # if count > shrink_feature_threshold:
            self.segpos_alphabet.add(segpos, count)
        # another method
        # reverse = lambda x: dict(zip(x, range(len(x))))
        # self.word_alphabet.word2id = reverse(self.word_alphabet.id2word)
        # self.label_alphabet.word2id = reverse(self.label_alphabet.id2word)

        ##### check
        if len(self.word_alphabet.word2id) != len(
                self.word_alphabet.id2word) or len(
                    self.word_alphabet.id2count) != len(
                        self.word_alphabet.id2word):
            print('there are errors in building word alphabet.')
        if len(self.char_alphabet.word2id) != len(
                self.char_alphabet.id2word) or len(
                    self.char_alphabet.id2count) != len(
                        self.char_alphabet.id2word):
            print('there are errors in building char alphabet.')

    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.char_num = self.char_alphabet.close()
        self.pos_num = self.pos_alphabet.close()
        self.bichar_num = self.bichar_alphabet.close()
        self.segpos_num = self.segpos_alphabet.close()
        # self.wordlen_max = self.wordlen_alphabet.size()-2            ######
        # print(self.wordlen_max)
        # self.wordlen_alphabet.add(self.wordlen_max+1)
        self.wordlen_num = self.wordlen_alphabet.close()
        # print(self.wordlen_num)
        self.char_type_num = self.char_type_alphabet.close()

    def fix_static_alphabet(self):
        self.extchar_num = self.extchar_alphabet.close()
        self.extbichar_num = self.extbichar_alphabet.close()

    def get_instance(self, file, run_insts, shrink_feature_threshold):
        insts = []
        word_counter = Counter()
        char_counter = Counter()
        bichar_counter = Counter()
        char_type_counter = Counter()
        gold_counter = Counter()
        pos_counter = Counter()
        extchar_counter = Counter()
        extbichar_counter = Counter()
        wordlen_counter = Counter()
        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                if run_insts == count: break
                line = line.strip().split(' ')
                inst = Instance()
                char_num = 0
                start = 0
                for idx, ele in enumerate(line):
                    word = ele.split('_')[0]
                    if self.number_normalized:
                        word = utils.normalize_word(word)
                    inst.words.append(word)
                    word_counter[word] += 1
                    pos = ele.split('_')[1]
                    inst.gold_pos.append(pos)
                    pos_counter[pos] += 1
                    word_list = list(word)
                    # print(len(word_list))
                    # if len(word_list) > 5:
                    #     print(len(word_list))
                    #     print(word)
                    if len(word_list) > 6:
                        # inst.word_len.append(7)
                        cur_len = 7
                    else:
                        # inst.word_len.append(len(word_list))
                        cur_len = len(word_list)
                    inst.word_len.append(cur_len)
                    wordlen_counter[cur_len] += 1
                    for id, char in enumerate(word):
                        if idx == 0 and id == 0:
                            inst.last_wordlen.append('<pad>')
                            inst.last_pos.append('<pad>')
                        elif idx != 0 and id == 0:
                            last_wordlen_cur = len(inst.words[-2])
                            if last_wordlen_cur > 6: last_wordlen_cur = 7
                            inst.last_wordlen.append(
                                last_wordlen_cur)  ### 因为当前的word和pos已经放进列表里了
                            inst.last_pos.append(inst.gold_pos[-2])
                        else:
                            last_wordlen_cur = id
                            if last_wordlen_cur > 6: last_wordlen_cur = 7
                            inst.last_wordlen.append(last_wordlen_cur)
                            inst.last_pos.append(inst.gold_pos[-1])
                        inst.chars.append(char)
                        inst.extchars.append(char)
                        char_counter[char] += 1
                        extchar_counter[char] += 1

                        char_type = instance.char_type(char)
                        # print(char_type)
                        inst.char_type.append(char_type)
                        char_type_counter[char_type] += 1
                        if id == 0:
                            inst.gold_action.append(sep + '#' + pos)
                            gold_counter[sep + '#' + pos] += 1
                            start = char_num
                        else:
                            inst.gold_action.append(app)
                            gold_counter[app] += 1
                        char_num += 1
                    inst.word_seg.append('[' + str(start) + ',' +
                                         str(char_num) + ']')
                    inst.word_seg_pos.append('[' + str(start) + ',' +
                                             str(char_num) + ']' + pos)
                right_bichars = []

                if len(inst.chars) == 1:
                    inst.left_bichars.append(nullsymbol + char)
                    inst.extleft_bichars.append(nullsymbol + char)
                    right_bichars.append(char + nullsymbol)
                    bichar_counter[nullsymbol + char] += 1
                    bichar_counter[char + nullsymbol] += 1
                    extbichar_counter[nullsymbol + char] += 1
                    extbichar_counter[char + nullsymbol] += 1
                else:
                    for id, char in enumerate(inst.chars):
                        if id == 0:
                            inst.left_bichars.append(nullsymbol + char)
                            # inst.right_bichars.append(char+inst.chars[id+1])
                            inst.extleft_bichars.append(nullsymbol + char)
                            # inst.extright_bichars.append(char+inst.chars[id+1])
                            right_bichars.append(char + inst.chars[id + 1])
                            # right_index_mark.append(1)
                            bichar_counter[nullsymbol + char] += 1
                            bichar_counter[char + inst.chars[id + 1]] += 1
                            extbichar_counter[nullsymbol + char] += 1
                            extbichar_counter[char + inst.chars[id + 1]] += 1
                        elif id == (len(inst.chars) - 1):
                            inst.left_bichars.append(inst.chars[id - 1] + char)
                            # inst.right_bichars.append(char+nullsymbol)
                            inst.extleft_bichars.append(inst.chars[id - 1] +
                                                        char)
                            # inst.extright_bichars.append(char + nullsymbol)
                            right_bichars.append(char + nullsymbol)
                            # right_index_mark.append(1)
                            bichar_counter[inst.chars[id - 1] + char] += 1
                            bichar_counter[char + nullsymbol] += 1
                            extbichar_counter[inst.chars[id - 1] + char] += 1
                            extbichar_counter[char + nullsymbol] += 1
                        else:
                            inst.left_bichars.append(inst.chars[id - 1] + char)
                            # inst.right_bichars.append(char+inst.chars[id+1])
                            inst.extleft_bichars.append(inst.chars[id - 1] +
                                                        char)
                            # inst.extright_bichars.append(char + inst.chars[id + 1])
                            right_bichars.append(char + inst.chars[id + 1])
                            # right_index_mark.append(1)
                            bichar_counter[inst.chars[id - 1] + char] += 1
                            bichar_counter[char + inst.chars[id + 1]] += 1
                            extbichar_counter[inst.chars[id - 1] + char] += 1
                            extbichar_counter[char + inst.chars[id + 1]] += 1
                # right_bichars = list(reversed(right_bichars))         #####
                # print(right_bichars)
                # print(right_index_mark)
                # right_index_mark = list(reversed(right_index_mark))
                inst.right_bichars = right_bichars
                inst.extright_bichars = right_bichars
                # print(line)
                # print(inst.words)
                # print(inst.chars)
                # print(inst.char_type)
                # print(inst.word_seg)
                # print(inst.word_seg_pos)
                # print(inst.gold)
                # print(inst.pos)
                # print(inst.left_bichars)
                # print(inst.right_bichars)
                count += 1
                insts.append(inst)

        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, char_counter, extchar_counter,
                                bichar_counter, extbichar_counter,
                                char_type_counter, pos_counter,
                                wordlen_counter, gold_counter,
                                shrink_feature_threshold)
        # insts_index = []

        for inst in insts:
            inst.words_index = [
                self.word_alphabet.get_index(w) for w in inst.words
            ]
            inst.chars_index = [
                self.char_alphabet.get_index(c) for c in inst.chars
            ]
            inst.extchars_index = [
                self.extchar_alphabet.get_index(ec) for ec in inst.extchars
            ]
            inst.left_bichar_index = [
                self.bichar_alphabet.get_index(b) for b in inst.left_bichars
            ]
            inst.right_bichar_index = [
                self.bichar_alphabet.get_index(b) for b in inst.right_bichars
            ]
            inst.extleft_bichar_index = [
                self.extbichar_alphabet.get_index(eb)
                for eb in inst.extleft_bichars
            ]
            inst.extright_bichar_index = [
                self.extbichar_alphabet.get_index(eb)
                for eb in inst.extright_bichars
            ]
            inst.pos_index = [
                self.pos_alphabet.get_index(p) for p in inst.gold_pos
            ]
            inst.char_type_index = [
                self.char_type_alphabet.get_index(t) for t in inst.char_type
            ]
            # inst.char_type_index = []
            # for t in inst.char_type:
            #     print(t)
            #     print(self.char_type_alphabet.word2id)
            #     temp = self.char_type_alphabet.get_index(t)
            #     print(temp)
            #     inst.char_type_index.append(temp)
            inst.segpos_index = [
                self.segpos_alphabet.get_index(g) for g in inst.gold_action
            ]
            inst.word_len_index = [
                self.wordlen_alphabet.get_index(w) for w in inst.word_len
            ]
            inst.last_wordlen_index = [
                self.wordlen_alphabet.get_index(l) for l in inst.last_wordlen
            ]
            inst.last_pos_index = [
                self.pos_alphabet.get_index(p) for p in inst.last_pos
            ]

        self.wordPadID = self.word_alphabet.get_index(pad)
        self.charPadID = self.char_alphabet.get_index(pad)
        self.bicharPadID = self.bichar_alphabet.get_index(pad)
        self.charTypePadID = self.char_type_alphabet.get_index(pad)
        self.wordlenPadID = self.wordlen_alphabet.get_index(pad)
        self.posPadID = self.pos_alphabet.get_index(pad)
        self.appID = self.segpos_alphabet.get_index(app)
        self.actionPadID = self.segpos_alphabet.get_index(pad)

        ##### sorted sentences
        # insts_sorted, insts_index_sorted = utils.sorted_instances(insts, insts_index)
        return insts

    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_uniform(
            emb_path, self.word_alphabet.word2id, word_dims,
            self.norm_word_emb)

    # 可以优化
    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_uniform(
            emb_path, self.extchar_alphabet.word2id, char_dims,
            self.norm_char_emb)

    # 可以优化
    def build_bichar_pretrain_emb(self, emb_path, bichar_dims):
        self.pretrain_bichar_embedding = utils.load_pretrained_emb_uniform(
            emb_path, self.extbichar_alphabet.word2id, bichar_dims,
            self.norm_bichar_emb)

    def generate_batch_buckets(self, batch_size, insts):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        buckets = [[[], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
                   for _ in range(batch_num)]
        # labels_raw = [[] for _ in range(batch_num)]
        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sort_instances(inst_save)
                max_length = len(inst_sorted[0].chars_index)
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy].chars_index)

                    buckets[idx - 1][0].append(
                        inst_sorted[idy].chars_index +
                        [self.char_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][1].append(
                        inst_sorted[idy].left_bichar_index +
                        [self.bichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][2].append(
                        inst_sorted[idy].right_bichar_index +
                        [self.bichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][3].append(
                        inst_sorted[idy].extchars_index +
                        [self.extchar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][4].append(
                        inst_sorted[idy].extleft_bichar_index +
                        [self.extbichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][5].append(
                        inst_sorted[idy].extright_bichar_index +
                        [self.extbichar_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][6].append(
                        inst_sorted[idy].char_type_index +
                        [self.char_type_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][7].append(
                        inst_sorted[idy].segpos_index +
                        [self.segpos_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][8].append(
                        inst_sorted[idy].last_wordlen_index +
                        [self.wordlen_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx -
                            1][9].append(inst_sorted[idy].last_pos_index +
                                         [self.pos_alphabet.word2id['<pad>']] *
                                         (max_length - cur_length))
                    buckets[idx - 1][10].append(inst_sorted[idy].gold_action +
                                                ['<pad>'] *
                                                (max_length - cur_length))
                    buckets[idx - 1][11].append(inst_sorted[idy].gold_pos +
                                                ['<pad>'] *
                                                (max_length - cur_length))
                    buckets[idx -
                            1][12].append(inst_sorted[idy].chars + ['<pad>'] *
                                          (max_length - cur_length))
                    buckets[idx -
                            1][13].append(inst_sorted[idy].words + ['<pad>'] *
                                          (max_length - cur_length))

                    buckets[idx - 1][-1].append([1] * cur_length + [0] *
                                                (max_length - cur_length))
                    # labels_raw[idx-1].append(inst_sorted[idy][-1])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sort_instances(inst_save)
            max_length = len(inst_sorted[0].chars_index)
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy].chars_index)
                buckets[batch_num -
                        1][0].append(inst_sorted[idy].chars_index +
                                     [self.char_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][1].append(inst_sorted[idy].left_bichar_index +
                                     [self.bichar_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][2].append(inst_sorted[idy].right_bichar_index +
                                     [self.bichar_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][3].append(inst_sorted[idy].extchars_index +
                                     [self.extchar_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num - 1][4].append(
                    inst_sorted[idy].extleft_bichar_index +
                    [self.extbichar_alphabet.word2id['<pad>']] *
                    (max_length - cur_length))
                buckets[batch_num - 1][5].append(
                    inst_sorted[idy].extright_bichar_index +
                    [self.extbichar_alphabet.word2id['<pad>']] *
                    (max_length - cur_length))
                buckets[batch_num - 1][6].append(
                    inst_sorted[idy].char_type_index +
                    [self.char_type_alphabet.word2id['<pad>']] *
                    (max_length - cur_length))
                buckets[batch_num -
                        1][7].append(inst_sorted[idy].segpos_index +
                                     [self.segpos_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][8].append(inst_sorted[idy].last_wordlen_index +
                                     [self.wordlen_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][9].append(inst_sorted[idy].last_pos_index +
                                     [self.pos_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][10].append(inst_sorted[idy].gold_action +
                                      ['<pad>'] * (max_length - cur_length))
                buckets[batch_num -
                        1][11].append(inst_sorted[idy].gold_pos + ['<pad>'] *
                                      (max_length - cur_length))
                buckets[batch_num -
                        1][12].append(inst_sorted[idy].chars + ['<pad>'] *
                                      (max_length - cur_length))
                buckets[batch_num -
                        1][13].append(inst_sorted[idy].words + ['<pad>'] *
                                      (max_length - cur_length))

                buckets[batch_num - 1][-1].append([1] * cur_length + [0] *
                                                  (max_length - cur_length))
                # labels_raw[batch_num-1].append(inst_sorted[idy][-1])
        return buckets

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] +
                                   [self.word_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] +
                                   inst[-1] +
                                   [self.label_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append(
                    (inst[1] +
                     [[self.char_alphabet.word2id['<pad>']] * char_length] *
                     (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] *
                                    (max_length - (cur_length + 1)))

        return buckets
示例#6
0
class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.category_alphabet = Alphabet('category', is_category=True)
        self.label_alphabet = Alphabet('label', is_label=True)
        self.char_alphabet = Alphabet('char')
        self.pos_alphabet = Alphabet('pos')
        self.parent_alphabet = Alphabet('parent')
        self.rel_alphabet = Alphabet('rel')

        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        self.max_char_length = 0

        self.word_num = 0
        self.char_num = 0
        self.label_num = 0
        self.category_num = 0
        self.parent_num = 0
        self.pos_num = 0
        self.rel_num = 0


    def build_alphabet(self, word_counter, label_counter, category_counter, pos_counter, rel_counter, shrink_feature_threshold, char=False):
        for word, count in word_counter.most_common():
            if count > shrink_feature_threshold:
                self.word_alphabet.add(word, count)
        for label, count in label_counter.most_common():
            self.label_alphabet.add(label, count)
        for category, count in category_counter.most_common():
            self.category_alphabet.add(category, count)
        for pos, count in pos_counter.most_common():
            self.pos_alphabet.add(pos, count)
        # for parent, count in parent_counter.most_common():
        #     self.parent_alphabet.add(parent, count)
        # print(rel_counter)
        for rel, count in rel_counter.most_common():
            self.rel_alphabet.add(rel, count)


    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.category_num = self.category_alphabet.close()
        self.label_num = self.label_alphabet.close()
        # self.category_num = self.category_alphabet.close()
        self.parent_num = self.parent_alphabet.close()
        self.pos_num = self.pos_alphabet.close()
        self.rel_num = self.rel_alphabet.close()


    def get_instance(self, file, run_insts, shrink_feature_threshold, char=False, char_padding_symbol='<pad>'):
        words = []
        labels = []
        categorys = []
        poss = []
        parents = []
        rels = []
        insts = []
        word_counter = Counter()
        label_counter = Counter()
        category_counter = Counter()
        pos_counter = Counter()
        # parent_counter = Counter()
        rel_counter = Counter()

        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            for id, line in enumerate(f.readlines()):
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    if '' in line: line.remove('')
                    if len(line) != 6: print(id)
                    # print(line)
                    word = line[0]
                    if self.number_normalized: word = utils.normalize_word(word)
                    label = line[1]
                    category = line[2]
                    pos = line[3]
                    parent = line[-2]
                    if ',' in parent:
                        # print(parent)
                        parent = parent.split(',')[0]
                    # if parent == '':
                    #     print(id)
                    rel = line[-1]

                    words.append(word)
                    labels.append(label)
                    categorys.append(category)
                    poss.append(pos)
                    parents.append(parent)
                    rels.append(rel)

                    word_counter[word] += 1        #####
                    label_counter[label] += 1
                    category_counter[category] += 1
                    pos_counter[pos] += 1
                    # parent_counter[category] += 1
                    rel_counter[rel] += 1
                else:
                    # print(words)
                    # print(parents)
                    insts.append([words, labels, categorys, poss, parents, rels])
                    words = []
                    labels = []
                    categorys = []
                    poss = []
                    parents = []
                    rels = []
                    count += 1
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, label_counter, category_counter, pos_counter, rel_counter, shrink_feature_threshold, char)
        insts_index = []

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [self.label_alphabet.get_index(l) for l in inst[1]]
            categorys_index = [self.category_alphabet.get_index(c) for c in inst[2]]
            poss_index = [self.pos_alphabet.get_index(p) for p in inst[3]]
            # parents_index = [self.parent_alphabet.get_index(p) for p in inst[-2]]
            parents_index = [int(p)-1 for p in inst[-2]]
            rels_index = [self.rel_alphabet.get_index(r) for r in inst[-1]]
            insts_index.append([words_index, labels_index, categorys_index, poss_index, parents_index, rels_index])
        return insts, insts_index


    def get_instance_tree(self, file, run_insts, shrink_feature_threshold, write_word_path, label_path, pos_path, parse_path, category_path, char=False, char_padding_symbol='<pad>'):
        words = []
        labels = []
        categorys = []
        poss = []
        parses = []
        insts = []
        word_counter = Counter()
        label_counter = Counter()
        category_counter = Counter()
        pos_counter = Counter()
        parse_counter = Counter()

        word_file = open(write_word_path, 'w', encoding='utf-8')
        label_file = open(label_path, 'w', encoding='utf-8')
        pos_file = open(pos_path, 'w', encoding='utf-8')
        parse_file = open(parse_path, 'w', encoding='utf-8')
        category_file = open(category_path, 'w', encoding='utf-8')

        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            for id, line in enumerate(f.readlines()):
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    # print(line)
                    word = line[0]
                    if self.number_normalized: word = utils.normalize_word(word)
                    label = line[1]
                    category = line[2]
                    pos = line[-2]
                    parse = line[-1]

                    words.append(word)
                    labels.append(label)
                    categorys.append(category)
                    parses.append(parse)
                    poss.append(pos)

                    word_counter[word] += 1        #####
                    label_counter[label] += 1
                    category_counter[category] += 1
                    pos_counter[pos] += 1
                    parse_counter[category] += 1
                else:
                    # print(words)
                    word_write = ' '.join(words)
                    word_file.write(word_write)
                    word_file.write('\n')
                    parse_write = ' '.join(parses)
                    parse_file.write(parse_write)
                    parse_file.write('\n')

                    insts.append([words, labels, categorys, poss, parses])
                    words = []
                    labels = []
                    categorys = []
                    poss = []
                    parses = []
                    count += 1
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, label_counter, category_counter, pos_counter, parse_counter, shrink_feature_threshold, char)
        insts_index = []

        path = r"C:\Users\song\Desktop\treelstm_word\examples\vocab-2.txt"
        file = open(path, 'w', encoding='utf-8')
        words = self.word_alphabet.id2word
        # print(len(words))       # 6799, 6310
        for id in range(len(words)):
            file.write(words[id])
            file.write('\n')
        file.close()

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [str(self.label_alphabet.get_index(l)) for l in inst[1]]
            categorys_index = [str(self.category_alphabet.get_index(c)) for c in inst[2]]
            pos_index = [str(self.pos_alphabet.get_index(p)) for p in inst[-2]]
            parses_index = [self.parse_alphabet.get_index(p) for p in inst[-1]]
            insts_index.append([words_index, labels_index, categorys_index, pos_index, parses_index])

            label_write = ' '.join(labels_index)
            label_file.write(label_write)
            label_file.write('\n')
            pos_write = ' '.join(pos_index)
            pos_file.write(pos_write)
            pos_file.write('\n')
            category_write = ' '.join(categorys_index)
            category_file.write(category_write)
            category_file.write('\n')
        return insts, insts_index


    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_avg(emb_path, self.word_alphabet.word2id, word_dims, self.norm_word_emb)

    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_avg(emb_path, self.char_alphabet.word2id, char_dims, self.norm_char_emb)


    def generate_batch_buckets(self, batch_size, insts, char=False):
        batch_num = int(math.ceil(len(insts) / batch_size))
        buckets = [[[], [], [], [], [], []] for _ in range(batch_num)]
        labels_raw = [[] for _ in range(batch_num)]
        category_raw = [[] for _ in range(batch_num)]
        target_start = [[] for _ in range(batch_num)]
        target_end = [[] for _ in range(batch_num)]

        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sorted_instances_index(inst_save)
                max_length = len(inst_sorted[0][0])
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy][0])
                    buckets[idx-1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][1].append(inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][2].append(inst_sorted[idy][-3] + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][3].append(inst_sorted[idy][-2] + [self.parent_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][4].append(inst_sorted[idy][-1] + [self.rel_alphabet.word2id['<pad>']] * (max_length - cur_length))
                    buckets[idx-1][-1].append([1] * cur_length + [0] * (max_length - cur_length))
                    labels_raw[idx-1].append(inst_sorted[idy][1])

                    start, end = evaluation.extract_target(inst_sorted[idy][1], self.label_alphabet)
                    if start == []:
                        start = [0]
                        end = [0]
                    target_start[idx-1].append(start[0])
                    target_end[idx-1].append(end[0])
                    # target_start.extend(start)
                    # target_end.extend(end)
                    category_raw[idx-1].append(inst_sorted[idy][2][0])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sorted_instances_index(inst_save)
            max_length = len(inst_sorted[0][0])
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy][0])
                buckets[batch_num-1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][1].append(inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][2].append(inst_sorted[idy][-3] + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][3].append(inst_sorted[idy][-2] + [self.parent_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][4].append(inst_sorted[idy][-1] + [self.rel_alphabet.word2id['<pad>']] * (max_length - cur_length))
                buckets[batch_num-1][-1].append([1] * cur_length + [0] * (max_length - cur_length))
                labels_raw[batch_num-1].append(inst_sorted[idy][1])
                category_raw[batch_num-1].append(inst_sorted[idy][2][0])
                start, end = evaluation.extract_target(inst_sorted[idy][1], self.label_alphabet)
                if start == []:
                    start = [0]
                    end = [0]
                target_start[batch_num - 1].append(start[0])
                target_end[batch_num - 1].append(end[0])
        # print(buckets)
        # print(labels_raw)
        # print(category_raw)
        # print(target_start)
        # print(target_end)
        return buckets, labels_raw, category_raw, target_start, target_end

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] + inst[-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append((inst[1] + [[self.char_alphabet.word2id['<pad>']] * char_length] * (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] * (max_length - (cur_length + 1)))

        return buckets
示例#7
0
        """Add Start Of Sequence character to an array of sequences."""
        sos_col = np.ones([self.latest_batch_size, 1]) * alphabet.sos_id
        return np.concatenate([sos_col, array[:, :-1]], 1)


if __name__ == '__main__':
    from data.alphabet import Alphabet
    SEQ_LEN_X = 250
    SEQ_LEN_T = 500
    BATCH_SIZE = 76800

    text_loader = TextLoader(
        ['data/train/europarl-v7.de-en.en'],
        ['data/train/europarl-v7.de-en.de'], SEQ_LEN_X, SEQ_LEN_T)

    alphabet_src = Alphabet('data/alphabet/dict_wmt_tok.de-en.en', eos='*')
    alphabet_tar = Alphabet('data/alphabet/dict_wmt_tok.de-en.de', eos='*', sos='')

    text_batch_gen = TextBatchGenerator(text_loader,
                                        BATCH_SIZE,
                                        alphabet_src,
                                        alphabet_tar,
                                        use_dynamic_array_sizes=True)

    print("running warmup for 20 iterations, and 180 iterations with bucket")
    line = ""
    for i, batch in enumerate(text_batch_gen.gen_batch(variable_bucket_schedule)):
        print(batch["x_encoded"].shape, batch["t_encoded"].shape)
        if i == 200:
            break
示例#8
0
class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.char_alphabet = Alphabet('char')
        self.label_alphabet = Alphabet('label', is_label=True)

        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        self.max_char_length = 0

        self.word_num = 0
        self.char_num = 0
        self.label_num = 0

    def build_alphabet(self,
                       word_counter,
                       char_counter,
                       label_counter,
                       shrink_feature_threshold,
                       char=False):
        for word, count in word_counter.most_common():
            if count > shrink_feature_threshold:
                # self.word_alphabet.id2word.append(word)
                # self.word_alphabet.id2count.append(count)
                # if self.number_normalized: word = utils.normalize_word(word)
                self.word_alphabet.add(word, count)
        for label, count in label_counter.most_common():
            if count > shrink_feature_threshold:
                # self.label_alphabet.id2word.append(label)
                # self.label_alphabet.id2count.append(count)
                self.label_alphabet.add(label, count)

        # another method
        # reverse = lambda x: dict(zip(x, range(len(x))))
        # self.word_alphabet.word2id = reverse(self.word_alphabet.id2word)
        # self.label_alphabet.word2id = reverse(self.label_alphabet.id2word)

        ##### check
        if len(self.word_alphabet.word2id) != len(
                self.word_alphabet.id2word) or len(
                    self.word_alphabet.id2count) != len(
                        self.word_alphabet.id2word):
            print('there are errors in building word alphabet.')
        if len(self.label_alphabet.word2id) != len(
                self.label_alphabet.id2word) or len(
                    self.label_alphabet.id2count) != len(
                        self.label_alphabet.id2word):
            print('there are errors in building label alphabet.')

        if char:
            for char, count in char_counter.most_common():
                if count > shrink_feature_threshold:
                    self.char_alphabet.add(char, count)
                    # self.char_alphabet.id2word.append(char)
                    # self.char_alphabet.id2count.append(count)
            # self.char_alphabet.word2id = reverse(self.char_alphabet.id2word)
            if len(self.char_alphabet.word2id) != len(
                    self.char_alphabet.id2word) or len(
                        self.char_alphabet.id2count) != len(
                            self.char_alphabet.id2word):
                print('there are errors in building char alphabet.')

    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.char_num = self.char_alphabet.close()
        self.label_num = self.label_alphabet.close()

    def get_instance(self,
                     file,
                     run_insts,
                     shrink_feature_threshold,
                     char=False,
                     char_padding_symbol='<pad>'):
        words = []
        chars = []
        labels = []
        insts = []
        word_counter = Counter()
        char_counter = Counter()
        label_counter = Counter()
        char_length_max = 0
        count = 0
        with open(file, 'r', encoding='utf-8') as f:
            ##### if one sentence is a line, you can use the method to control instances for debug.
            # if run_insts == -1:
            #     fin_lines = f.readlines()
            # else:
            #     fin_lines = f.readlines()[:run_insts]
            # in_lines = open(file, 'r', encoding='utf-8').readlines()
            for line in f.readlines():
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    if line[0] == 'token':
                        word = line[1]
                        if self.number_normalized:
                            word = utils.normalize_word(word)
                        label = line[-1]
                        words.append(word)
                        labels.append(label)
                        word_counter[word] += 1  #####
                        label_counter[label] += 1

                        if char:
                            char_list = []
                            for char in word:
                                char_list.append(char)
                                char_counter[char] += 1
                            chars.append(char_list)
                            char_length = len(char_list)
                            if char_length > char_length_max:
                                char_length_max = char_length
                            if char_length_max > self.max_char_length:
                                self.max_char_length = char_length_max
                else:
                    if char:
                        chars_padded = []
                        for index, char_list in enumerate(chars):
                            char_number = len(char_list)
                            if char_number < char_length_max:
                                char_list = char_list + [
                                    char_padding_symbol
                                ] * (char_length_max - char_number)
                                char_counter[char_padding_symbol] += (
                                    char_length_max - char_number)
                            chars_padded.append(char_list)
                            assert (len(char_list) == char_length_max)

                        insts.append([words, chars_padded, labels])
                    else:
                        insts.append([words, labels])
                    words = []
                    chars = []
                    labels = []
                    char_length_max = 0
                    count += 1
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, char_counter, label_counter,
                                shrink_feature_threshold, char)
        insts_index = []

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [self.label_alphabet.get_index(l) for l in inst[-1]]
            chars_index = []
            if char:
                # words, chars, labels = inst
                # words_index = [self.word_alphabet.get_index(w) for w in words]
                # labels_index = [self.label_alphabet.get_index(l) for l in labels]
                # char_index = []
                for char in inst[1]:
                    char_index = [
                        self.char_alphabet.get_index(c) for c in char
                    ]
                    chars_index.append(char_index)
                insts_index.append([words_index, chars_index, labels_index])
            else:
                # words, labels = inst
                # words_index = [self.word_alphabet.get_index(w) for w in words]
                # labels_index = [self.label_alphabet.get_index(l) for l in labels]
                insts_index.append([words_index, labels_index])

        ##### sorted sentences
        # insts_sorted, insts_index_sorted = utils.sorted_instances(insts, insts_index)
        return insts, insts_index

    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.word_alphabet.word2id, word_dims,
            self.norm_word_emb)

    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.char_alphabet.word2id, char_dims,
            self.norm_char_emb)

    def generate_batch_buckets(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        labels_raw = [[] for _ in range(batch_num)]
        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sorted_instances_index(inst_save)
                max_length = len(inst_sorted[0][0])
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy][0])
                    buckets[idx - 1][0].append(
                        inst_sorted[idy][0] +
                        [self.word_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][1].append(
                        inst_sorted[idy][-1] +
                        [self.label_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))

                    if char:
                        cur_char_length = len(inst_sorted[idy][1][0])
                        inst_sorted[idy][1] = [
                            (ele + [self.char_alphabet.word2id['<pad>']] *
                             (self.max_char_length - cur_char_length))
                            for ele in inst_sorted[idy][1]
                        ]
                        buckets[idx - 1][2].append(
                            (inst_sorted[idy][1] +
                             [[self.char_alphabet.word2id['<pad>']] *
                              self.max_char_length] *
                             (max_length - cur_length)))
                    buckets[idx - 1][-1].append([1] * cur_length + [0] *
                                                (max_length - cur_length))
                    labels_raw[idx - 1].append(inst_sorted[idy][-1])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sorted_instances_index(inst_save)
            max_length = len(inst_sorted[0][0])
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy][0])
                buckets[batch_num -
                        1][0].append(inst_sorted[idy][0] +
                                     [self.word_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][1].append(inst_sorted[idy][-1] +
                                     [self.label_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                if char:
                    cur_char_length = len(inst_sorted[idy][1][0])
                    inst_sorted[idy][1] = [
                        (ele + [self.char_alphabet.word2id['<pad>']] *
                         (self.max_char_length - cur_char_length))
                        for ele in inst_sorted[idy][1]
                    ]
                    buckets[batch_num - 1][2].append(
                        (inst_sorted[idy][1] +
                         [[self.char_alphabet.word2id['<pad>']] *
                          self.max_char_length] * (max_length - cur_length)))
                buckets[batch_num - 1][-1].append([1] * cur_length + [0] *
                                                  (max_length - cur_length))
                labels_raw[batch_num - 1].append(inst_sorted[idy][-1])
        return buckets, labels_raw

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] +
                                   [self.word_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] +
                                   inst[-1] +
                                   [self.label_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append(
                    (inst[1] +
                     [[self.char_alphabet.word2id['<pad>']] * char_length] *
                     (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] *
                                    (max_length - (cur_length + 1)))

        return buckets
示例#9
0
class Data():
    def __init__(self):
        self.word_alphabet = Alphabet('word')
        self.category_alphabet = Alphabet('category', is_category=True)
        self.label_alphabet = Alphabet('label', is_label=True)
        self.char_alphabet = Alphabet('char')

        self.number_normalized = True
        self.norm_word_emb = False
        self.norm_char_emb = False

        self.pretrain_word_embedding = None
        self.pretrain_char_embedding = None

        self.max_char_length = 0

        self.word_num = 0
        self.char_num = 0
        self.label_num = 0

    def build_alphabet(self,
                       word_counter,
                       label_counter,
                       category_counter,
                       shrink_feature_threshold,
                       char=False):
        for word, count in word_counter.most_common():
            if count > shrink_feature_threshold:
                self.word_alphabet.add(word, count)
        for label, count in label_counter.most_common():
            self.label_alphabet.add(label, count)
        for category, count in category_counter.most_common():
            self.category_alphabet.add(category, count)

        ##### check
        if len(self.word_alphabet.word2id) != len(
                self.word_alphabet.id2word) or len(
                    self.word_alphabet.id2count) != len(
                        self.word_alphabet.id2word):
            print('there are errors in building word alphabet.')
        if len(self.label_alphabet.word2id) != len(
                self.label_alphabet.id2word) or len(
                    self.label_alphabet.id2count) != len(
                        self.label_alphabet.id2word):
            print('there are errors in building label alphabet.')
        if len(self.category_alphabet.word2id) != len(
                self.category_alphabet.id2word) or len(
                    self.category_alphabet.id2count) != len(
                        self.category_alphabet.id2word):
            print('there are errors in building category alphabet.')

    def fix_alphabet(self):
        self.word_num = self.word_alphabet.close()
        self.category_num = self.category_alphabet.close()
        self.label_num = self.label_alphabet.close()

    def get_instance(self,
                     file,
                     run_insts,
                     shrink_feature_threshold,
                     char=False,
                     char_padding_symbol='<pad>'):
        words = []
        labels = []
        categorys = []
        insts = []
        word_counter = Counter()
        char_counter = Counter()
        label_counter = Counter()
        category_counter = Counter()

        count = 0
        ner_num = 0
        with open(file, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                if run_insts == count: break
                if len(line) > 2:
                    line = line.strip().split(' ')
                    word = line[0]
                    if self.number_normalized:
                        word = utils.normalize_word(word)
                    if len(list(line[1])) > 1:
                        label = line[1].split('-')[0]
                        if label == 'b':
                            ner_num += 1
                        category = line[1].split('-')[1]
                        categorys.append(category)
                        category_counter[category] += 1
                    else:
                        label = line[1]
                    label = label + '-target'
                    words.append(word)
                    labels.append(label)
                    # categorys.append(category)
                    word_counter[word] += 1  #####
                    label_counter[label] += 1
                    # category_counter[category] += 1
                else:
                    insts.append([words, labels, categorys])
                    words = []
                    labels = []
                    categorys = []
                    count += 1
                    if ner_num > 1: print(ner_num)
                    ner_num = 0
        if not self.word_alphabet.fix_flag:
            self.build_alphabet(word_counter, label_counter, category_counter,
                                shrink_feature_threshold, char)
        insts_index = []

        for inst in insts:
            words_index = [self.word_alphabet.get_index(w) for w in inst[0]]
            labels_index = [self.label_alphabet.get_index(l) for l in inst[1]]
            length = len(labels_index)
            categorys_index = [self.category_alphabet.get_index(inst[-1][0])
                               ] * length
            # print(len(categorys_index))
            insts_index.append([words_index, labels_index, categorys_index])

        return insts, insts_index

    def build_word_pretrain_emb(self, emb_path, word_dims):
        self.pretrain_word_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.word_alphabet.word2id, word_dims,
            self.norm_word_emb)

    def build_char_pretrain_emb(self, emb_path, char_dims):
        self.pretrain_char_embedding = utils.load_pretrained_emb_avg(
            emb_path, self.char_alphabet.word2id, char_dims,
            self.norm_char_emb)

    def generate_batch_buckets(self, batch_size, insts, char=False):
        batch_num = int(math.ceil(len(insts) / batch_size))
        buckets = [[[], [], []] for _ in range(batch_num)]
        labels_raw = [[] for _ in range(batch_num)]
        category_raw = [[] for _ in range(batch_num)]
        target_start = [[] for _ in range(batch_num)]
        target_end = [[] for _ in range(batch_num)]

        inst_save = []
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id == 0 or id % batch_size != 0:
                inst_save.append(inst)
            elif id % batch_size == 0:
                assert len(inst_save) == batch_size
                inst_sorted = utils.sorted_instances_index(inst_save)
                max_length = len(inst_sorted[0][0])
                for idy in range(batch_size):
                    cur_length = len(inst_sorted[idy][0])
                    buckets[idx - 1][0].append(
                        inst_sorted[idy][0] +
                        [self.word_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][1].append(
                        inst_sorted[idy][1] +
                        [self.label_alphabet.word2id['<pad>']] *
                        (max_length - cur_length))
                    buckets[idx - 1][-1].append([1] * cur_length + [0] *
                                                (max_length - cur_length))
                    labels_raw[idx - 1].append(inst_sorted[idy][1])

                    start, end = evaluation.extract_target(
                        inst_sorted[idy][1], self.label_alphabet)
                    target_start[idx - 1].append(start[0])
                    target_end[idx - 1].append(end[0])
                    # target_start.extend(start)
                    # target_end.extend(end)
                    category_raw[idx - 1].append(inst_sorted[idy][-1][0])
                inst_save = []
                inst_save.append(inst)
        if inst_save != []:
            inst_sorted = utils.sorted_instances_index(inst_save)
            max_length = len(inst_sorted[0][0])
            for idy in range(len(inst_sorted)):
                cur_length = len(inst_sorted[idy][0])
                buckets[batch_num -
                        1][0].append(inst_sorted[idy][0] +
                                     [self.word_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num -
                        1][1].append(inst_sorted[idy][1] +
                                     [self.label_alphabet.word2id['<pad>']] *
                                     (max_length - cur_length))
                buckets[batch_num - 1][-1].append([1] * cur_length + [0] *
                                                  (max_length - cur_length))
                labels_raw[batch_num - 1].append(inst_sorted[idy][1])
                category_raw[batch_num - 1].append(inst_sorted[idy][-1][0])
                start, end = evaluation.extract_target(inst_sorted[idy][1],
                                                       self.label_alphabet)
                target_start[batch_num - 1].append(start[0])
                target_end[batch_num - 1].append(end[0])
        # print(buckets)
        # print(labels_raw)
        # print(category_raw)
        # print(target_start)
        # print(target_end)
        return buckets, labels_raw, category_raw, target_start, target_end

    def generate_batch_buckets_save(self, batch_size, insts, char=False):
        # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts))
        # insts_length = list(len(inst[0]+1) for inst in insts)
        # if len(insts) % batch_size == 0:
        #     batch_num = len(insts) // batch_size
        # else:
        #     batch_num = len(insts) // batch_size + 1
        batch_num = int(math.ceil(len(insts) / batch_size))

        if char:
            buckets = [[[], [], [], []] for _ in range(batch_num)]
        else:
            buckets = [[[], [], []] for _ in range(batch_num)]
        max_length = 0
        for id, inst in enumerate(insts):
            idx = id // batch_size
            if id % batch_size == 0:
                max_length = len(inst[0]) + 1
            cur_length = len(inst[0])

            buckets[idx][0].append(inst[0] +
                                   [self.word_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length))
            buckets[idx][1].append([self.label_alphabet.word2id['<start>']] +
                                   inst[-1] +
                                   [self.label_alphabet.word2id['<pad>']] *
                                   (max_length - cur_length - 1))
            if char:
                char_length = len(inst[1][0])
                buckets[idx][2].append(
                    (inst[1] +
                     [[self.char_alphabet.word2id['<pad>']] * char_length] *
                     (max_length - cur_length)))
            buckets[idx][-1].append([1] * (cur_length + 1) + [0] *
                                    (max_length - (cur_length + 1)))

        return buckets
示例#10
0
class Model:
    # settings that affect train.py
    batch_size_train = 100000
    batch_size_valid = 128
    seq_len_x = 50
    seq_len_t = 50
    name = None  # (string) For saving logs and checkpoints. (None to disable.)
    visualize_freq = 10000  # Visualize training X, y, and t. (0 to disable.)
    log_freq = 100  # How often to print updates during training.
    save_freq = 1000  # How often to save checkpoints. (0 to disable.)
    valid_freq = 500  # How often to validate.
    iterations = 5 * 32000  # How many iterations to train for before stopping.
    train_feedback = False  # Enable feedback during training?
    tb_log_freq = 500  # How often to save logs for TensorBoard
    max_to_keep = 100

    # datasets
    #train_x_files = ['data/train/europarl-v7.de-en.en']
    #train_t_files = ['data/train/europarl-v7.de-en.de']
    train_x_files = [
        'data/train/europarl-v7.de-en.en.tok',
        'data/train/commoncrawl.de-en.en.tok',
        'data/train/news-commentary-v10.de-en.en.tok'
    ]
    train_t_files = [
        'data/train/europarl-v7.de-en.de.tok',
        'data/train/commoncrawl.de-en.de.tok',
        'data/train/news-commentary-v10.de-en.de.tok'
    ]
    #valid_x_files = ['data/valid/devtest2006.en', 'data/valid/test2006.en',
    #                 'data/valid/test2007.en', 'data/valid/test2008.en']
    #valid_t_files = ['data/valid/devtest2006.de', 'data/valid/test2006.de',
    #                 'data/valid/test2007.de', 'data/valid/test2008.de']
    valid_x_files = ['data/valid/newstest2013.en.tok']
    valid_t_files = ['data/valid/newstest2013.de.tok']
    test_x_files = ['data/valid/newstest2014.deen.en.tok']
    test_t_files = ['data/valid/newstest2014.deen.de.tok']

    # settings that are local to the model
    alphabet_src_size = 310  # size of alphabet
    alphabet_tar_size = 310  # size of alphabet
    alphabet_src = Alphabet('data/alphabet/dict_wmt_tok.de-en.en', eos='*')
    alphabet_tar = Alphabet('data/alphabet/dict_wmt_tok.de-en.de',
                            eos='*',
                            sos='')
    char_encoder_units = 400  # number of units in character-level encoder
    word_encoder_units = 400  # num nuits in word-level encoders (both forwards and back)
    attn_units = 300  # num units used for attention in the decoder.
    embedd_dims = 256  # size of character embeddings
    learning_rate = 0.001
    reg_scale = 0.000001
    clip_norm = 1

    swap_schedule = {0: 0.0}

    # kwargs for scheduling function
    schedule_kwargs = {'fuzzyness': 3}

    def __init__(self):
        self.max_x_seq_len = self.seq_len_x
        self.max_t_seq_len = self.seq_len_t

        # TF placeholders
        self.setup_placeholders()

        # schedule functions
        self.train_schedule_function = tl.variable_bucket_schedule
        self.valid_schedule_function = None  # falls back to frostings.default_schedule
        self.test_schedule_function = None

        print("Model instantiation")
        self.build()
        self.loss, self.accuracy = self.build_loss(self.out, self.out_tensor)
        self.valid_loss, self.valid_accuracy = self.build_valid_loss()
        self.ys = self.build_prediction(self.out_tensor)
        self.valid_ys = self.build_valid_prediction()
        self.build_training()

        # Create TensorBoard scalar summaries
        tf.scalar_summary('train/loss', self.loss)
        tf.scalar_summary('train/accuracy', self.accuracy)

        # setup batch generators
        self.setup_batch_generators()

    def setup_placeholders(self):
        shape = [None, None]
        self.Xs = tf.placeholder(tf.int32, shape=shape, name='X_input')
        self.ts = tf.placeholder(tf.int32, shape=shape, name='t_input')
        self.ts_go = tf.placeholder(tf.int32, shape=shape, name='t_input_go')
        self.X_len = tf.placeholder(tf.int32, shape=[None], name='X_len')
        self.t_len = tf.placeholder(tf.int32, shape=[None], name='t_len')
        self.feedback = tf.placeholder(tf.bool, name='feedback_indicator')
        self.x_mask = tf.placeholder(tf.float32, shape=shape, name='x_mask')
        self.t_mask = tf.placeholder(tf.float32, shape=shape, name='t_mask')

        shape = [None, None]
        self.X_spaces = tf.placeholder(tf.int32, shape=shape, name='X_spaces')
        self.X_spaces_len = tf.placeholder(tf.int32,
                                           shape=[None],
                                           name='X_spaces_len')

    def build(self):
        print('Building model')
        self.x_embeddings = tf.Variable(tf.random_normal(
            [self.alphabet_src_size, self.embedd_dims], stddev=0.1),
                                        name='x_embeddings')
        self.t_embeddings = tf.Variable(tf.random_normal(
            [self.alphabet_tar_size, self.embedd_dims], stddev=0.1),
                                        name='t_embeddings')

        X_embedded = tf.gather(self.x_embeddings, self.Xs, name='embed_X')
        t_embedded = tf.gather(self.t_embeddings, self.ts_go, name='embed_t')

        with tf.variable_scope('dense_out'):
            W_out = tf.get_variable(
                'W_out', [self.word_encoder_units * 2, self.alphabet_tar_size])
            b_out = tf.get_variable('b_out', [self.alphabet_tar_size])

        # forward encoding
        char_enc_state, char_enc_out = encoder(X_embedded, self.X_len,
                                               'char_encoder',
                                               self.char_encoder_units)
        char2word = _grid_gather(char_enc_out, self.X_spaces)
        char2word.set_shape([None, None, self.char_encoder_units])
        word_enc_state, word_enc_out = encoder(char2word, self.X_spaces_len,
                                               'word_encoder',
                                               self.word_encoder_units)

        # backward encoding words
        char2word = tf.reverse_sequence(char2word,
                                        tf.to_int64(self.X_spaces_len), 1)
        char2word.set_shape([None, None, self.char_encoder_units])
        word_enc_state_bck, word_enc_out_bck = encoder(
            char2word, self.X_spaces_len, 'word_encoder_backwards',
            self.word_encoder_units)
        word_enc_out_bck = tf.reverse_sequence(word_enc_out_bck,
                                               tf.to_int64(self.X_spaces_len),
                                               1)

        word_enc_state = tf.concat(1, [word_enc_state, word_enc_state_bck])
        word_enc_out = tf.concat(2, [word_enc_out, word_enc_out_bck])

        # decoding
        dec_state, dec_out, valid_dec_out, valid_attention_tracker = (
            attention_decoder(word_enc_out, self.X_spaces_len, word_enc_state,
                              t_embedded, self.t_len, self.attn_units,
                              self.t_embeddings, W_out, b_out))

        out_tensor = tf.reshape(dec_out, [-1, self.word_encoder_units * 2])
        out_tensor = tf.matmul(out_tensor, W_out) + b_out
        out_shape = tf.concat(0, [
            tf.expand_dims(tf.shape(self.X_len)[0], 0),
            tf.expand_dims(tf.shape(t_embedded)[1], 0),
            tf.expand_dims(tf.constant(self.alphabet_tar_size), 0)
        ])
        self.valid_attention_tracker = valid_attention_tracker.pack()
        self.out_tensor = tf.reshape(out_tensor, out_shape)
        self.out_tensor.set_shape([None, None, self.alphabet_tar_size])

        valid_out_tensor = tf.reshape(valid_dec_out,
                                      [-1, self.word_encoder_units * 2])
        valid_out_tensor = tf.matmul(valid_out_tensor, W_out) + b_out
        self.valid_out_tensor = tf.reshape(valid_out_tensor, out_shape)

        self.out = None

        # add TensorBoard summaries for all variables
        tf.contrib.layers.summarize_variables()

    def build_loss(self, out, out_tensor):
        """Build a loss function and accuracy for the model."""
        print('  Building loss and accuracy')

        with tf.variable_scope('accuracy'):
            argmax = tf.to_int32(tf.argmax(out_tensor, 2))
            correct = tf.to_float(tf.equal(argmax, self.ts)) * self.t_mask
            accuracy = tf.reduce_sum(correct) / tf.reduce_sum(self.t_mask)

        with tf.variable_scope('loss'):
            loss = sequence_loss_tensor(out_tensor, self.ts, self.t_mask,
                                        self.alphabet_tar_size)

            with tf.variable_scope('regularization'):
                regularize = tf.contrib.layers.l2_regularizer(self.reg_scale)
                params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
                reg_term = sum([regularize(param) for param in params])

            loss += reg_term

        return loss, accuracy

    def build_valid_loss(self):
        return self.build_loss(self.out, self.valid_out_tensor)

    def build_prediction(self, out_tensor):
        print('  Building prediction')
        with tf.variable_scope('prediction'):
            # logits is a list of tensors of shape [batch_size, alphabet_size].
            # We need shape of [batch_size, target_seq_len, alphabet_size].
            return tf.argmax(out_tensor, dimension=2)

    def build_valid_prediction(self):
        return self.build_prediction(self.valid_out_tensor)

    def build_training(self):
        print('  Building training')
        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        optimizer = tf.train.AdamOptimizer(self.learning_rate)

        # Do gradient clipping
        # NOTE: this is the correct, but slower clipping by global norm.
        # Maybe it's worth trying the faster tf.clip_by_norm()
        # (See the documentation for tf.clip_by_global_norm() for more info)
        grads_and_vars = optimizer.compute_gradients(self.loss)
        gradients, variables = zip(*grads_and_vars)  # unzip list of tuples
        clipped_gradients, global_norm = (tf.clip_by_global_norm(
            gradients, self.clip_norm))
        clipped_grads_and_vars = zip(clipped_gradients, variables)

        # Create TensorBoard scalar summary for global gradient norm
        tf.scalar_summary('train/global gradient norm', global_norm)

        # Create TensorBoard summaries for gradients
        # for grad, var in grads_and_vars:
        #     # Sparse tensor updates can't be summarized, so avoid doing that:
        #     if isinstance(grad, tf.Tensor):
        #         tf.histogram_summary('grad_' + var.name, grad)

        # make training op for applying the gradients
        self.train_op = optimizer.apply_gradients(clipped_grads_and_vars,
                                                  global_step=self.global_step)

    def setup_batch_generators(self):
        """Load the datasets"""
        self.batch_generator = dict()

        # load training set
        print('Load training set')
        train_loader = tl.TextLoader(paths_X=self.train_x_files,
                                     paths_t=self.train_t_files,
                                     seq_len_x=self.seq_len_x,
                                     seq_len_t=self.seq_len_t)
        self.batch_generator['train'] = tl.TextBatchGenerator(
            loader=train_loader,
            batch_size=self.batch_size_train,
            alphabet_src=self.alphabet_src,
            alphabet_tar=self.alphabet_tar,
            use_dynamic_array_sizes=True,
            **self.schedule_kwargs)

        # load validation set
        print('Load validation set')
        valid_loader = tl.TextLoader(paths_X=self.valid_x_files,
                                     paths_t=self.valid_t_files,
                                     seq_len_x=self.seq_len_x,
                                     seq_len_t=self.seq_len_t)
        self.batch_generator['valid'] = tl.TextBatchGenerator(
            loader=valid_loader,
            batch_size=self.batch_size_valid,
            alphabet_src=self.alphabet_src,
            alphabet_tar=self.alphabet_tar,
            use_dynamic_array_sizes=True)

        # load test set
        print('Load validation set')
        test_loader = tl.TextLoader(paths_X=self.test_x_files,
                                    paths_t=self.test_t_files,
                                    seq_len_x=self.seq_len_x,
                                    seq_len_t=self.seq_len_t)
        self.batch_generator['test'] = tl.TextBatchGenerator(
            loader=test_loader,
            batch_size=self.batch_size_valid,
            alphabet_src=self.alphabet_src,
            alphabet_tar=self.alphabet_tar,
            use_dynamic_array_sizes=True)

    def valid_dict(self, batch, feedback=True):
        """ Return feed_dict for validation """
        return {
            self.Xs: batch['x_encoded'],
            self.ts: batch['t_encoded'],
            self.ts_go: batch['t_encoded_go'],
            self.X_len: batch['x_len'],
            self.t_len: batch['t_len'],
            self.x_mask: batch['x_mask'],
            self.t_mask: batch['t_mask'],
            self.feedback: feedback,
            self.X_spaces: batch['x_spaces'],
            self.X_spaces_len: batch['x_spaces_len']
        }

    def train_dict(self, batch):
        """ Return feed_dict for training.
        Reuse validation feed_dict because the only difference is feedback.
        """
        return self.valid_dict(batch, feedback=False)

    def build_feed_dict(self, batch, validate=False):
        return self.valid_dict(batch) if validate else self.train_dict(batch)

    def get_generator(self, split):
        assert split in ['train', 'valid', 'test']
        return self.batch_generator[split].gen_batch

    def next_train_feed(self):
        generator = self.get_generator('train')
        for t_batch in generator(self.train_schedule_function):
            extra = {'t_len': t_batch['t_len']}
            yield (self.build_feed_dict(t_batch), extra)

    def next_valid_feed(self):
        generator = self.get_generator('valid')
        for v_batch in generator(self.valid_schedule_function):
            yield self.build_feed_dict(v_batch, validate=True)

    def next_test_feed(self):
        generator = self.get_generator('test')
        for p_batch in generator(self.test_schedule_function):
            yield self.build_feed_dict(p_batch, validate=True)

    def get_alphabet_src(self):
        return self.batch_generator['train'].alphabet_src

    def get_alphabet_tar(self):
        return self.batch_generator['train'].alphabet_tar