def build_save_dataset(corpus_type, fields, opt): # corpus_type: train or valid """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] # Judging whether it is train or valid if corpus_type == 'train': src_corpus = opt.train_src # 获取源端、目标端和结构信息的path tgt_corpus = opt.train_tgt structure_corpus = opt.train_structure mask_corpus = opt.train_mask relation_corpus = opt.train_relation else: src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt structure_corpus = opt.valid_structure mask_corpus = opt.valid_mask relation_corpus = opt.valid_relation if (opt.shard_size > 0): return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, structure_corpus, mask_corpus, relation_corpus, fields, corpus_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) structure_iter = make_text_iterator_from_file(structure_corpus) mask_iter = make_text_iterator_from_file(mask_corpus) relation_iter = make_text_iterator_from_file(relation_corpus) dataset = build_dataset(fields, src_iter, tgt_iter, structure_iter, mask_iter, relation_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def main(opt): translator = build_translator(opt) out_file = codecs.open(opt.output, 'w+', 'utf-8') src_iter = make_text_iterator_from_file(opt.src) if opt.tgt is not None: tgt_iter = make_text_iterator_from_file(opt.tgt) else: tgt_iter = None translator.translate(src_data_iter=src_iter, tgt_data_iter=tgt_iter, batch_size=opt.batch_size, out_file=out_file) out_file.close()
def build_save_dataset(corpus_type, task_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] assert task_type in ['task', 'task2'] if corpus_type == 'train': if task_type == 'task': src_corpus = opt.train_src tgt_corpus = opt.train_tgt else: src_corpus = opt.train_src2 tgt_corpus = opt.train_tgt2 else: if task_type == 'task': src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt else: src_corpus = opt.valid_src2 tgt_corpus = opt.valid_tgt2 if (opt.shard_size > 0): return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, task_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard # to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) dataset = build_dataset(fields, src_iter, tgt_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}_{:s}_{:s}.pt".format(opt.save_data, task_type, corpus_type) logger.info(" * saving %s %s dataset to %s." % (task_type, corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def main(opt, model_path): translator = build_translator(opt, model_path) out_file = codecs.open(opt.output, 'w+', 'utf-8') # X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.test_mosei_emotion_data() X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.read_cmumosei_emotion_pkl( ) src_path = X_test src_iter = make_text_iterator_from_file(src_path) #(opt.src) tgt_path = y_test tgt_iter = make_text_iterator_from_file(tgt_path) # if opt.tgt is not None: # tgt_iter = make_text_iterator_from_file(opt.tgt) # else: # tgt_iter = None translator.translate(src_data_iter=src_iter, tgt_data_iter=tgt_iter, batch_size=opt.batch_size, out_file=out_file) out_file.close()
def build_save_dataset(corpus_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] # X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.test_mosei_emotion_data() X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.read_cmumosei_emotion_pkl( ) if corpus_type == 'train': src_corpus = X_train #opt.train_src tgt_corpus = y_train #opt.train_tgt else: src_corpus = X_valid #opt.valid_src tgt_corpus = y_valid #opt.valid_tgt if (opt.shard_size > 0): return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard # to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) dataset = build_dataset(fields, src_iter, tgt_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty.p dataset.fields = [] pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def main(opt): translator = build_translator(opt) out_file = codecs.open(opt.output, 'w+', 'utf-8') src_iter = make_text_iterator_from_file(opt.src) if opt.tgt is not None: tgt_iter = make_text_iterator_from_file(opt.tgt) else: tgt_iter = None if opt.structure1 is not None: structure_iter1 = make_text_iterator_from_file(opt.structure1) else: structure_iter1 = None if opt.structure2 is not None: structure_iter2 = make_text_iterator_from_file(opt.structure2) else: structure_iter2 = None if opt.structure3 is not None: structure_iter3 = make_text_iterator_from_file(opt.structure3) else: structure_iter3 = None if opt.structure4 is not None: structure_iter4 = make_text_iterator_from_file(opt.structure4) else: structure_iter4 = None if opt.structure5 is not None: structure_iter5 = make_text_iterator_from_file(opt.structure5) else: structure_iter5 = None translator.translate(src_data_iter=src_iter, tgt_data_iter=tgt_iter, structure_iter1=structure_iter1, structure_iter2=structure_iter2, structure_iter3=structure_iter3, structure_iter4=structure_iter4, structure_iter5=structure_iter5, batch_size=opt.batch_size, out_file=out_file) out_file.close()
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, task_type, opt): src_data = [] tgt_data = [] with open(src_corpus, "r") as src_file: with open(tgt_corpus, "r") as tgt_file: for s, t in zip(src_file, tgt_file): src_data.append(s) tgt_data.append(t) if len(src_data) != len(tgt_data): raise AssertionError("Source and Target should \ have the same length") num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) ret_list = [] for index, src in enumerate(src_list): logger.info("Building shard %d." % index) src_iter = make_text_iterator_from_file(src) tgt_iter = make_text_iterator_from_file(tgt_list[index]) dataset = build_dataset(fields, src_iter, tgt_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, task_type=task_type) pt_file = "{:s}_{:s}_{:s}.{:d}.pt".format(opt.save_data, task_type, corpus_type, index) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (index, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[index]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list
def build_save_dataset(corpus_type, fields, opt): #corpus_type: train or valid """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] #Judging whether it is train or valid if corpus_type == 'train': src_corpus = opt.train_src #获取源端、目标端和结构信息的path tgt_corpus = opt.train_tgt structure_corpus1 = opt.train_structure1 structure_corpus2 = opt.train_structure2 structure_corpus3 = opt.train_structure3 structure_corpus4 = opt.train_structure4 structure_corpus5 = opt.train_structure5 structure_corpus6 = opt.train_structure6 structure_corpus7 = opt.train_structure7 structure_corpus8 = opt.train_structure8 else: src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt structure_corpus1 = opt.valid_structure1 structure_corpus2 = opt.valid_structure2 structure_corpus3 = opt.valid_structure3 structure_corpus4 = opt.valid_structure4 structure_corpus5 = opt.valid_structure5 structure_corpus6 = opt.valid_structure6 structure_corpus7 = opt.valid_structure7 structure_corpus8 = opt.valid_structure8 if (opt.shard_size > 0): return build_save_in_shards_using_shards_size( src_corpus, tgt_corpus, structure_corpus1, structure_corpus2, structure_corpus3, structure_corpus4, structure_corpus5, fields, corpus_type, opt) # We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard to do this should users need this feature. src_iter = make_text_iterator_from_file(src_corpus) tgt_iter = make_text_iterator_from_file(tgt_corpus) structure_iter1 = make_text_iterator_from_file(structure_corpus1) structure_iter2 = make_text_iterator_from_file(structure_corpus2) structure_iter3 = make_text_iterator_from_file(structure_corpus3) structure_iter4 = make_text_iterator_from_file(structure_corpus4) structure_iter5 = make_text_iterator_from_file(structure_corpus5) # structure_iter6 = make_text_iterator_from_file(structure_corpus6) # structure_iter7 = make_text_iterator_from_file(structure_corpus7) # structure_iter8 = make_text_iterator_from_file(structure_corpus8) dataset = build_dataset(fields, src_iter, tgt_iter, structure_iter1, structure_iter2, structure_iter3, structure_iter4, structure_iter5, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def build_save_in_shards_using_shards_size( src_corpus, tgt_corpus, structure_corpus1, structure_corpus2, structure_corpus3, structure_corpus4, structure_corpus5, fields, corpus_type, opt): src_data = [] tgt_data = [] structure_data1 = [] structure_data2 = [] structure_data3 = [] structure_data4 = [] structure_data5 = [] # structure_data6 = [] # structure_data7 = [] # structure_data8 = [] with open(src_corpus, "r") as src_file: with open(tgt_corpus, "r") as tgt_file: with open(structure_corpus1, "r") as structure_file1: with open(structure_corpus2, "r") as structure_file2: with open(structure_corpus3, "r") as structure_file3: with open(structure_corpus4, "r") as structure_file4: with open(structure_corpus5, "r") as structure_file5: # with open(structure_corpus6, "r") as structure_file6: # with open(structure_corpus7, "r") as structure_file7: # with open(structure_corpus8, "r") as structure_file8: for s, t, structure1, structure2, structure3, structure4, structure5 in zip( src_file, tgt_file, structure_file1, structure_file2, structure_file3, structure_file4, structure_file5, ): src_data.append(s) tgt_data.append(t) structure_data1.append(structure1) structure_data2.append(structure2) structure_data3.append(structure3) structure_data4.append(structure4) structure_data5.append(structure5) # structure_data6.append(structure6) # structure_data7.append(structure7) # structure_data8.append(structure8) assert (len(s.split()) + 1)**2 == len( structure1.split()) assert (len(s.split()) + 1)**2 == len( structure2.split()) assert (len(s.split()) + 1)**2 == len( structure3.split()) assert (len(s.split()) + 1)**2 == len( structure4.split()) assert (len(s.split()) + 1)**2 == len( structure5.split()) if len(src_data) != len(tgt_data) or len(tgt_data) != len(structure_data1): raise AssertionError( "Source,Target,structure and index should have the same length") num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus1 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data1[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus2 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data2[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus3 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data3[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus4 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data4[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus5 + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data5[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() # f = codecs.open(structure_corpus6 + ".{0}.txt".format(x), "w", encoding="utf-8") # f.writelines(structure_data6[x * opt.shard_size: (x + 1) * opt.shard_size]) # f.close() # # f = codecs.open(structure_corpus7 + ".{0}.txt".format(x), "w", encoding="utf-8") # f.writelines(structure_data7[x * opt.shard_size: (x + 1) * opt.shard_size]) # f.close() # # f = codecs.open(structure_corpus8 + ".{0}.txt".format(x), "w", encoding="utf-8") # f.writelines(structure_data8[x * opt.shard_size: (x + 1) * opt.shard_size]) # f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: #处理最后一个剩下的shard logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus1 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data1[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus2 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data2[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus3 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data3[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus4 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data4[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus5 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data5[num_shards * opt.shard_size:]) f.close() # f = codecs.open(structure_corpus6 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") # f.writelines(structure_data6[num_shards * opt.shard_size:]) # f.close() # f = codecs.open(structure_corpus7 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") # f.writelines(structure_data7[num_shards * opt.shard_size:]) # f.close() # f = codecs.open(structure_corpus8 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") # f.writelines(structure_data8[num_shards * opt.shard_size:]) # f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) structure_list1 = sorted(glob.glob(structure_corpus1 + '.*.txt')) structure_list2 = sorted(glob.glob(structure_corpus2 + '.*.txt')) structure_list3 = sorted(glob.glob(structure_corpus3 + '.*.txt')) structure_list4 = sorted(glob.glob(structure_corpus4 + '.*.txt')) structure_list5 = sorted(glob.glob(structure_corpus5 + '.*.txt')) # structure_list6 = sorted(glob.glob(structure_corpus6 + '.*.txt')) # structure_list7 = sorted(glob.glob(structure_corpus7 + '.*.txt')) # structure_list8 = sorted(glob.glob(structure_corpus8 + '.*.txt')) ret_list = [] for i, src in enumerate(src_list): logger.info("Building shard %d." % i) src_iter = make_text_iterator_from_file(src) #迭代器,每次返回文件中的一行数据 tgt_iter = make_text_iterator_from_file(tgt_list[i]) structure_iter1 = make_text_iterator_from_file(structure_list1[i]) structure_iter2 = make_text_iterator_from_file(structure_list2[i]) structure_iter3 = make_text_iterator_from_file(structure_list3[i]) structure_iter4 = make_text_iterator_from_file(structure_list4[i]) structure_iter5 = make_text_iterator_from_file(structure_list5[i]) # structure_iter6 = make_text_iterator_from_file(structure_list6[i]) # structure_iter7 = make_text_iterator_from_file(structure_list7[i]) # structure_iter8 = make_text_iterator_from_file(structure_list8[i]) dataset = build_dataset(fields, src_iter, tgt_iter, structure_iter1, structure_iter2, structure_iter3, structure_iter4, structure_iter5, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) pt_file = "{:s}_{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) #..../gq_coupus_type.{0,1}.pt # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[i]) os.remove(structure_list1[i]) os.remove(structure_list2[i]) os.remove(structure_list3[i]) os.remove(structure_list4[i]) os.remove(structure_list5[i]) # os.remove(structure_list6[i]) # os.remove(structure_list7[i]) # os.remove(structure_list8[i]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list #返回一个文件名列表
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, structure_corpus, mask_corpus, relation_corpus, fields, corpus_type, opt): src_data = [] tgt_data = [] structure_data = [] mask_data = [] relation_data = [] with open(src_corpus, "r") as src_file: with open(tgt_corpus, "r") as tgt_file: with open(structure_corpus, "r") as structure_file: with open(mask_corpus, 'r') as mask_file: with open(relation_corpus, 'r') as relation_file: for s, t, structure, mask, relation in zip( src_file, tgt_file, structure_file, mask_file, relation_file): src_data.append(s) tgt_data.append(t) structure_data.append(structure) mask_data.append(mask) relation_data.append(relation) assert (len(s.split()) + 1)**2 == len( structure.split()) and (len(t.split())**2 == len(mask.split())) if len(src_data) != len(tgt_data) or len(tgt_data) != len(structure_data): raise AssertionError( "Source,Target and structure should have the same length") num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(structure_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(structure_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(mask_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(mask_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(relation_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(relation_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: # 处理最后一个剩下的shard logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(structure_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(structure_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(mask_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(mask_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(relation_corpus + ".{0}.txt".format(num_shards), "w", encoding="utf-8") f.writelines(relation_data[num_shards * opt.shard_size:]) f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) structure_list = sorted(glob.glob(structure_corpus + '.*.txt')) mask_list = sorted(glob.glob(mask_corpus + '.*.txt')) relation_list = sorted(glob.glob(relation_corpus + '.*.txt')) ret_list = [] for index, src in enumerate(src_list): logger.info("Building shard %d." % index) src_iter = make_text_iterator_from_file(src) # 迭代器,每次返回文件中的一行数据 tgt_iter = make_text_iterator_from_file(tgt_list[index]) structure_iter = make_text_iterator_from_file(structure_list[index]) mask_iter = make_text_iterator_from_file(mask_list[index]) relation_iter = make_text_iterator_from_file(relation_list[index]) dataset = build_dataset(fields, src_iter, tgt_iter, structure_iter, mask_iter, relation_iter, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc) pt_file = "{:s}_{:s}.{:d}.pt".format( opt.save_data, corpus_type, index) # ..../gq_coupus_type.{0,1}.pt # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (index, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[index]) os.remove(structure_list[index]) os.remove(mask_list[index]) os.remove(relation_list[index]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list # 返回一个文件名列表