def load_parallel_data(src_file, tgt_file, batch_size, sort_k_batches, dictionary, training=False): def preproc(s): s = s.replace('``', '"') s = s.replace('\'\'', '"') return s enc_dset = TextFile(files=[src_file], dictionary=dictionary, bos_token=None, eos_token=None, unk_token=CHAR_UNK_TOK, level='character', preprocess=preproc) dec_dset = TextFile(files=[tgt_file], dictionary=dictionary, bos_token=CHAR_SOS_TOK, eos_token=CHAR_EOS_TOK, unk_token=CHAR_UNK_TOK, level='character', preprocess=preproc) # NOTE merge encoder and decoder setup together stream = Merge([enc_dset.get_example_stream(), dec_dset.get_example_stream()], ('source', 'target')) if training: # filter sequences that are too long stream = Filter(stream, predicate=TooLong(seq_len=CHAR_MAX_SEQ_LEN)) # batch and read k batches ahead stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size*sort_k_batches)) # sort all samples in read-ahead batch stream = Mapping(stream, SortMapping(lambda x: len(x[1]))) # turn back into stream stream = Unpack(stream) # batch again stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size)) masked_stream = Padding(stream) return masked_stream
def load_parallel_data(src_file, tgt_file, batch_size, sort_k_batches, dictionary, training=False): def preproc(s): s = s.replace('``', '"') s = s.replace('\'\'', '"') return s enc_dset = TextFile(files=[src_file], dictionary=dictionary, bos_token=None, eos_token=None, unk_token=CHAR_UNK_TOK, level='character', preprocess=preproc) dec_dset = TextFile(files=[tgt_file], dictionary=dictionary, bos_token=CHAR_SOS_TOK, eos_token=CHAR_EOS_TOK, unk_token=CHAR_UNK_TOK, level='character', preprocess=preproc) # NOTE merge encoder and decoder setup together stream = Merge( [enc_dset.get_example_stream(), dec_dset.get_example_stream()], ('source', 'target')) if training: # filter sequences that are too long stream = Filter(stream, predicate=TooLong(seq_len=CHAR_MAX_SEQ_LEN)) # batch and read k batches ahead stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_k_batches)) # sort all samples in read-ahead batch stream = Mapping(stream, SortMapping(lambda x: len(x[1]))) # turn back into stream stream = Unpack(stream) # batch again stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size)) masked_stream = Padding(stream) return masked_stream
def load_data(src_file, tgt_file, batch_size, sort_k_batches, training=False): src_dict, tgt_dict = load_dictionaries() src_dset = TextFile(files=[src_file], dictionary=src_dict, bos_token=None, eos_token=None, unk_token=WORD_UNK_TOK) tgt_dset = TextFile(files=[tgt_file], dictionary=tgt_dict, bos_token=WORD_EOS_TOK, eos_token=WORD_EOS_TOK, unk_token=WORD_UNK_TOK) stream = Merge([src_dset.get_example_stream(), tgt_dset.get_example_stream()], ('source', 'target')) # filter sequences that are too long if training: stream = Filter(stream, predicate=TooLong(seq_len=WORD_MAX_SEQ_LEN)) # batch and read k batches ahead stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size*sort_k_batches)) # sort all samples in read-ahead batch stream = Mapping(stream, SortMapping(lambda x: len(x[1]))) # turn back into stream stream = Unpack(stream) # batch again stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size)) # NOTE pads with zeros so eos_idx should be 0 masked_stream = Padding(stream) return masked_stream, src_dict, tgt_dict