def reset(self): if self.shuffle: shuffle.main([ self.source.name.replace('.shuf', ''), self.target.name.replace('.shuf', '') ]) self.source = fopen(self.source.name) self.target = fopen(self.target.name) else: self.source.seek(0) self.target.seek(0)
def __init__(self, source, target, hter, source_dict, target_dict, batch_size=128, maxlen=100, n_words_source=-1, n_words_target=-1, shuffle_each_epoch=False, sort_by_length=False, maxibatch_size=20): # 每次epoch都,打乱文件顺序 if shuffle_each_epoch: shuffle.main([source, target, hter]) self.source = fopen(source + '.shuf') self.target = fopen(target + '.shuf') self.hter = fopen(hter + '.shuf') else: self.source = fopen(source) self.target = fopen(target) self.hter = fopen(hter) self.source_dict = load_dict(source_dict) self.target_dict = load_dict(target_dict) self.batch_size = batch_size self.maxlen = maxlen self.n_words_source = n_words_source self.n_words_target = n_words_target if self.n_words_source > 0: for key, idx in self.source_dict.items(): if idx >= self.n_words_source: del self.source_dict[key] if self.n_words_target > 0: for key, idx in self.target_dict.items(): if idx >= self.n_words_target: del self.target_dict[key] self.shuffle = shuffle_each_epoch self.sort_by_length = sort_by_length self.source_buffer = [] self.target_buffer = [] self.hter_buffer = [] self.k = batch_size * maxibatch_size self.end_of_data = False
import os import sys sys.path.insert(1, os.path.abspath('../')) from nematus import shuffle if __name__ == "__main__": shuffle.main(sys.argv[1:])