Exemplo n.º 1
0
def test_cache():
    dataset = IterableDataset(range(100))
    stream = DataStream(dataset)
    batched_stream = Batch(stream, ConstantScheme(11))
    cached_stream = Cache(batched_stream, ConstantScheme(7))
    epoch = cached_stream.get_epoch_iterator()

    # Make sure that cache is filled as expected
    for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2,
                                               6, 10, 3, 7, 0, 4]):
        assert len(cached_stream.cache[0]) == cache_size

    # Make sure that the epoch finishes correctly
    for (features,) in cached_stream.get_epoch_iterator():
        pass
    assert len(features) == 100 % 7
    assert not cached_stream.cache[0]

    # Ensure that the epoch transition is correct
    cached_stream = Cache(batched_stream, ConstantScheme(7, times=3))
    for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
        cache_sizes = [4, 8, 1]
        for i, (features,) in enumerate(epoch):
            assert len(cached_stream.cache[0]) == cache_sizes[i]
            assert len(features) == 7
            assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features)
        assert i == 2
Exemplo n.º 2
0
 def test_epoch_transition(self):
     cached_stream = Cache(self.stream, ConstantScheme(7, times=3))
     for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
         cache_sizes = [4, 8, 1]
         for i, (features,) in enumerate(epoch):
             assert_equal(len(cached_stream.cache[0]), cache_sizes[i])
             assert_equal(len(features), 7)
             assert_equal(list(range(100))[i * 7:(i + 1) * 7], features)
         assert_equal(i, 2)
Exemplo n.º 3
0
 def test_epoch_transition(self):
     cached_stream = Cache(self.stream, ConstantScheme(7, times=3))
     for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
         cache_sizes = [4, 8, 1]
         for i, (features,) in enumerate(epoch):
             assert_equal(len(cached_stream.cache[0]), cache_sizes[i])
             assert_equal(len(features), 7)
             assert_equal(list(range(100))[i * 7:(i + 1) * 7], features)
         assert_equal(i, 2)
Exemplo n.º 4
0
    def test_epoch_finishes_correctly(self):
        cached_stream = Cache(self.stream, ConstantScheme(7))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 100 % 7)
        assert not cached_stream.cache[0]

        stream = Batch(DataStream(IterableDataset(range(3000))),
                       ConstantScheme(3200))

        cached_stream = Cache(stream, ConstantScheme(64))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 3000 % 64)
        assert not cached_stream.cache[0]
Exemplo n.º 5
0
def test_cache():
    dataset = IterableDataset(range(100))
    stream = DataStream(dataset)
    batched_stream = Batch(stream, ConstantScheme(11))
    cached_stream = Cache(batched_stream, ConstantScheme(7))
    epoch = cached_stream.get_epoch_iterator()

    # Make sure that cache is filled as expected
    for (features, ), cache_size in zip(epoch,
                                        [4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 0, 4]):
        assert len(cached_stream.cache[0]) == cache_size

    # Make sure that the epoch finishes correctly
    for (features, ) in cached_stream.get_epoch_iterator():
        pass
    assert len(features) == 100 % 7
    assert not cached_stream.cache[0]

    # Ensure that the epoch transition is correct
    cached_stream = Cache(batched_stream, ConstantScheme(7, times=3))
    for _, epoch in zip(range(2), cached_stream.iterate_epochs()):
        cache_sizes = [4, 8, 1]
        for i, (features, ) in enumerate(epoch):
            assert len(cached_stream.cache[0]) == cache_sizes[i]
            assert len(features) == 7
            assert numpy.all(list(range(100))[i * 7:(i + 1) * 7] == features)
        assert i == 2
Exemplo n.º 6
0
    def test_epoch_finishes_correctly(self):
        cached_stream = Cache(self.stream, ConstantScheme(7))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 100 % 7)
        assert not cached_stream.cache[0]

        stream = Batch(DataStream(IterableDataset(range(3000))),
                       ConstantScheme(3200))

        cached_stream = Cache(stream, ConstantScheme(64))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 3000 % 64)
        assert not cached_stream.cache[0]
Exemplo n.º 7
0
 def test_epoch_finishes_correctly(self):
     cached_stream = Cache(self.stream, ConstantScheme(7))
     data = list(cached_stream.get_epoch_iterator())
     assert_equal(len(data[-1][0]), 100 % 7)
     assert not cached_stream.cache[0]
Exemplo n.º 8
0
 def test_cache_fills_correctly(self):
     cached_stream = Cache(self.stream, ConstantScheme(7))
     epoch = cached_stream.get_epoch_iterator()
     for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2,
                                                6, 10, 3, 7, 0, 4]):
         assert_equal(len(cached_stream.cache[0]), cache_size)
Exemplo n.º 9
0
 def test_axis_labels_passed_on_by_default(self):
     self.stream.axis_labels = {'features': ('batch', 'index')}
     cached_stream = Cache(self.stream, ConstantScheme(7))
     assert_equal(cached_stream.axis_labels, self.stream.axis_labels)
Exemplo n.º 10
0
 def test_value_error_on_none_request(self):
     cached_stream = Cache(self.stream, ConstantScheme(7))
     cached_stream.get_epoch_iterator()
     assert_raises(ValueError, cached_stream.get_data, None)
Exemplo n.º 11
0
 def test_value_error_on_none_request(self):
     cached_stream = Cache(self.stream, ConstantScheme(7))
     cached_stream.get_epoch_iterator()
     assert_raises(ValueError, cached_stream.get_data, None)
Exemplo n.º 12
0
 def test_epoch_finishes_correctly(self):
     cached_stream = Cache(self.stream, ConstantScheme(7))
     data = list(cached_stream.get_epoch_iterator())
     assert_equal(len(data[-1][0]), 100 % 7)
     assert not cached_stream.cache[0]
Exemplo n.º 13
0
 def test_cache_fills_correctly(self):
     cached_stream = Cache(self.stream, ConstantScheme(7))
     epoch = cached_stream.get_epoch_iterator()
     for (features,), cache_size in zip(epoch, [4, 8, 1, 5, 9, 2,
                                                6, 10, 3, 7, 0, 4]):
         assert_equal(len(cached_stream.cache[0]), cache_size)
Exemplo n.º 14
0
def get_stream(source,
               target,
               source_input_dict,
               target_label_dict,
               batch_size,
               buffer_multiplier=100,
               input_token_level='word',
               n_input_tokens=0,
               n_labels=0,
               reverse_labels=False,
               max_input_length=None,
               max_label_length=None,
               pad_labels=True,
               is_sort=True):
    """Returns a stream over sentence pairs.

    Parameters
    ----------
    source : list
        A list of files to read source languages from.
    target : list
        A list of corresponding files in the target language.
    source_word_dict : str
        Path to a tab-delimited text file whose last column contains the
        vocabulary.
    target_label_dict : str
        See `source_char_dict`.
    batch_size : int
        The minibatch size.
    buffer_multiplier : int
        The number of batches to load, concatenate, sort by length of
        source sentence, and split again; this makes batches more uniform
        in their sentence length and hence more computationally efficient.
    n_source_words : int
        The number of words in the source vocabulary. Pass 0 (default) to
        use the entire vocabulary.
    n_target_labels : int
        See `n_chars_source`.

    """
    if len(source) != len(target):
        raise ValueError("number of source and target files don't match")

    # Read the dictionaries
    dicts = [
        load_dict(source_input_dict, dict_size=n_input_tokens),
        load_dict(target_label_dict,
                  dict_size=n_labels,
                  reverse=reverse_labels,
                  include_unk=False)
    ]

    # Open the two sets of files and merge them
    streams = [
        TextFile(source,
                 dicts[0],
                 level=input_token_level,
                 bos_token=None,
                 eos_token=EOS_TOKEN,
                 encoding='utf-8').get_example_stream(),
        TextFile(target,
                 dicts[1],
                 level='word',
                 bos_token=None,
                 unk_token=None,
                 eos_token=EOS_TOKEN,
                 encoding='utf-8').get_example_stream()
    ]
    merged = Merge(streams, ('source_input_tokens', 'target_labels'))
    if reverse_labels:
        merged = SortLabels(merged)

    # Filter sentence lengths
    if max_input_length or max_label_length:

        def filter_pair(pair):
            src_input_tokens, trg_labels = pair
            src_input_ok = (not max_input_length) or \
                len(src_input_tokens) <= (max_input_length + 1)
            trg_label_ok = (not max_label_length) or \
                len(trg_labels) <= (max_label_length + 1)

            return src_input_ok and trg_label_ok

        merged = Filter(merged, filter_pair)

    # Batches of approximately uniform size
    large_batches = Batch(merged,
                          iteration_scheme=ConstantScheme(batch_size *
                                                          buffer_multiplier))
    # sorted_batches = Mapping(large_batches, SortMapping(_source_length))
    # batches = Cache(sorted_batches, ConstantScheme(batch_size))
    # shuffled_batches = Shuffle(batches, buffer_multiplier)
    # masked_batches = Padding(shuffled_batches,
    #                          mask_sources=('source_chars', 'target_labels'))
    if is_sort:
        sorted_batches = Mapping(large_batches, SortMapping(_source_length))
    else:
        sorted_batches = large_batches
    batches = Cache(sorted_batches, ConstantScheme(batch_size))
    mask_sources = ('source_input_tokens', 'target_labels')
    masked_batches = Padding(batches, mask_sources=mask_sources)

    return masked_batches
Exemplo n.º 15
0
def get_stream(source,
               target,
               source_dict,
               target_dict,
               batch_size,
               buffer_multiplier=100,
               n_words_source=0,
               n_words_target=0,
               max_src_length=None,
               max_trg_length=None):
    """Returns a stream over sentence pairs.

    Parameters
    ----------
    source : list
        A list of files to read source languages from.
    target : list
        A list of corresponding files in the target language.
    source_dict : str
        Path to a tab-delimited text file whose last column contains the
        vocabulary.
    target_dict : str
        See `source_dict`.
    batch_size : int
        The minibatch size.
    buffer_multiplier : int
        The number of batches to load, concatenate, sort by length of
        source sentence, and split again; this makes batches more uniform
        in their sentence length and hence more computationally efficient.
    n_words_source : int
        The number of words in the source vocabulary. Pass 0 (default) to
        use the entire vocabulary.
    n_words_target : int
        See `n_words_source`.

    """
    if len(source) != len(target):
        raise ValueError("number of source and target files don't match")

    # Read the dictionaries
    dicts = [
        load_dict(source_dict, n_words=n_words_source),
        load_dict(target_dict, n_words=n_words_target)
    ]

    # Open the two sets of files and merge them
    streams = [
        TextFile(source, dicts[0], bos_token=None,
                 eos_token=EOS_TOKEN).get_example_stream(),
        TextFile(target, dicts[1], bos_token=None,
                 eos_token=EOS_TOKEN).get_example_stream()
    ]
    merged = Merge(streams, ('source', 'target'))

    # Filter sentence lengths
    if max_src_length or max_trg_length:

        def filter_pair(pair):
            src, trg = pair
            src_ok = (not max_src_length) or len(src) < max_src_length
            trg_ok = (not max_trg_length) or len(trg) < max_trg_length
            return src_ok and trg_ok

        merged = Filter(merged, filter_pair)

    # Batches of approximately uniform size
    large_batches = Batch(merged,
                          iteration_scheme=ConstantScheme(batch_size *
                                                          buffer_multiplier))
    sorted_batches = Mapping(large_batches, SortMapping(_source_length))
    batches = Cache(sorted_batches, ConstantScheme(batch_size))
    shuffled_batches = Shuffle(batches, buffer_multiplier)
    masked_batches = Padding(shuffled_batches)

    return masked_batches