def test_cache(): dataset = IterableDataset(range(100)) stream = DataStream(dataset) batched_stream = Batch(stream, ConstantScheme(11)) cached_stream = Cache(batched_stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() # Make sure that cache is filled as expected for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert len(cached_stream.cache[0]) == cache_size # Make sure that the epoch finishes correctly for (features,) in cached_stream.get_epoch_iterator(): pass assert len(features) == 100 % 7 assert not cached_stream.cache[0] # Ensure that the epoch transition is correct cached_stream = Cache(batched_stream, ConstantScheme(7, times=3)) for _, epoch in zip(range(2), cached_stream.iterate_epochs()): cache_sizes = [4, 8, 1] for i, (features,) in enumerate(epoch): assert len(cached_stream.cache[0]) == cache_sizes[i] assert len(features) == 7 assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features) assert i == 2
def test_epoch_transition(self): cached_stream = Cache(self.stream, ConstantScheme(7, times=3)) for _, epoch in zip(range(2), cached_stream.iterate_epochs()): cache_sizes = [4, 8, 1] for i, (features,) in enumerate(epoch): assert_equal(len(cached_stream.cache[0]), cache_sizes[i]) assert_equal(len(features), 7) assert_equal(list(range(100))[i * 7:(i + 1) * 7], features) assert_equal(i, 2)
def test_epoch_finishes_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 100 % 7) assert not cached_stream.cache[0] stream = Batch(DataStream(IterableDataset(range(3000))), ConstantScheme(3200)) cached_stream = Cache(stream, ConstantScheme(64)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 3000 % 64) assert not cached_stream.cache[0]
def test_cache(): dataset = IterableDataset(range(100)) stream = DataStream(dataset) batched_stream = Batch(stream, ConstantScheme(11)) cached_stream = Cache(batched_stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() # Make sure that cache is filled as expected for (features, ), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert len(cached_stream.cache[0]) == cache_size # Make sure that the epoch finishes correctly for (features, ) in cached_stream.get_epoch_iterator(): pass assert len(features) == 100 % 7 assert not cached_stream.cache[0] # Ensure that the epoch transition is correct cached_stream = Cache(batched_stream, ConstantScheme(7, times=3)) for _, epoch in zip(range(2), cached_stream.iterate_epochs()): cache_sizes = [4, 8, 1] for i, (features, ) in enumerate(epoch): assert len(cached_stream.cache[0]) == cache_sizes[i] assert len(features) == 7 assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features) assert i == 2
def test_epoch_finishes_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) data = list(cached_stream.get_epoch_iterator()) assert_equal(len(data[-1][0]), 100 % 7) assert not cached_stream.cache[0]
def test_cache_fills_correctly(self): cached_stream = Cache(self.stream, ConstantScheme(7)) epoch = cached_stream.get_epoch_iterator() for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]): assert_equal(len(cached_stream.cache[0]), cache_size)
def test_axis_labels_passed_on_by_default(self): self.stream.axis_labels = {'features': ('batch', 'index')} cached_stream = Cache(self.stream, ConstantScheme(7)) assert_equal(cached_stream.axis_labels, self.stream.axis_labels)
def test_value_error_on_none_request(self): cached_stream = Cache(self.stream, ConstantScheme(7)) cached_stream.get_epoch_iterator() assert_raises(ValueError, cached_stream.get_data, None)
def get_stream(source, target, source_input_dict, target_label_dict, batch_size, buffer_multiplier=100, input_token_level='word', n_input_tokens=0, n_labels=0, reverse_labels=False, max_input_length=None, max_label_length=None, pad_labels=True, is_sort=True): """Returns a stream over sentence pairs. Parameters ---------- source : list A list of files to read source languages from. target : list A list of corresponding files in the target language. source_word_dict : str Path to a tab-delimited text file whose last column contains the vocabulary. target_label_dict : str See `source_char_dict`. batch_size : int The minibatch size. buffer_multiplier : int The number of batches to load, concatenate, sort by length of source sentence, and split again; this makes batches more uniform in their sentence length and hence more computationally efficient. n_source_words : int The number of words in the source vocabulary. Pass 0 (default) to use the entire vocabulary. n_target_labels : int See `n_chars_source`. """ if len(source) != len(target): raise ValueError("number of source and target files don't match") # Read the dictionaries dicts = [ load_dict(source_input_dict, dict_size=n_input_tokens), load_dict(target_label_dict, dict_size=n_labels, reverse=reverse_labels, include_unk=False) ] # Open the two sets of files and merge them streams = [ TextFile(source, dicts[0], level=input_token_level, bos_token=None, eos_token=EOS_TOKEN, encoding='utf-8').get_example_stream(), TextFile(target, dicts[1], level='word', bos_token=None, unk_token=None, eos_token=EOS_TOKEN, encoding='utf-8').get_example_stream() ] merged = Merge(streams, ('source_input_tokens', 'target_labels')) if reverse_labels: merged = SortLabels(merged) # Filter sentence lengths if max_input_length or max_label_length: def filter_pair(pair): src_input_tokens, trg_labels = pair src_input_ok = (not max_input_length) or \ len(src_input_tokens) <= (max_input_length + 1) trg_label_ok = (not max_label_length) or \ len(trg_labels) <= (max_label_length + 1) return src_input_ok and trg_label_ok merged = Filter(merged, filter_pair) # Batches of approximately uniform size large_batches = Batch(merged, iteration_scheme=ConstantScheme(batch_size * buffer_multiplier)) # sorted_batches = Mapping(large_batches, SortMapping(_source_length)) # batches = Cache(sorted_batches, ConstantScheme(batch_size)) # shuffled_batches = Shuffle(batches, buffer_multiplier) # masked_batches = Padding(shuffled_batches, # mask_sources=('source_chars', 'target_labels')) if is_sort: sorted_batches = Mapping(large_batches, SortMapping(_source_length)) else: sorted_batches = large_batches batches = Cache(sorted_batches, ConstantScheme(batch_size)) mask_sources = ('source_input_tokens', 'target_labels') masked_batches = Padding(batches, mask_sources=mask_sources) return masked_batches
def get_stream(source, target, source_dict, target_dict, batch_size, buffer_multiplier=100, n_words_source=0, n_words_target=0, max_src_length=None, max_trg_length=None): """Returns a stream over sentence pairs. Parameters ---------- source : list A list of files to read source languages from. target : list A list of corresponding files in the target language. source_dict : str Path to a tab-delimited text file whose last column contains the vocabulary. target_dict : str See `source_dict`. batch_size : int The minibatch size. buffer_multiplier : int The number of batches to load, concatenate, sort by length of source sentence, and split again; this makes batches more uniform in their sentence length and hence more computationally efficient. n_words_source : int The number of words in the source vocabulary. Pass 0 (default) to use the entire vocabulary. n_words_target : int See `n_words_source`. """ if len(source) != len(target): raise ValueError("number of source and target files don't match") # Read the dictionaries dicts = [ load_dict(source_dict, n_words=n_words_source), load_dict(target_dict, n_words=n_words_target) ] # Open the two sets of files and merge them streams = [ TextFile(source, dicts[0], bos_token=None, eos_token=EOS_TOKEN).get_example_stream(), TextFile(target, dicts[1], bos_token=None, eos_token=EOS_TOKEN).get_example_stream() ] merged = Merge(streams, ('source', 'target')) # Filter sentence lengths if max_src_length or max_trg_length: def filter_pair(pair): src, trg = pair src_ok = (not max_src_length) or len(src) < max_src_length trg_ok = (not max_trg_length) or len(trg) < max_trg_length return src_ok and trg_ok merged = Filter(merged, filter_pair) # Batches of approximately uniform size large_batches = Batch(merged, iteration_scheme=ConstantScheme(batch_size * buffer_multiplier)) sorted_batches = Mapping(large_batches, SortMapping(_source_length)) batches = Cache(sorted_batches, ConstantScheme(batch_size)) shuffled_batches = Shuffle(batches, buffer_multiplier) masked_batches = Padding(shuffled_batches) return masked_batches