class Data(): def __init__(self): self.word_alphabet = Alphabet('word') self.char_alphabet = Alphabet('char') self.bichar_alphabet = Alphabet('bichar') self.pos_alphabet = Alphabet('pos', is_label=True) self.char_type_alphabet = Alphabet('type', is_label=True) self.extchar_alphabet = Alphabet('extchar') self.extbichar_alphabet = Alphabet('extbichar') self.segpos_alphabet = Alphabet('segpos', is_label=True) self.wordlen_alphabet = Alphabet('wordlen') self.number_normalized = False self.norm_word_emb = False self.norm_char_emb = False self.norm_bichar_emb = False self.pretrain_word_embedding = None self.pretrain_char_embedding = None self.pretrain_bichar_embedding = None self.wordPadID = 0 self.charPadID = 0 self.bicharPadID = 0 self.charTypePadID = 0 self.wordlenPadID = 0 self.posPadID = 0 self.appID = 0 self.actionPadID = 0 self.word_num = 0 self.char_num = 0 self.pos_num = 0 self.bichar_num = 0 self.segpos_num = 0 self.wordlen_num = 0 self.char_type_num = 0 self.extchar_num = 0 self.extbichar_num = 0 self.wordlen_max = 7 def build_alphabet(self, word_counter, char_counter, extchar_counter, bichar_counter, extbichar_counter, char_type_counter, pos_counter, wordlen_counter, gold_counter, shrink_feature_threshold): # 可以优化,但是会比较混乱 for word, count in word_counter.most_common(): # if count > shrink_feature_threshold: self.word_alphabet.add(word, count) for char, count in char_counter.most_common(): if count > shrink_feature_threshold: self.char_alphabet.add(char, count) for extchar, count in extchar_counter.most_common(): # if count > shrink_feature_threshold: self.extchar_alphabet.add(extchar, count) for bichar, count in bichar_counter.most_common(): if count > shrink_feature_threshold: self.bichar_alphabet.add(bichar, count) for extbichar, count in extbichar_counter.most_common(): # if count > shrink_feature_threshold: self.extbichar_alphabet.add(extbichar, count) for char_type, count in char_type_counter.most_common(): # if count > shrink_feature_threshold: self.char_type_alphabet.add(char_type, count) for pos, count in pos_counter.most_common(): # if count > shrink_feature_threshold: self.pos_alphabet.add(pos, count) for wordlen, count in wordlen_counter.most_common(): # if count > shrink_feature_threshold: self.wordlen_alphabet.add(wordlen, count) for segpos, count in gold_counter.most_common(): # if count > shrink_feature_threshold: self.segpos_alphabet.add(segpos, count) # another method # reverse = lambda x: dict(zip(x, range(len(x)))) # self.word_alphabet.word2id = reverse(self.word_alphabet.id2word) # self.label_alphabet.word2id = reverse(self.label_alphabet.id2word) ##### check if len(self.word_alphabet.word2id) != len( self.word_alphabet.id2word) or len( self.word_alphabet.id2count) != len( self.word_alphabet.id2word): print('there are errors in building word alphabet.') if len(self.char_alphabet.word2id) != len( self.char_alphabet.id2word) or len( self.char_alphabet.id2count) != len( self.char_alphabet.id2word): print('there are errors in building char alphabet.') def fix_alphabet(self): self.word_num = self.word_alphabet.close() self.char_num = self.char_alphabet.close() self.pos_num = self.pos_alphabet.close() self.bichar_num = self.bichar_alphabet.close() self.segpos_num = self.segpos_alphabet.close() # self.wordlen_max = self.wordlen_alphabet.size()-2 ###### # print(self.wordlen_max) # self.wordlen_alphabet.add(self.wordlen_max+1) self.wordlen_num = self.wordlen_alphabet.close() # print(self.wordlen_num) self.char_type_num = self.char_type_alphabet.close() def fix_static_alphabet(self): self.extchar_num = self.extchar_alphabet.close() self.extbichar_num = self.extbichar_alphabet.close() def get_instance(self, file, run_insts, shrink_feature_threshold): insts = [] word_counter = Counter() char_counter = Counter() bichar_counter = Counter() char_type_counter = Counter() gold_counter = Counter() pos_counter = Counter() extchar_counter = Counter() extbichar_counter = Counter() wordlen_counter = Counter() count = 0 with open(file, 'r', encoding='utf-8') as f: for line in f.readlines(): if run_insts == count: break line = line.strip().split(' ') inst = Instance() char_num = 0 start = 0 for idx, ele in enumerate(line): word = ele.split('_')[0] if self.number_normalized: word = utils.normalize_word(word) inst.words.append(word) word_counter[word] += 1 pos = ele.split('_')[1] inst.gold_pos.append(pos) pos_counter[pos] += 1 word_list = list(word) # print(len(word_list)) # if len(word_list) > 5: # print(len(word_list)) # print(word) if len(word_list) > 6: # inst.word_len.append(7) cur_len = 7 else: # inst.word_len.append(len(word_list)) cur_len = len(word_list) inst.word_len.append(cur_len) wordlen_counter[cur_len] += 1 for id, char in enumerate(word): if idx == 0 and id == 0: inst.last_wordlen.append('<pad>') inst.last_pos.append('<pad>') elif idx != 0 and id == 0: last_wordlen_cur = len(inst.words[-2]) if last_wordlen_cur > 6: last_wordlen_cur = 7 inst.last_wordlen.append( last_wordlen_cur) ### 因为当前的word和pos已经放进列表里了 inst.last_pos.append(inst.gold_pos[-2]) else: last_wordlen_cur = id if last_wordlen_cur > 6: last_wordlen_cur = 7 inst.last_wordlen.append(last_wordlen_cur) inst.last_pos.append(inst.gold_pos[-1]) inst.chars.append(char) inst.extchars.append(char) char_counter[char] += 1 extchar_counter[char] += 1 char_type = instance.char_type(char) # print(char_type) inst.char_type.append(char_type) char_type_counter[char_type] += 1 if id == 0: inst.gold_action.append(sep + '#' + pos) gold_counter[sep + '#' + pos] += 1 start = char_num else: inst.gold_action.append(app) gold_counter[app] += 1 char_num += 1 inst.word_seg.append('[' + str(start) + ',' + str(char_num) + ']') inst.word_seg_pos.append('[' + str(start) + ',' + str(char_num) + ']' + pos) right_bichars = [] if len(inst.chars) == 1: inst.left_bichars.append(nullsymbol + char) inst.extleft_bichars.append(nullsymbol + char) right_bichars.append(char + nullsymbol) bichar_counter[nullsymbol + char] += 1 bichar_counter[char + nullsymbol] += 1 extbichar_counter[nullsymbol + char] += 1 extbichar_counter[char + nullsymbol] += 1 else: for id, char in enumerate(inst.chars): if id == 0: inst.left_bichars.append(nullsymbol + char) # inst.right_bichars.append(char+inst.chars[id+1]) inst.extleft_bichars.append(nullsymbol + char) # inst.extright_bichars.append(char+inst.chars[id+1]) right_bichars.append(char + inst.chars[id + 1]) # right_index_mark.append(1) bichar_counter[nullsymbol + char] += 1 bichar_counter[char + inst.chars[id + 1]] += 1 extbichar_counter[nullsymbol + char] += 1 extbichar_counter[char + inst.chars[id + 1]] += 1 elif id == (len(inst.chars) - 1): inst.left_bichars.append(inst.chars[id - 1] + char) # inst.right_bichars.append(char+nullsymbol) inst.extleft_bichars.append(inst.chars[id - 1] + char) # inst.extright_bichars.append(char + nullsymbol) right_bichars.append(char + nullsymbol) # right_index_mark.append(1) bichar_counter[inst.chars[id - 1] + char] += 1 bichar_counter[char + nullsymbol] += 1 extbichar_counter[inst.chars[id - 1] + char] += 1 extbichar_counter[char + nullsymbol] += 1 else: inst.left_bichars.append(inst.chars[id - 1] + char) # inst.right_bichars.append(char+inst.chars[id+1]) inst.extleft_bichars.append(inst.chars[id - 1] + char) # inst.extright_bichars.append(char + inst.chars[id + 1]) right_bichars.append(char + inst.chars[id + 1]) # right_index_mark.append(1) bichar_counter[inst.chars[id - 1] + char] += 1 bichar_counter[char + inst.chars[id + 1]] += 1 extbichar_counter[inst.chars[id - 1] + char] += 1 extbichar_counter[char + inst.chars[id + 1]] += 1 # right_bichars = list(reversed(right_bichars)) ##### # print(right_bichars) # print(right_index_mark) # right_index_mark = list(reversed(right_index_mark)) inst.right_bichars = right_bichars inst.extright_bichars = right_bichars # print(line) # print(inst.words) # print(inst.chars) # print(inst.char_type) # print(inst.word_seg) # print(inst.word_seg_pos) # print(inst.gold) # print(inst.pos) # print(inst.left_bichars) # print(inst.right_bichars) count += 1 insts.append(inst) if not self.word_alphabet.fix_flag: self.build_alphabet(word_counter, char_counter, extchar_counter, bichar_counter, extbichar_counter, char_type_counter, pos_counter, wordlen_counter, gold_counter, shrink_feature_threshold) # insts_index = [] for inst in insts: inst.words_index = [ self.word_alphabet.get_index(w) for w in inst.words ] inst.chars_index = [ self.char_alphabet.get_index(c) for c in inst.chars ] inst.extchars_index = [ self.extchar_alphabet.get_index(ec) for ec in inst.extchars ] inst.left_bichar_index = [ self.bichar_alphabet.get_index(b) for b in inst.left_bichars ] inst.right_bichar_index = [ self.bichar_alphabet.get_index(b) for b in inst.right_bichars ] inst.extleft_bichar_index = [ self.extbichar_alphabet.get_index(eb) for eb in inst.extleft_bichars ] inst.extright_bichar_index = [ self.extbichar_alphabet.get_index(eb) for eb in inst.extright_bichars ] inst.pos_index = [ self.pos_alphabet.get_index(p) for p in inst.gold_pos ] inst.char_type_index = [ self.char_type_alphabet.get_index(t) for t in inst.char_type ] # inst.char_type_index = [] # for t in inst.char_type: # print(t) # print(self.char_type_alphabet.word2id) # temp = self.char_type_alphabet.get_index(t) # print(temp) # inst.char_type_index.append(temp) inst.segpos_index = [ self.segpos_alphabet.get_index(g) for g in inst.gold_action ] inst.word_len_index = [ self.wordlen_alphabet.get_index(w) for w in inst.word_len ] inst.last_wordlen_index = [ self.wordlen_alphabet.get_index(l) for l in inst.last_wordlen ] inst.last_pos_index = [ self.pos_alphabet.get_index(p) for p in inst.last_pos ] self.wordPadID = self.word_alphabet.get_index(pad) self.charPadID = self.char_alphabet.get_index(pad) self.bicharPadID = self.bichar_alphabet.get_index(pad) self.charTypePadID = self.char_type_alphabet.get_index(pad) self.wordlenPadID = self.wordlen_alphabet.get_index(pad) self.posPadID = self.pos_alphabet.get_index(pad) self.appID = self.segpos_alphabet.get_index(app) self.actionPadID = self.segpos_alphabet.get_index(pad) ##### sorted sentences # insts_sorted, insts_index_sorted = utils.sorted_instances(insts, insts_index) return insts def build_word_pretrain_emb(self, emb_path, word_dims): self.pretrain_word_embedding = utils.load_pretrained_emb_uniform( emb_path, self.word_alphabet.word2id, word_dims, self.norm_word_emb) # 可以优化 def build_char_pretrain_emb(self, emb_path, char_dims): self.pretrain_char_embedding = utils.load_pretrained_emb_uniform( emb_path, self.extchar_alphabet.word2id, char_dims, self.norm_char_emb) # 可以优化 def build_bichar_pretrain_emb(self, emb_path, bichar_dims): self.pretrain_bichar_embedding = utils.load_pretrained_emb_uniform( emb_path, self.extbichar_alphabet.word2id, bichar_dims, self.norm_bichar_emb) def generate_batch_buckets(self, batch_size, insts): # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts)) # insts_length = list(len(inst[0]+1) for inst in insts) # if len(insts) % batch_size == 0: # batch_num = len(insts) // batch_size # else: # batch_num = len(insts) // batch_size + 1 batch_num = int(math.ceil(len(insts) / batch_size)) buckets = [[[], [], [], [], [], [], [], [], [], [], [], [], [], [], []] for _ in range(batch_num)] # labels_raw = [[] for _ in range(batch_num)] inst_save = [] for id, inst in enumerate(insts): idx = id // batch_size if id == 0 or id % batch_size != 0: inst_save.append(inst) elif id % batch_size == 0: assert len(inst_save) == batch_size inst_sorted = utils.sort_instances(inst_save) max_length = len(inst_sorted[0].chars_index) for idy in range(batch_size): cur_length = len(inst_sorted[idy].chars_index) buckets[idx - 1][0].append( inst_sorted[idy].chars_index + [self.char_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][1].append( inst_sorted[idy].left_bichar_index + [self.bichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][2].append( inst_sorted[idy].right_bichar_index + [self.bichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][3].append( inst_sorted[idy].extchars_index + [self.extchar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][4].append( inst_sorted[idy].extleft_bichar_index + [self.extbichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][5].append( inst_sorted[idy].extright_bichar_index + [self.extbichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][6].append( inst_sorted[idy].char_type_index + [self.char_type_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][7].append( inst_sorted[idy].segpos_index + [self.segpos_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][8].append( inst_sorted[idy].last_wordlen_index + [self.wordlen_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][9].append(inst_sorted[idy].last_pos_index + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][10].append(inst_sorted[idy].gold_action + ['<pad>'] * (max_length - cur_length)) buckets[idx - 1][11].append(inst_sorted[idy].gold_pos + ['<pad>'] * (max_length - cur_length)) buckets[idx - 1][12].append(inst_sorted[idy].chars + ['<pad>'] * (max_length - cur_length)) buckets[idx - 1][13].append(inst_sorted[idy].words + ['<pad>'] * (max_length - cur_length)) buckets[idx - 1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) # labels_raw[idx-1].append(inst_sorted[idy][-1]) inst_save = [] inst_save.append(inst) if inst_save != []: inst_sorted = utils.sort_instances(inst_save) max_length = len(inst_sorted[0].chars_index) for idy in range(len(inst_sorted)): cur_length = len(inst_sorted[idy].chars_index) buckets[batch_num - 1][0].append(inst_sorted[idy].chars_index + [self.char_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][1].append(inst_sorted[idy].left_bichar_index + [self.bichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][2].append(inst_sorted[idy].right_bichar_index + [self.bichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][3].append(inst_sorted[idy].extchars_index + [self.extchar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][4].append( inst_sorted[idy].extleft_bichar_index + [self.extbichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][5].append( inst_sorted[idy].extright_bichar_index + [self.extbichar_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][6].append( inst_sorted[idy].char_type_index + [self.char_type_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][7].append(inst_sorted[idy].segpos_index + [self.segpos_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][8].append(inst_sorted[idy].last_wordlen_index + [self.wordlen_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][9].append(inst_sorted[idy].last_pos_index + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][10].append(inst_sorted[idy].gold_action + ['<pad>'] * (max_length - cur_length)) buckets[batch_num - 1][11].append(inst_sorted[idy].gold_pos + ['<pad>'] * (max_length - cur_length)) buckets[batch_num - 1][12].append(inst_sorted[idy].chars + ['<pad>'] * (max_length - cur_length)) buckets[batch_num - 1][13].append(inst_sorted[idy].words + ['<pad>'] * (max_length - cur_length)) buckets[batch_num - 1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) # labels_raw[batch_num-1].append(inst_sorted[idy][-1]) return buckets def generate_batch_buckets_save(self, batch_size, insts, char=False): # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts)) # insts_length = list(len(inst[0]+1) for inst in insts) # if len(insts) % batch_size == 0: # batch_num = len(insts) // batch_size # else: # batch_num = len(insts) // batch_size + 1 batch_num = int(math.ceil(len(insts) / batch_size)) if char: buckets = [[[], [], [], []] for _ in range(batch_num)] else: buckets = [[[], [], []] for _ in range(batch_num)] max_length = 0 for id, inst in enumerate(insts): idx = id // batch_size if id % batch_size == 0: max_length = len(inst[0]) + 1 cur_length = len(inst[0]) buckets[idx][0].append(inst[0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx][1].append([self.label_alphabet.word2id['<start>']] + inst[-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length - 1)) if char: char_length = len(inst[1][0]) buckets[idx][2].append( (inst[1] + [[self.char_alphabet.word2id['<pad>']] * char_length] * (max_length - cur_length))) buckets[idx][-1].append([1] * (cur_length + 1) + [0] * (max_length - (cur_length + 1))) return buckets
class Data(): def __init__(self): self.word_alphabet = Alphabet('word') self.char_alphabet = Alphabet('char') self.label_alphabet = Alphabet('label', is_label=True) self.number_normalized = True self.norm_word_emb = False self.norm_char_emb = False self.pretrain_word_embedding = None self.pretrain_char_embedding = None self.max_char_length = 0 self.word_num = 0 self.char_num = 0 self.label_num = 0 def build_alphabet(self, word_counter, char_counter, label_counter, shrink_feature_threshold, char=False): for word, count in word_counter.most_common(): if count > shrink_feature_threshold: # self.word_alphabet.id2word.append(word) # self.word_alphabet.id2count.append(count) # if self.number_normalized: word = utils.normalize_word(word) self.word_alphabet.add(word, count) for label, count in label_counter.most_common(): if count > shrink_feature_threshold: # self.label_alphabet.id2word.append(label) # self.label_alphabet.id2count.append(count) self.label_alphabet.add(label, count) # another method # reverse = lambda x: dict(zip(x, range(len(x)))) # self.word_alphabet.word2id = reverse(self.word_alphabet.id2word) # self.label_alphabet.word2id = reverse(self.label_alphabet.id2word) ##### check if len(self.word_alphabet.word2id) != len( self.word_alphabet.id2word) or len( self.word_alphabet.id2count) != len( self.word_alphabet.id2word): print('there are errors in building word alphabet.') if len(self.label_alphabet.word2id) != len( self.label_alphabet.id2word) or len( self.label_alphabet.id2count) != len( self.label_alphabet.id2word): print('there are errors in building label alphabet.') if char: for char, count in char_counter.most_common(): if count > shrink_feature_threshold: self.char_alphabet.add(char, count) # self.char_alphabet.id2word.append(char) # self.char_alphabet.id2count.append(count) # self.char_alphabet.word2id = reverse(self.char_alphabet.id2word) if len(self.char_alphabet.word2id) != len( self.char_alphabet.id2word) or len( self.char_alphabet.id2count) != len( self.char_alphabet.id2word): print('there are errors in building char alphabet.') def fix_alphabet(self): self.word_num = self.word_alphabet.close() self.char_num = self.char_alphabet.close() self.label_num = self.label_alphabet.close() def get_instance(self, file, run_insts, shrink_feature_threshold, char=False, char_padding_symbol='<pad>'): words = [] chars = [] labels = [] insts = [] word_counter = Counter() char_counter = Counter() label_counter = Counter() char_length_max = 0 count = 0 with open(file, 'r', encoding='utf-8') as f: ##### if one sentence is a line, you can use the method to control instances for debug. # if run_insts == -1: # fin_lines = f.readlines() # else: # fin_lines = f.readlines()[:run_insts] # in_lines = open(file, 'r', encoding='utf-8').readlines() for line in f.readlines(): if run_insts == count: break if len(line) > 2: line = line.strip().split(' ') if line[0] == 'token': word = line[1] if self.number_normalized: word = utils.normalize_word(word) label = line[-1] words.append(word) labels.append(label) word_counter[word] += 1 ##### label_counter[label] += 1 if char: char_list = [] for char in word: char_list.append(char) char_counter[char] += 1 chars.append(char_list) char_length = len(char_list) if char_length > char_length_max: char_length_max = char_length if char_length_max > self.max_char_length: self.max_char_length = char_length_max else: if char: chars_padded = [] for index, char_list in enumerate(chars): char_number = len(char_list) if char_number < char_length_max: char_list = char_list + [ char_padding_symbol ] * (char_length_max - char_number) char_counter[char_padding_symbol] += ( char_length_max - char_number) chars_padded.append(char_list) assert (len(char_list) == char_length_max) insts.append([words, chars_padded, labels]) else: insts.append([words, labels]) words = [] chars = [] labels = [] char_length_max = 0 count += 1 if not self.word_alphabet.fix_flag: self.build_alphabet(word_counter, char_counter, label_counter, shrink_feature_threshold, char) insts_index = [] for inst in insts: words_index = [self.word_alphabet.get_index(w) for w in inst[0]] labels_index = [self.label_alphabet.get_index(l) for l in inst[-1]] chars_index = [] if char: # words, chars, labels = inst # words_index = [self.word_alphabet.get_index(w) for w in words] # labels_index = [self.label_alphabet.get_index(l) for l in labels] # char_index = [] for char in inst[1]: char_index = [ self.char_alphabet.get_index(c) for c in char ] chars_index.append(char_index) insts_index.append([words_index, chars_index, labels_index]) else: # words, labels = inst # words_index = [self.word_alphabet.get_index(w) for w in words] # labels_index = [self.label_alphabet.get_index(l) for l in labels] insts_index.append([words_index, labels_index]) ##### sorted sentences # insts_sorted, insts_index_sorted = utils.sorted_instances(insts, insts_index) return insts, insts_index def build_word_pretrain_emb(self, emb_path, word_dims): self.pretrain_word_embedding = utils.load_pretrained_emb_avg( emb_path, self.word_alphabet.word2id, word_dims, self.norm_word_emb) def build_char_pretrain_emb(self, emb_path, char_dims): self.pretrain_char_embedding = utils.load_pretrained_emb_avg( emb_path, self.char_alphabet.word2id, char_dims, self.norm_char_emb) def generate_batch_buckets(self, batch_size, insts, char=False): # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts)) # insts_length = list(len(inst[0]+1) for inst in insts) # if len(insts) % batch_size == 0: # batch_num = len(insts) // batch_size # else: # batch_num = len(insts) // batch_size + 1 batch_num = int(math.ceil(len(insts) / batch_size)) if char: buckets = [[[], [], [], []] for _ in range(batch_num)] else: buckets = [[[], [], []] for _ in range(batch_num)] labels_raw = [[] for _ in range(batch_num)] inst_save = [] for id, inst in enumerate(insts): idx = id // batch_size if id == 0 or id % batch_size != 0: inst_save.append(inst) elif id % batch_size == 0: assert len(inst_save) == batch_size inst_sorted = utils.sorted_instances_index(inst_save) max_length = len(inst_sorted[0][0]) for idy in range(batch_size): cur_length = len(inst_sorted[idy][0]) buckets[idx - 1][0].append( inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][1].append( inst_sorted[idy][-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length)) if char: cur_char_length = len(inst_sorted[idy][1][0]) inst_sorted[idy][1] = [ (ele + [self.char_alphabet.word2id['<pad>']] * (self.max_char_length - cur_char_length)) for ele in inst_sorted[idy][1] ] buckets[idx - 1][2].append( (inst_sorted[idy][1] + [[self.char_alphabet.word2id['<pad>']] * self.max_char_length] * (max_length - cur_length))) buckets[idx - 1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) labels_raw[idx - 1].append(inst_sorted[idy][-1]) inst_save = [] inst_save.append(inst) if inst_save != []: inst_sorted = utils.sorted_instances_index(inst_save) max_length = len(inst_sorted[0][0]) for idy in range(len(inst_sorted)): cur_length = len(inst_sorted[idy][0]) buckets[batch_num - 1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][1].append(inst_sorted[idy][-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length)) if char: cur_char_length = len(inst_sorted[idy][1][0]) inst_sorted[idy][1] = [ (ele + [self.char_alphabet.word2id['<pad>']] * (self.max_char_length - cur_char_length)) for ele in inst_sorted[idy][1] ] buckets[batch_num - 1][2].append( (inst_sorted[idy][1] + [[self.char_alphabet.word2id['<pad>']] * self.max_char_length] * (max_length - cur_length))) buckets[batch_num - 1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) labels_raw[batch_num - 1].append(inst_sorted[idy][-1]) return buckets, labels_raw def generate_batch_buckets_save(self, batch_size, insts, char=False): # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts)) # insts_length = list(len(inst[0]+1) for inst in insts) # if len(insts) % batch_size == 0: # batch_num = len(insts) // batch_size # else: # batch_num = len(insts) // batch_size + 1 batch_num = int(math.ceil(len(insts) / batch_size)) if char: buckets = [[[], [], [], []] for _ in range(batch_num)] else: buckets = [[[], [], []] for _ in range(batch_num)] max_length = 0 for id, inst in enumerate(insts): idx = id // batch_size if id % batch_size == 0: max_length = len(inst[0]) + 1 cur_length = len(inst[0]) buckets[idx][0].append(inst[0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx][1].append([self.label_alphabet.word2id['<start>']] + inst[-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length - 1)) if char: char_length = len(inst[1][0]) buckets[idx][2].append( (inst[1] + [[self.char_alphabet.word2id['<pad>']] * char_length] * (max_length - cur_length))) buckets[idx][-1].append([1] * (cur_length + 1) + [0] * (max_length - (cur_length + 1))) return buckets
class Data(): def __init__(self): self.word_alphabet = Alphabet('word') self.category_alphabet = Alphabet('category', is_category=True) self.label_alphabet = Alphabet('label', is_label=True) self.char_alphabet = Alphabet('char') self.pos_alphabet = Alphabet('pos') self.parent_alphabet = Alphabet('parent') self.rel_alphabet = Alphabet('rel') self.number_normalized = True self.norm_word_emb = False self.norm_char_emb = False self.pretrain_word_embedding = None self.pretrain_char_embedding = None self.max_char_length = 0 self.word_num = 0 self.char_num = 0 self.label_num = 0 self.category_num = 0 self.parent_num = 0 self.pos_num = 0 self.rel_num = 0 def build_alphabet(self, word_counter, label_counter, category_counter, pos_counter, rel_counter, shrink_feature_threshold, char=False): for word, count in word_counter.most_common(): if count > shrink_feature_threshold: self.word_alphabet.add(word, count) for label, count in label_counter.most_common(): self.label_alphabet.add(label, count) for category, count in category_counter.most_common(): self.category_alphabet.add(category, count) for pos, count in pos_counter.most_common(): self.pos_alphabet.add(pos, count) # for parent, count in parent_counter.most_common(): # self.parent_alphabet.add(parent, count) # print(rel_counter) for rel, count in rel_counter.most_common(): self.rel_alphabet.add(rel, count) def fix_alphabet(self): self.word_num = self.word_alphabet.close() self.category_num = self.category_alphabet.close() self.label_num = self.label_alphabet.close() # self.category_num = self.category_alphabet.close() self.parent_num = self.parent_alphabet.close() self.pos_num = self.pos_alphabet.close() self.rel_num = self.rel_alphabet.close() def get_instance(self, file, run_insts, shrink_feature_threshold, char=False, char_padding_symbol='<pad>'): words = [] labels = [] categorys = [] poss = [] parents = [] rels = [] insts = [] word_counter = Counter() label_counter = Counter() category_counter = Counter() pos_counter = Counter() # parent_counter = Counter() rel_counter = Counter() count = 0 with open(file, 'r', encoding='utf-8') as f: for id, line in enumerate(f.readlines()): if run_insts == count: break if len(line) > 2: line = line.strip().split(' ') if '' in line: line.remove('') if len(line) != 6: print(id) # print(line) word = line[0] if self.number_normalized: word = utils.normalize_word(word) label = line[1] category = line[2] pos = line[3] parent = line[-2] if ',' in parent: # print(parent) parent = parent.split(',')[0] # if parent == '': # print(id) rel = line[-1] words.append(word) labels.append(label) categorys.append(category) poss.append(pos) parents.append(parent) rels.append(rel) word_counter[word] += 1 ##### label_counter[label] += 1 category_counter[category] += 1 pos_counter[pos] += 1 # parent_counter[category] += 1 rel_counter[rel] += 1 else: # print(words) # print(parents) insts.append([words, labels, categorys, poss, parents, rels]) words = [] labels = [] categorys = [] poss = [] parents = [] rels = [] count += 1 if not self.word_alphabet.fix_flag: self.build_alphabet(word_counter, label_counter, category_counter, pos_counter, rel_counter, shrink_feature_threshold, char) insts_index = [] for inst in insts: words_index = [self.word_alphabet.get_index(w) for w in inst[0]] labels_index = [self.label_alphabet.get_index(l) for l in inst[1]] categorys_index = [self.category_alphabet.get_index(c) for c in inst[2]] poss_index = [self.pos_alphabet.get_index(p) for p in inst[3]] # parents_index = [self.parent_alphabet.get_index(p) for p in inst[-2]] parents_index = [int(p)-1 for p in inst[-2]] rels_index = [self.rel_alphabet.get_index(r) for r in inst[-1]] insts_index.append([words_index, labels_index, categorys_index, poss_index, parents_index, rels_index]) return insts, insts_index def get_instance_tree(self, file, run_insts, shrink_feature_threshold, write_word_path, label_path, pos_path, parse_path, category_path, char=False, char_padding_symbol='<pad>'): words = [] labels = [] categorys = [] poss = [] parses = [] insts = [] word_counter = Counter() label_counter = Counter() category_counter = Counter() pos_counter = Counter() parse_counter = Counter() word_file = open(write_word_path, 'w', encoding='utf-8') label_file = open(label_path, 'w', encoding='utf-8') pos_file = open(pos_path, 'w', encoding='utf-8') parse_file = open(parse_path, 'w', encoding='utf-8') category_file = open(category_path, 'w', encoding='utf-8') count = 0 with open(file, 'r', encoding='utf-8') as f: for id, line in enumerate(f.readlines()): if run_insts == count: break if len(line) > 2: line = line.strip().split(' ') # print(line) word = line[0] if self.number_normalized: word = utils.normalize_word(word) label = line[1] category = line[2] pos = line[-2] parse = line[-1] words.append(word) labels.append(label) categorys.append(category) parses.append(parse) poss.append(pos) word_counter[word] += 1 ##### label_counter[label] += 1 category_counter[category] += 1 pos_counter[pos] += 1 parse_counter[category] += 1 else: # print(words) word_write = ' '.join(words) word_file.write(word_write) word_file.write('\n') parse_write = ' '.join(parses) parse_file.write(parse_write) parse_file.write('\n') insts.append([words, labels, categorys, poss, parses]) words = [] labels = [] categorys = [] poss = [] parses = [] count += 1 if not self.word_alphabet.fix_flag: self.build_alphabet(word_counter, label_counter, category_counter, pos_counter, parse_counter, shrink_feature_threshold, char) insts_index = [] path = r"C:\Users\song\Desktop\treelstm_word\examples\vocab-2.txt" file = open(path, 'w', encoding='utf-8') words = self.word_alphabet.id2word # print(len(words)) # 6799, 6310 for id in range(len(words)): file.write(words[id]) file.write('\n') file.close() for inst in insts: words_index = [self.word_alphabet.get_index(w) for w in inst[0]] labels_index = [str(self.label_alphabet.get_index(l)) for l in inst[1]] categorys_index = [str(self.category_alphabet.get_index(c)) for c in inst[2]] pos_index = [str(self.pos_alphabet.get_index(p)) for p in inst[-2]] parses_index = [self.parse_alphabet.get_index(p) for p in inst[-1]] insts_index.append([words_index, labels_index, categorys_index, pos_index, parses_index]) label_write = ' '.join(labels_index) label_file.write(label_write) label_file.write('\n') pos_write = ' '.join(pos_index) pos_file.write(pos_write) pos_file.write('\n') category_write = ' '.join(categorys_index) category_file.write(category_write) category_file.write('\n') return insts, insts_index def build_word_pretrain_emb(self, emb_path, word_dims): self.pretrain_word_embedding = utils.load_pretrained_emb_avg(emb_path, self.word_alphabet.word2id, word_dims, self.norm_word_emb) def build_char_pretrain_emb(self, emb_path, char_dims): self.pretrain_char_embedding = utils.load_pretrained_emb_avg(emb_path, self.char_alphabet.word2id, char_dims, self.norm_char_emb) def generate_batch_buckets(self, batch_size, insts, char=False): batch_num = int(math.ceil(len(insts) / batch_size)) buckets = [[[], [], [], [], [], []] for _ in range(batch_num)] labels_raw = [[] for _ in range(batch_num)] category_raw = [[] for _ in range(batch_num)] target_start = [[] for _ in range(batch_num)] target_end = [[] for _ in range(batch_num)] inst_save = [] for id, inst in enumerate(insts): idx = id // batch_size if id == 0 or id % batch_size != 0: inst_save.append(inst) elif id % batch_size == 0: assert len(inst_save) == batch_size inst_sorted = utils.sorted_instances_index(inst_save) max_length = len(inst_sorted[0][0]) for idy in range(batch_size): cur_length = len(inst_sorted[idy][0]) buckets[idx-1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx-1][1].append(inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx-1][2].append(inst_sorted[idy][-3] + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx-1][3].append(inst_sorted[idy][-2] + [self.parent_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx-1][4].append(inst_sorted[idy][-1] + [self.rel_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx-1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) labels_raw[idx-1].append(inst_sorted[idy][1]) start, end = evaluation.extract_target(inst_sorted[idy][1], self.label_alphabet) if start == []: start = [0] end = [0] target_start[idx-1].append(start[0]) target_end[idx-1].append(end[0]) # target_start.extend(start) # target_end.extend(end) category_raw[idx-1].append(inst_sorted[idy][2][0]) inst_save = [] inst_save.append(inst) if inst_save != []: inst_sorted = utils.sorted_instances_index(inst_save) max_length = len(inst_sorted[0][0]) for idy in range(len(inst_sorted)): cur_length = len(inst_sorted[idy][0]) buckets[batch_num-1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num-1][1].append(inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num-1][2].append(inst_sorted[idy][-3] + [self.pos_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num-1][3].append(inst_sorted[idy][-2] + [self.parent_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num-1][4].append(inst_sorted[idy][-1] + [self.rel_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num-1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) labels_raw[batch_num-1].append(inst_sorted[idy][1]) category_raw[batch_num-1].append(inst_sorted[idy][2][0]) start, end = evaluation.extract_target(inst_sorted[idy][1], self.label_alphabet) if start == []: start = [0] end = [0] target_start[batch_num - 1].append(start[0]) target_end[batch_num - 1].append(end[0]) # print(buckets) # print(labels_raw) # print(category_raw) # print(target_start) # print(target_end) return buckets, labels_raw, category_raw, target_start, target_end def generate_batch_buckets_save(self, batch_size, insts, char=False): # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts)) # insts_length = list(len(inst[0]+1) for inst in insts) # if len(insts) % batch_size == 0: # batch_num = len(insts) // batch_size # else: # batch_num = len(insts) // batch_size + 1 batch_num = int(math.ceil(len(insts) / batch_size)) if char: buckets = [[[], [], [], []] for _ in range(batch_num)] else: buckets = [[[], [], []] for _ in range(batch_num)] max_length = 0 for id, inst in enumerate(insts): idx = id // batch_size if id % batch_size == 0: max_length = len(inst[0]) + 1 cur_length = len(inst[0]) buckets[idx][0].append(inst[0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx][1].append([self.label_alphabet.word2id['<start>']] + inst[-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length - 1)) if char: char_length = len(inst[1][0]) buckets[idx][2].append((inst[1] + [[self.char_alphabet.word2id['<pad>']] * char_length] * (max_length - cur_length))) buckets[idx][-1].append([1] * (cur_length + 1) + [0] * (max_length - (cur_length + 1))) return buckets
class Data(): def __init__(self): self.word_alphabet = Alphabet('word') self.category_alphabet = Alphabet('category', is_category=True) self.label_alphabet = Alphabet('label', is_label=True) self.char_alphabet = Alphabet('char') self.number_normalized = True self.norm_word_emb = False self.norm_char_emb = False self.pretrain_word_embedding = None self.pretrain_char_embedding = None self.max_char_length = 0 self.word_num = 0 self.char_num = 0 self.label_num = 0 def build_alphabet(self, word_counter, label_counter, category_counter, shrink_feature_threshold, char=False): for word, count in word_counter.most_common(): if count > shrink_feature_threshold: self.word_alphabet.add(word, count) for label, count in label_counter.most_common(): self.label_alphabet.add(label, count) for category, count in category_counter.most_common(): self.category_alphabet.add(category, count) ##### check if len(self.word_alphabet.word2id) != len( self.word_alphabet.id2word) or len( self.word_alphabet.id2count) != len( self.word_alphabet.id2word): print('there are errors in building word alphabet.') if len(self.label_alphabet.word2id) != len( self.label_alphabet.id2word) or len( self.label_alphabet.id2count) != len( self.label_alphabet.id2word): print('there are errors in building label alphabet.') if len(self.category_alphabet.word2id) != len( self.category_alphabet.id2word) or len( self.category_alphabet.id2count) != len( self.category_alphabet.id2word): print('there are errors in building category alphabet.') def fix_alphabet(self): self.word_num = self.word_alphabet.close() self.category_num = self.category_alphabet.close() self.label_num = self.label_alphabet.close() def get_instance(self, file, run_insts, shrink_feature_threshold, char=False, char_padding_symbol='<pad>'): words = [] labels = [] categorys = [] insts = [] word_counter = Counter() char_counter = Counter() label_counter = Counter() category_counter = Counter() count = 0 ner_num = 0 with open(file, 'r', encoding='utf-8') as f: for line in f.readlines(): if run_insts == count: break if len(line) > 2: line = line.strip().split(' ') word = line[0] if self.number_normalized: word = utils.normalize_word(word) if len(list(line[1])) > 1: label = line[1].split('-')[0] if label == 'b': ner_num += 1 category = line[1].split('-')[1] categorys.append(category) category_counter[category] += 1 else: label = line[1] label = label + '-target' words.append(word) labels.append(label) # categorys.append(category) word_counter[word] += 1 ##### label_counter[label] += 1 # category_counter[category] += 1 else: insts.append([words, labels, categorys]) words = [] labels = [] categorys = [] count += 1 if ner_num > 1: print(ner_num) ner_num = 0 if not self.word_alphabet.fix_flag: self.build_alphabet(word_counter, label_counter, category_counter, shrink_feature_threshold, char) insts_index = [] for inst in insts: words_index = [self.word_alphabet.get_index(w) for w in inst[0]] labels_index = [self.label_alphabet.get_index(l) for l in inst[1]] length = len(labels_index) categorys_index = [self.category_alphabet.get_index(inst[-1][0]) ] * length # print(len(categorys_index)) insts_index.append([words_index, labels_index, categorys_index]) return insts, insts_index def build_word_pretrain_emb(self, emb_path, word_dims): self.pretrain_word_embedding = utils.load_pretrained_emb_avg( emb_path, self.word_alphabet.word2id, word_dims, self.norm_word_emb) def build_char_pretrain_emb(self, emb_path, char_dims): self.pretrain_char_embedding = utils.load_pretrained_emb_avg( emb_path, self.char_alphabet.word2id, char_dims, self.norm_char_emb) def generate_batch_buckets(self, batch_size, insts, char=False): batch_num = int(math.ceil(len(insts) / batch_size)) buckets = [[[], [], []] for _ in range(batch_num)] labels_raw = [[] for _ in range(batch_num)] category_raw = [[] for _ in range(batch_num)] target_start = [[] for _ in range(batch_num)] target_end = [[] for _ in range(batch_num)] inst_save = [] for id, inst in enumerate(insts): idx = id // batch_size if id == 0 or id % batch_size != 0: inst_save.append(inst) elif id % batch_size == 0: assert len(inst_save) == batch_size inst_sorted = utils.sorted_instances_index(inst_save) max_length = len(inst_sorted[0][0]) for idy in range(batch_size): cur_length = len(inst_sorted[idy][0]) buckets[idx - 1][0].append( inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][1].append( inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx - 1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) labels_raw[idx - 1].append(inst_sorted[idy][1]) start, end = evaluation.extract_target( inst_sorted[idy][1], self.label_alphabet) target_start[idx - 1].append(start[0]) target_end[idx - 1].append(end[0]) # target_start.extend(start) # target_end.extend(end) category_raw[idx - 1].append(inst_sorted[idy][-1][0]) inst_save = [] inst_save.append(inst) if inst_save != []: inst_sorted = utils.sorted_instances_index(inst_save) max_length = len(inst_sorted[0][0]) for idy in range(len(inst_sorted)): cur_length = len(inst_sorted[idy][0]) buckets[batch_num - 1][0].append(inst_sorted[idy][0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][1].append(inst_sorted[idy][1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[batch_num - 1][-1].append([1] * cur_length + [0] * (max_length - cur_length)) labels_raw[batch_num - 1].append(inst_sorted[idy][1]) category_raw[batch_num - 1].append(inst_sorted[idy][-1][0]) start, end = evaluation.extract_target(inst_sorted[idy][1], self.label_alphabet) target_start[batch_num - 1].append(start[0]) target_end[batch_num - 1].append(end[0]) # print(buckets) # print(labels_raw) # print(category_raw) # print(target_start) # print(target_end) return buckets, labels_raw, category_raw, target_start, target_end def generate_batch_buckets_save(self, batch_size, insts, char=False): # insts_length = list(map(lambda t: len(t) + 1, inst[0] for inst in insts)) # insts_length = list(len(inst[0]+1) for inst in insts) # if len(insts) % batch_size == 0: # batch_num = len(insts) // batch_size # else: # batch_num = len(insts) // batch_size + 1 batch_num = int(math.ceil(len(insts) / batch_size)) if char: buckets = [[[], [], [], []] for _ in range(batch_num)] else: buckets = [[[], [], []] for _ in range(batch_num)] max_length = 0 for id, inst in enumerate(insts): idx = id // batch_size if id % batch_size == 0: max_length = len(inst[0]) + 1 cur_length = len(inst[0]) buckets[idx][0].append(inst[0] + [self.word_alphabet.word2id['<pad>']] * (max_length - cur_length)) buckets[idx][1].append([self.label_alphabet.word2id['<start>']] + inst[-1] + [self.label_alphabet.word2id['<pad>']] * (max_length - cur_length - 1)) if char: char_length = len(inst[1][0]) buckets[idx][2].append( (inst[1] + [[self.char_alphabet.word2id['<pad>']] * char_length] * (max_length - cur_length))) buckets[idx][-1].append([1] * (cur_length + 1) + [0] * (max_length - (cur_length + 1))) return buckets