class TrainData: def __init__(self, model_config): self.model_config = model_config vocab_simple_path = self.model_config.vocab_simple vocab_complex_path = self.model_config.vocab_complex vocab_all_path = self.model_config.vocab_all if self.model_config.subword_vocab_size > 0: vocab_simple_path = self.model_config.subword_vocab_simple vocab_complex_path = self.model_config.subword_vocab_complex vocab_all_path = self.model_config.subword_vocab_all data_simple_path = self.model_config.train_dataset_simple data_complex_path = self.model_config.train_dataset_complex if (self.model_config.tie_embedding == 'none' or self.model_config.tie_embedding == 'dec_out'): self.vocab_simple = Vocab(model_config, vocab_simple_path) self.vocab_complex = Vocab(model_config, vocab_complex_path) elif (self.model_config.tie_embedding == 'all' or self.model_config.tie_embedding == 'enc_dec'): self.vocab_simple = Vocab(model_config, vocab_all_path) self.vocab_complex = Vocab(model_config, vocab_all_path) self.size = self.get_size(data_complex_path) if self.model_config.use_dataset2: self.size2 = self.get_size( self.model_config.train_dataset_complex2) # Populate basic complex simple pairs if not self.model_config.it_train: self.data = self.populate_data(data_complex_path, data_simple_path, self.vocab_complex, self.vocab_simple, True) else: self.data_it = self.get_data_sample_it(data_simple_path, data_complex_path) print( 'Use Train Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d.' % (data_simple_path, data_complex_path, self.size)) if 'rule' in self.model_config.memory or 'rule' in self.model_config.rl_configs: self.vocab_rule = Rule(model_config, self.model_config.vocab_rules) self.rules_target, self.rules_align = self.populate_rules( self.model_config.train_dataset_complex_ppdb, self.vocab_rule) assert len(self.rules_align) == self.size assert len(self.rules_target) == self.size print('Populate Rule with size:%s' % self.vocab_rule.get_rule_size()) # if self.model_config.use_dataset2: # self.rules2 = self.populate_rules( # self.model_config.train_dataset_complex_ppdb2, self.vocab_rule) # assert len(self.rules2) == self.size2 def process_line(self, line, vocab, max_len, need_raw=False): if self.model_config.tokenizer == 'split': words = line.split() elif self.model_config.tokenizer == 'nltk': words = word_tokenize(line) else: raise Exception('Unknown tokenizer.') words = [Vocab.process_word(word, self.model_config) for word in words] if need_raw: words_raw = [constant.SYMBOL_START] + words + [constant.SYMBOL_END] else: words_raw = None if self.model_config.subword_vocab_size > 0: words = [constant.SYMBOL_START] + words + [constant.SYMBOL_END] words = vocab.encode(' '.join(words)) else: words = [vocab.encode(word) for word in words] words = ([self.vocab_simple.encode(constant.SYMBOL_START)] + words + [self.vocab_simple.encode(constant.SYMBOL_END)]) if self.model_config.subword_vocab_size > 0: pad_id = vocab.encode(constant.SYMBOL_PAD) else: pad_id = [vocab.encode(constant.SYMBOL_PAD)] if len(words) < max_len: num_pad = max_len - len(words) words.extend(num_pad * pad_id) else: words = words[:max_len] return words, words_raw def get_size(self, data_complex_path): return len(open(data_complex_path, encoding='utf-8').readlines()) def get_data_sample_it(self, data_simple_path, data_complex_path): f_simple = open(data_simple_path, encoding='utf-8') f_complex = open(data_complex_path, encoding='utf-8') # if self.model_config.use_dataset2: # f_simple2 = open(self.model_config.train_dataset_simple2, encoding='utf-8') # f_complex2 = open(self.model_config.train_dataset_complex2, encoding='utf-8') # j = 0 i = 0 while True: if i >= self.size: f_simple = open(data_simple_path, encoding='utf-8') f_complex = open(data_complex_path, encoding='utf-8') i = 0 line_complex = f_complex.readline() line_simple = f_simple.readline() if rd.random() < 0.5 or i >= self.size: i += 1 continue words_complex, words_raw_comp = self.process_line( line_complex, self.vocab_complex, self.model_config.max_complex_sentence, True) words_simple, words_raw_simp = self.process_line( line_simple, self.vocab_simple, self.model_config.max_simple_sentence, True) supplement = {} if 'rule' in self.model_config.memory: supplement['rules_target'] = self.rules_target[i] supplement['rules_align'] = self.rules_align[i] obj = {} obj['words_comp'] = words_complex obj['words_simp'] = words_simple obj['words_raw_comp'] = words_raw_comp obj['words_raw_simp'] = words_raw_simp yield i, obj, supplement i += 1 # if self.model_config.use_dataset2: # if j == self.size2: # f_simple2 = open(self.model_config.train_dataset_simple2, encoding='utf-8') # f_complex2 = open(self.model_config.train_dataset_complex2, encoding='utf-8') # j = 0 # line_complex2 = f_complex2.readline() # line_simple2 = f_simple2.readline() # words_complex2, _ = self.process_line(line_complex2, self.vocab_complex) # words_simple2, _ = self.process_line(line_simple2, self.vocab_simple) # # supplement2 = {} # if self.model_config.memory == 'rule': # supplement2['mem'] = self.rules2[j] # # yield j, words_simple2, words_complex2, cp.deepcopy([1.0] * len(words_simple2)), cp.deepcopy([1.0] * len(words_complex2)), supplement2 # j += 1 def populate_rules(self, rule_path, vocab_rule): data_target, data_align = [], [] for line in open(rule_path, encoding='utf-8'): cur_rules = line.split('\t') tmp, tmp_align = [], [] for cur_rule in cur_rules: rule_id, rule_origins, rule_targets = vocab_rule.encode( cur_rule) if rule_targets is not None and rule_origins is not None: tmp.append((rule_id, [ self.vocab_simple.encode(rule_target) for rule_target in rule_targets ])) if len(rule_origins) == 1 and len(rule_targets) == 1: tmp_align.append( (self.vocab_complex.encode(rule_origins[0]), self.vocab_simple.encode(rule_targets[0]))) data_target.append(tmp) data_align.append(tmp_align) return data_target, data_align def populate_data(self, data_path_comp, data_path_simp, vocab_comp, vocab_simp, need_raw=False): # Populate data into memory data = [] # max_len = -1 # from collections import Counter # len_report = Counter() lines_comp = open(data_path_comp, encoding='utf-8').readlines() lines_simp = open(data_path_simp, encoding='utf-8').readlines() assert len(lines_comp) == len(lines_simp) for line_id in range(len(lines_comp)): obj = {} line_comp = lines_comp[line_id] line_simp = lines_simp[line_id] words_comp, words_raw_comp = self.process_line( line_comp, vocab_comp, self.model_config.max_complex_sentence, need_raw) words_simp, words_raw_simp = self.process_line( line_simp, vocab_simp, self.model_config.max_simple_sentence, need_raw) obj['words_comp'] = words_comp obj['words_simp'] = words_simp if need_raw: obj['words_raw_comp'] = words_raw_comp obj['words_raw_simp'] = words_raw_simp data.append(obj) return data def get_data_sample(self): i = rd.sample(range(self.size), 1)[0] supplement = {} if 'rule' in self.model_config.memory: supplement['rules_target'] = self.rules_target[i] supplement['rules_align'] = self.rules_align[i] return i, self.data[i], supplement
class ValData: def __init__(self, model_config): self.model_config = model_config vocab_simple_path = self.model_config.vocab_simple vocab_complex_path = self.model_config.vocab_complex vocab_all_path = self.model_config.vocab_all if self.model_config.subword_vocab_size > 0: vocab_simple_path = self.model_config.subword_vocab_simple vocab_complex_path = self.model_config.subword_vocab_complex vocab_all_path = self.model_config.subword_vocab_all if (self.model_config.tie_embedding == 'none' or self.model_config.tie_embedding == 'dec_out'): self.vocab_simple = Vocab(model_config, vocab_simple_path) self.vocab_complex = Vocab(model_config, vocab_complex_path) elif (self.model_config.tie_embedding == 'all' or self.model_config.tie_embedding == 'enc_dec'): self.vocab_simple = Vocab(model_config, vocab_all_path) self.vocab_complex = Vocab(model_config, vocab_all_path) # Populate basic complex simple pairs self.data = self.populate_data(self.vocab_complex, self.vocab_simple, True) self.data_complex_raw_lines = self.populate_data_rawfile( self.model_config.val_dataset_complex_rawlines_file) # Populate simple references self.data_references_raw_lines = [] for i in range(self.model_config.num_refs): ref_tmp_rawlines = self.populate_data_rawfile( self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_rawlines_file_references) self.data_references_raw_lines.append(ref_tmp_rawlines) if self.model_config.replace_ner: self.mapper = load_mappers(self.model_config.val_mapper, self.model_config.lower_case) while len(self.mapper) < len(self.data): self.mapper.append({}) assert len(self.data_complex_raw_lines) == len(self.data) assert len(self.mapper) == len(self.data) for i in range(self.model_config.num_refs): assert len(self.data_references_raw_lines[i]) == len(self.data) print('Use Val Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d' % (self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file, self.model_config.val_dataset_complex, len(self.data))) if 'rule' in self.model_config.memory: self.vocab_rule = Rule(model_config, self.model_config.vocab_rules) self.rules = self.populate_rules( self.model_config.val_dataset_complex_ppdb, self.vocab_rule) print('Populate Rule with size:%s' % self.vocab_rule.get_rule_size()) def populate_rules(self, rule_path, vocab_rule): data = [] for line in open(rule_path, encoding='utf-8'): cur_rules = line.split('\t') tmp = [] for cur_rule in cur_rules: rule_id, _, rule_targets = vocab_rule.encode(cur_rule) if rule_targets is not None: tmp.append((rule_id, [self.vocab_simple.encode(rule_target) for rule_target in rule_targets])) data.append(tmp) return data def populate_data_rawfile(self, data_path): """Populate data raw lines into memory""" data = [] for line in open(data_path, encoding='utf-8'): data.append(line.strip()) return data def process_line(self, line, vocab, max_len, need_raw=False): if self.model_config.tokenizer == 'split': words = line.split() elif self.model_config.tokenizer == 'nltk': words = word_tokenize(line) else: raise Exception('Unknown tokenizer.') words = [Vocab.process_word(word, self.model_config) for word in words] if need_raw: words_raw = [constant.SYMBOL_START] + words + [constant.SYMBOL_END] else: words_raw = None if self.model_config.subword_vocab_size > 0: words = [constant.SYMBOL_START] + words + [constant.SYMBOL_END] words = vocab.encode(' '.join(words)) else: words = [vocab.encode(word) for word in words] words = ([self.vocab_simple.encode(constant.SYMBOL_START)] + words + [self.vocab_simple.encode(constant.SYMBOL_END)]) if self.model_config.subword_vocab_size > 0: pad_id = vocab.encode(constant.SYMBOL_PAD) else: pad_id = [vocab.encode(constant.SYMBOL_PAD)] if len(words) < max_len: num_pad = max_len - len(words) words.extend(num_pad * pad_id) else: words = words[:max_len] return words, words_raw def populate_data(self, vocab_comp, vocab_simp, need_raw=False): # Populate data into memory data = [] # max_len = -1 # from collections import Counter # len_report = Counter() lines_comp = open( self.model_config.val_dataset_complex, encoding='utf-8').readlines() lines_simp = open( self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file, encoding='utf-8').readlines() assert len(lines_comp) == len(lines_simp) for line_id in range(len(lines_comp)): obj = {} line_comp = lines_comp[line_id] line_simp = lines_simp[line_id] words_comp, words_raw_comp = self.process_line( line_comp, vocab_comp, self.model_config.max_complex_sentence, need_raw) words_simp, words_raw_simp = self.process_line( line_simp, vocab_simp, self.model_config.max_simple_sentence, need_raw) obj['words_comp'] = words_comp obj['words_simp'] = words_simp if need_raw: obj['words_raw_comp'] = words_raw_comp obj['words_raw_simp'] = words_raw_simp data.append(obj) # len_report.update([len(words)]) # if len(words) > max_len: # max_len = len(words) # print('Max length for data %s is %s.' % (data_path, max_len)) # print('counter:%s' % len_report) return data def get_data_iter(self): i = 0 while True: if i % 100 == 0: print("Processed " + str(i) + " examples so far") ref_rawlines_batch = [self.data_references_raw_lines[j][i] for j in range(self.model_config.num_refs)] supplement = {} if 'rule' in self.model_config.memory: try: supplement['mem'] = self.rules[i] except IndexError: print("****INDEX ERROR: " + str(i)) yield None obj = { 'sentence_simple': self.data[i]['words_simp'], 'sentence_complex': self.data[i]['words_comp'], 'sentence_complex_raw': self.data[i]['words_raw_comp'], 'sentence_simple_raw': self.data[i]['words_raw_simp'], 'sentence_complex_raw_lines': self.data_complex_raw_lines[i], 'mapper': self.mapper[i], 'ref_raw_lines': ref_rawlines_batch, 'sup': supplement, } yield obj i += 1 if i == len(self.data): yield None
class ValData: def __init__(self, model_config): self.model_config = model_config vocab_simple_path = self.model_config.vocab_simple vocab_complex_path = self.model_config.vocab_complex vocab_all_path = self.model_config.vocab_all if self.model_config.subword_vocab_size > 0: vocab_simple_path = self.model_config.subword_vocab_simple vocab_complex_path = self.model_config.subword_vocab_complex vocab_all_path = self.model_config.subword_vocab_all if (self.model_config.tie_embedding == 'none' or self.model_config.tie_embedding == 'dec_out'): self.vocab_simple = Vocab(model_config, vocab_simple_path) self.vocab_complex = Vocab(model_config, vocab_complex_path) elif (self.model_config.tie_embedding == 'all' or self.model_config.tie_embedding == 'enc_dec'): self.vocab_simple = Vocab(model_config, vocab_all_path) self.vocab_complex = Vocab(model_config, vocab_all_path) # Populate basic complex simple pairs self.data = self.populate_data(self.vocab_complex, self.vocab_simple, True) self.data_complex_raw_lines = self.populate_data_rawfile( self.model_config.val_dataset_complex_rawlines_file, self.model_config.lower_case) # Populate simple references self.data_references_raw_lines = [] for i in range(self.model_config.num_refs): ref_tmp_rawlines = self.populate_data_rawfile( self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_rawlines_file_references + str(i), self.model_config.lower_case) self.data_references_raw_lines.append(ref_tmp_rawlines) if self.model_config.replace_ner: self.mapper = load_mappers(self.model_config.val_mapper, self.model_config.lower_case) while len(self.mapper) < len(self.data): self.mapper.append({}) assert len(self.data_complex_raw_lines) == len(self.data) assert len(self.mapper) == len(self.data) for i in range(self.model_config.num_refs): assert len(self.data_references_raw_lines[i]) == len(self.data) print('Use Val Dataset: \n Simple\t %s. \n Complex\t %s. \n Size\t %d' % (self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file, self.model_config.val_dataset_complex, len(self.data))) if 'rule' in self.model_config.memory or 'direct' in self.model_config.memory: self.vocab_rule = Rule(model_config, self.model_config.vocab_rules) self.rules = self.populate_rules( self.model_config.val_dataset_complex_ppdb, self.vocab_rule) print('Populate Rule with size:%s' % self.vocab_rule.get_rule_size()) if self.model_config.tune_style: self.comp_features = self.populate_comp_features( self.model_config.val_dataset_complex_features) def populate_comp_features(self, feature_path): data = [] for line in open(feature_path, encoding='utf-8'): items = line.split('\t') data.append( (float(items[0]), float(items[1]))) return data def populate_rules(self, rule_path, vocab_rule): data = [] for line in open(rule_path, encoding='utf-8'): cur_rules = line.split('\t') tmp = [] for cur_rule in cur_rules: rule_id, _, rule_targets = vocab_rule.encode(cur_rule) if rule_targets is not None: if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode: tmp.append((rule_id, self.vocab_simple.encode(rule_targets))) else: tmp.append((rule_id, [self.vocab_simple.encode(rule_target) for rule_target in rule_targets])) data.append(tmp) return data def populate_data_rawfile(self, data_path, lower_case=True): """Populate data raw lines into memory""" data = [] for line in open(data_path, encoding='utf-8'): if lower_case: line = line.lower() data.append(line.strip()) return data def populate_data(self, vocab_comp, vocab_simp, need_raw=False): # Populate data into memory data = [] # max_len = -1 # from collections import Counter # len_report = Counter() lines_comp = open( self.model_config.val_dataset_complex, encoding='utf-8').readlines() lines_simp = open( self.model_config.val_dataset_simple_folder + self.model_config.val_dataset_simple_file, encoding='utf-8').readlines() assert len(lines_comp) == len(lines_simp) for line_id in range(len(lines_comp)): obj = {} line_comp = lines_comp[line_id] line_simp = lines_simp[line_id] words_comp, words_raw_comp, obj_comp = data_utils.process_line( line_comp, vocab_comp, self.model_config.max_complex_sentence, self.model_config, need_raw, self.model_config.lower_case) words_simp, words_raw_simp, obj_simp = data_utils.process_line( line_simp, vocab_simp, self.model_config.max_simple_sentence, self.model_config, need_raw, self.model_config.lower_case) obj['words_comp'] = words_comp obj['words_simp'] = words_simp if need_raw: obj['words_raw_comp'] = words_raw_comp obj['words_raw_simp'] = words_raw_simp if self.model_config.subword_vocab_size and self.model_config.seg_mode: obj['line_comp_segids'] = obj_comp['segment_idxs'] obj['line_simp_segids'] = obj_simp['segment_idxs'] data.append(obj) # len_report.update([len(words)]) # if len(words) > max_len: # max_len = len(words) # print('Max length for data %s is %s.' % (data_path, max_len)) # print('counter:%s' % len_report) return data def get_data_iter(self): i = 0 while True: if i >= len(self.data): yield None else: ref_rawlines_batch = [self.data_references_raw_lines[j][i] for j in range(self.model_config.num_refs)] supplement = {} if 'rule' in self.model_config.memory or 'direct' in self.model_config.memory: supplement['mem'] = self.rules[i] if self.model_config.tune_style: supplement['comp_features'] = self.comp_features[i] obj = { 'sentence_simple': self.data[i]['words_simp'], 'sentence_complex': self.data[i]['words_comp'], 'sentence_complex_raw': self.data[i]['words_raw_comp'], 'sentence_simple_raw': self.data[i]['words_raw_simp'], 'sentence_complex_raw_lines': self.data_complex_raw_lines[i], 'mapper': self.mapper[i], 'ref_raw_lines': ref_rawlines_batch, 'sup': supplement, } if self.model_config.subword_vocab_size and self.model_config.seg_mode: obj['line_comp_segids'] = self.data[i]['line_comp_segids'] obj['line_simp_segids'] = self.data[i]['line_simp_segids'] yield obj i += 1