def translate(self, src_data_iter, tgt_data_iter, structure_iter, batch_size, out_file=None): data = build_dataset(self.fields, src_data_iter, tgt_data_iter, None, structure_iter,None, None, use_filter_pred=False) # for line in data: # print(line.__dict__) {src: , indices: structure: } def sort_translation(indices, translation): ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=True, sort_within_batch=True, shuffle=True) start_time = time.time() print("Begin decoding ...") batch_count = 0 all_translation = [] for batch in data_iter: ''' batch [torchtext.data.batch.Batch of size 30] [.src]:('[torch.LongTensor of size 4x30]', '[torch.LongTensor of size 30]') [.indices]:[torch.LongTensor of size 30] [.structure]:[torch.LongTensor of size 30x4x4] ''' hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) batch_transtaltion = [] for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores): src_words = self.build_tokens(src_idx_seq, side='src') src = ' '.join(src_words) tran_words = self.build_tokens(tran_idx_seq, side='tgt') tran = ' '.join(tran_words) batch_transtaltion.append(tran) print("SOURCE: " + src + "\nOUTPUT: " + tran + "\n") for index, tran in zip(batch.indices.data, batch_transtaltion): while (len(all_translation) <= index): all_translation.append("") all_translation[index] = tran batch_count += 1 print("batch: " + str(batch_count) + "...") if out_file is not None: for tran in all_translation: out_file.write(tran + '\n') print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
def build_save_dataset(corpus_type, fields, opt): # corpus_type: train or valid """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] # Judging whether it is train or valid if corpus_type == 'train': src_corpus = opt.train_src # 获取源端、目标端和结构信息的path tgt_corpus = opt.train_tgt stgt_corpus = opt.train_stgt structure_corpus = opt.train_structure mask_corpus = opt.train_mask relation_corpus = opt.train_relation else: src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt stgt_corpus = opt.valid_stgt structure_corpus = opt.valid_structure mask_corpus = opt.valid_mask relation_corpus = opt.valid_relation if (opt.shard_size > 0): return build_save_in_shards_using_shards_size( src_corpus, tgt_corpus, stgt_corpus, structure_corpus, mask_corpus, relation_corpus, fields, corpus_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) stgt_iter = make_text_iterator_from_file(stgt_corpus) structure_iter = make_text_iterator_from_file(structure_corpus) mask_iter = make_text_iterator_from_file(mask_corpus) relation_iter = make_text_iterator_from_file(relation_corpus) dataset = build_dataset(fields, src_iter, tgt_iter, stgt_iter, structure_iter, mask_iter, relation_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, abundancy=opt.abundancy) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None): # data每次产生一个eaxmple, 包含example.indice, example.src data = build_dataset(self.fields, src_data_iter=src_data_iter, tgt_data_iter=tgt_data_iter, use_filter_pred=False) def sort_translation(indices, translation): # indices是一维张量,translation是一维数组 ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) start_time = time.time() print("Begin decoding ...") idx = 0, # 此处的batch中的src每行长度不对齐 for batch in data_iter: # batch.src[0]: (27, batch_size), batch.src[1]: (27, ... ,...) # hyps尺寸为(batch_size, 4)的arry, scores长度为batch_size的一维数组 # 可以看出最终每句话均翻译为4个单词 # 下面代码使用batch的时候,并没有迭代,而是直接取值 hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) transtaltion = [] for idx_seq, score in zip(hyps, scores): words = self.build_tokens(idx_seq, side='tgt') tran = ' '.join(words) transtaltion.append(tran) if out_file is not None: transtaltion = sort_translation(batch.indices.data - idx, transtaltion) for tran in transtaltion: out_file.write(tran + '\n') idx += len(batch) print("sents " + str(idx) + "...") print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None): data = build_dataset(self.fields, src_data_iter=src_data_iter, tgt_data_iter=tgt_data_iter, use_filter_pred=False) def sort_translation(indices, translation): ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) start_time = time.time() print("Begin decoding ...") idx = 0 for batch in data_iter: hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) transtaltion = [] for idx_seq, score in zip(hyps, scores): words = self.build_tokens(idx_seq, side='tgt') tran = ' '.join(words) transtaltion.append(tran) if out_file is not None: transtaltion = sort_translation(batch.indices.data - idx, transtaltion) for tran in transtaltion: out_file.write(tran + '\n') idx += len(batch) print("sents " + str(idx) + "...") print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None): data = build_dataset(self.fields, src_data_iter=src_data_iter, tgt_data_iter=tgt_data_iter, use_filter_pred=False) def sort_translation(indices, translation): ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" # sort=True sort_within_batch=True shuffle=True data_iter = OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=False, shuffle=False) start_time = time.time() print("Begin decoding ...") batch_count = 0 all_translation = [] for batch in data_iter: hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) batch_transtaltion = [] for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores): # src_words = self.build_tokens(src_idx_seq, side='src') # src = ' '.join(src_words) tran_words = self.build_tokens(tran_idx_seq, side='tgt') tran = ' '.join(tran_words) batch_transtaltion.append(tran) print("SOURCE: " + "Three-modal" + "\nOUTPUT: " + tran + "\n")#src for index, tran in zip(batch.indices.data, batch_transtaltion):# why my batch have inidices while (len(all_translation) <= index): all_translation.append("") all_translation[index] = tran batch_count += 1 print("batch: " + str(batch_count) + "...") if out_file is not None: for tran in all_translation: out_file.write(tran + '\n') print('Decoding took %.1f minutes ...'%(float(time.time() - start_time) / 60.))
def build_save_dataset(corpus_type, task_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] assert task_type in ['task', 'task2'] if corpus_type == 'train': if task_type == 'task': src_corpus = opt.train_src tgt_corpus = opt.train_tgt else: src_corpus = opt.train_src2 tgt_corpus = opt.train_tgt2 else: if task_type == 'task': src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt else: src_corpus = opt.valid_src2 tgt_corpus = opt.valid_tgt2 if (opt.shard_size > 0): return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, task_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard # to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) dataset = build_dataset(fields, src_iter, tgt_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}_{:s}_{:s}.pt".format(opt.save_data, task_type, corpus_type) logger.info(" * saving %s %s dataset to %s." % (task_type, corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def build_save_dataset(corpus_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] # X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.test_mosei_emotion_data() X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.read_cmumosei_emotion_pkl( ) if corpus_type == 'train': src_corpus = X_train #opt.train_src tgt_corpus = y_train #opt.train_tgt else: src_corpus = X_valid #opt.valid_src tgt_corpus = y_valid #opt.valid_tgt if (opt.shard_size > 0): return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard # to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) dataset = build_dataset(fields, src_iter, tgt_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty.p dataset.fields = [] pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, task_type, opt): src_data = [] tgt_data = [] with open(src_corpus, "r") as src_file: with open(tgt_corpus, "r") as tgt_file: for s, t in zip(src_file, tgt_file): src_data.append(s) tgt_data.append(t) if len(src_data) != len(tgt_data): raise AssertionError("Source and Target should \ have the same length") num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) ret_list = [] for index, src in enumerate(src_list): logger.info("Building shard %d." % index) src_iter = make_text_iterator_from_file(src) tgt_iter = make_text_iterator_from_file(tgt_list[index]) dataset = build_dataset(fields, src_iter, tgt_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, task_type=task_type) pt_file = "{:s}_{:s}_{:s}.{:d}.pt".format(opt.save_data, task_type, corpus_type, index) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (index, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[index]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list
def build_save_dataset(corpus_type, fields, opt): #corpus_type: train or valid """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] #Judging whether it is train or valid if corpus_type == 'train': src_corpus = opt.train_src #获取源端、目标端和结构信息的path tgt_corpus = opt.train_tgt structure_corpus1 = opt.train_structure1 structure_corpus2 = opt.train_structure2 structure_corpus3 = opt.train_structure3 structure_corpus4 = opt.train_structure4 structure_corpus5 = opt.train_structure5 structure_corpus6 = opt.train_structure6 structure_corpus7 = opt.train_structure7 structure_corpus8 = opt.train_structure8 else: src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt structure_corpus1 = opt.valid_structure1 structure_corpus2 = opt.valid_structure2 structure_corpus3 = opt.valid_structure3 structure_corpus4 = opt.valid_structure4 structure_corpus5 = opt.valid_structure5 structure_corpus6 = opt.valid_structure6 structure_corpus7 = opt.valid_structure7 structure_corpus8 = opt.valid_structure8 if (opt.shard_size > 0): return build_save_in_shards_using_shards_size( src_corpus, tgt_corpus, structure_corpus1, structure_corpus2, structure_corpus3, structure_corpus4, structure_corpus5, fields, corpus_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) structure_iter1 = make_text_iterator_from_file(structure_corpus1) structure_iter2 = make_text_iterator_from_file(structure_corpus2) structure_iter3 = make_text_iterator_from_file(structure_corpus3) structure_iter4 = make_text_iterator_from_file(structure_corpus4) structure_iter5 = make_text_iterator_from_file(structure_corpus5) # structure_iter6 = make_text_iterator_from_file(structure_corpus6) # structure_iter7 = make_text_iterator_from_file(structure_corpus7) # structure_iter8 = make_text_iterator_from_file(structure_corpus8) dataset = build_dataset(fields, src_iter, tgt_iter, structure_iter1, structure_iter2, structure_iter3, structure_iter4, structure_iter5, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def build_save_in_shards_using_shards_size( src_corpus, tgt_corpus, structure_corpus1, structure_corpus2, structure_corpus3, structure_corpus4, structure_corpus5, fields, corpus_type, opt): src_data = [] tgt_data = [] structure_data1 = [] structure_data2 = [] structure_data3 = [] structure_data4 = [] structure_data5 = [] # structure_data6 = [] # structure_data7 = [] # structure_data8 = [] with open(src_corpus, "r") as src_file: with open(tgt_corpus, "r") as tgt_file: with open(structure_corpus1, "r") as structure_file1: with open(structure_corpus2, "r") as structure_file2: with open(structure_corpus3, "r") as structure_file3: with open(structure_corpus4, "r") as structure_file4: with open(structure_corpus5, "r") as structure_file5: # with open(structure_corpus6, "r") as structure_file6: # with open(structure_corpus7, "r") as structure_file7: # with open(structure_corpus8, "r") as structure_file8: for s, t, structure1, structure2, structure3, structure4, structure5 in zip( src_file, tgt_file, structure_file1, structure_file2, structure_file3, structure_file4, structure_file5, ): src_data.append(s) tgt_data.append(t) structure_data1.append(structure1) structure_data2.append(structure2) structure_data3.append(structure3) structure_data4.append(structure4) structure_data5.append(structure5) # structure_data6.append(structure6) # structure_data7.append(structure7) # structure_data8.append(structure8) assert (len(s.split()) + 1)**2 == len( structure1.split()) assert (len(s.split()) + 1)**2 == len( structure2.split()) assert (len(s.split()) + 1)**2 == len( structure3.split()) assert (len(s.split()) + 1)**2 == len( structure4.split()) assert (len(s.split()) + 1)**2 == len( structure5.split()) if len(src_data) != len(tgt_data) or len(tgt_data) != len(structure_data1): raise AssertionError( "Source,Target,structure and index should have the same length") num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus1 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data1[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus2 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data2[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus3 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data3[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus4 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data4[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus5 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data5[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() # f = codecs.open(structure_corpus6 + ".{0}.txt".format(x), "w", encoding="utf-8") # f.writelines(structure_data6[x * opt.shard_size: (x + 1) * opt.shard_size]) # f.close() # # f = codecs.open(structure_corpus7 + ".{0}.txt".format(x), "w", encoding="utf-8") # f.writelines(structure_data7[x * opt.shard_size: (x + 1) * opt.shard_size]) # f.close() # # f = codecs.open(structure_corpus8 + ".{0}.txt".format(x), "w", encoding="utf-8") # f.writelines(structure_data8[x * opt.shard_size: (x + 1) * opt.shard_size]) # f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: #处理最后一个剩下的shard logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus1 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data1[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus2 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data2[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus3 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data3[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus4 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data4[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus5 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data5[num_shards * opt.shard_size:]) f.close() # f = codecs.open(structure_corpus6 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") # f.writelines(structure_data6[num_shards * opt.shard_size:]) # f.close() # f = codecs.open(structure_corpus7 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") # f.writelines(structure_data7[num_shards * opt.shard_size:]) # f.close() # f = codecs.open(structure_corpus8 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") # f.writelines(structure_data8[num_shards * opt.shard_size:]) # f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) structure_list1 = sorted(glob.glob(structure_corpus1 + '.*.txt')) structure_list2 = sorted(glob.glob(structure_corpus2 + '.*.txt')) structure_list3 = sorted(glob.glob(structure_corpus3 + '.*.txt')) structure_list4 = sorted(glob.glob(structure_corpus4 + '.*.txt')) structure_list5 = sorted(glob.glob(structure_corpus5 + '.*.txt')) # structure_list6 = sorted(glob.glob(structure_corpus6 + '.*.txt')) # structure_list7 = sorted(glob.glob(structure_corpus7 + '.*.txt')) # structure_list8 = sorted(glob.glob(structure_corpus8 + '.*.txt')) ret_list = [] for i, src in enumerate(src_list): logger.info("Building shard %d." % i) src_iter = make_text_iterator_from_file(src) #迭代器,每次返回文件中的一行数据 tgt_iter = make_text_iterator_from_file(tgt_list[i]) structure_iter1 = make_text_iterator_from_file(structure_list1[i]) structure_iter2 = make_text_iterator_from_file(structure_list2[i]) structure_iter3 = make_text_iterator_from_file(structure_list3[i]) structure_iter4 = make_text_iterator_from_file(structure_list4[i]) structure_iter5 = make_text_iterator_from_file(structure_list5[i]) # structure_iter6 = make_text_iterator_from_file(structure_list6[i]) # structure_iter7 = make_text_iterator_from_file(structure_list7[i]) # structure_iter8 = make_text_iterator_from_file(structure_list8[i]) dataset = build_dataset(fields, src_iter, tgt_iter, structure_iter1, structure_iter2, structure_iter3, structure_iter4, structure_iter5, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) pt_file = "{:s}_{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) #..../gq_coupus_type.{0,1}.pt # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[i]) os.remove(structure_list1[i]) os.remove(structure_list2[i]) os.remove(structure_list3[i]) os.remove(structure_list4[i]) os.remove(structure_list5[i]) # os.remove(structure_list6[i]) # os.remove(structure_list7[i]) # os.remove(structure_list8[i]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list #返回一个文件名列表
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, structure_corpus, mask_corpus, relation_corpus, fields, corpus_type, opt): src_data = [] tgt_data = [] structure_data = [] mask_data = [] relation_data = [] with open(src_corpus, "r") as src_file: with open(tgt_corpus, "r") as tgt_file: with open(structure_corpus, "r") as structure_file: with open(mask_corpus, 'r') as mask_file: with open(relation_corpus, 'r') as relation_file: for s, t, structure, mask, relation in zip( src_file, tgt_file, structure_file, mask_file, relation_file): src_data.append(s) tgt_data.append(t) structure_data.append(structure) mask_data.append(mask) relation_data.append(relation) assert (len(s.split()) + 1)**2 == len( structure.split()) and (len(t.split())**2 == len(mask.split())) if len(src_data) != len(tgt_data) or len(tgt_data) != len(structure_data): raise AssertionError( "Source,Target and structure should have the same length") num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(mask_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(mask_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(relation_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(relation_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: # 处理最后一个剩下的shard logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(mask_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(mask_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(relation_corpus + ".{0}.txt".format(num_shards), "w", encoding="utf-8") f.writelines(relation_data[num_shards * opt.shard_size:]) f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) structure_list = sorted(glob.glob(structure_corpus + '.*.txt')) mask_list = sorted(glob.glob(mask_corpus + '.*.txt')) relation_list = sorted(glob.glob(relation_corpus + '.*.txt')) ret_list = [] for index, src in enumerate(src_list): logger.info("Building shard %d." % index) src_iter = make_text_iterator_from_file(src) # 迭代器,每次返回文件中的一行数据 tgt_iter = make_text_iterator_from_file(tgt_list[index]) structure_iter = make_text_iterator_from_file(structure_list[index]) mask_iter = make_text_iterator_from_file(mask_list[index]) relation_iter = make_text_iterator_from_file(relation_list[index]) dataset = build_dataset(fields, src_iter, tgt_iter, structure_iter, mask_iter, relation_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) pt_file = "{:s}_{:s}.{:d}.pt".format( opt.save_data, corpus_type, index) # ..../gq_coupus_type.{0,1}.pt # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (index, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[index]) os.remove(structure_list[index]) os.remove(mask_list[index]) os.remove(relation_list[index]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list # 返回一个文件名列表