コード例 #1
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_word_based_define_bucket_batch_sizes(length_ratio):
    batch_by_words = True
    batch_num_devices = 1
    batch_size = 200
    max_seq_len = 100
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, length_ratio)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
                                                           batch_size=batch_size,
                                                           batch_by_words=batch_by_words,
                                                           batch_num_devices=batch_num_devices,
                                                           data_target_average_len=[None] * len(buckets))
    max_num_words = 0
    # last bucket batch size is different
    for bbs in bucket_batch_sizes[:-1]:
        target_padded_seq_len = bbs.bucket[1]
        expected_batch_size = round((batch_size / target_padded_seq_len) / batch_num_devices)
        assert bbs.batch_size == expected_batch_size
        expected_average_words_per_batch = expected_batch_size * bbs.bucket[1]
        assert bbs.average_words_per_batch == expected_average_words_per_batch
        max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket))

    last_bbs = bucket_batch_sizes[-1]
    min_expected_batch_size = round((batch_size / last_bbs.bucket[1]) / batch_num_devices)
    assert last_bbs.batch_size >= min_expected_batch_size
    last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket)
    assert last_bbs_num_words >= max_num_words
コード例 #2
0
def test_sharded_parallel_sample_iter_num_batches():
    num_shards = 2
    batch_size = 2
    num_batches_per_bucket = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0)
    bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets]
    num_batches_per_shard = num_batches_per_bucket * len(buckets)
    num_batches = num_shards * num_batches_per_shard
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_by_words=False,
        batch_num_devices=1,
        data_target_average_len=[None] * len(buckets))

    dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5, bucket_counts=bucket_counts))
    dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5, bucket_counts=bucket_counts))
    with TemporaryDirectory() as work_dir:
        shard1_fname = os.path.join(work_dir, 'shard1')
        shard2_fname = os.path.join(work_dir, 'shard2')
        dataset1.save(shard1_fname)
        dataset2.save(shard2_fname)
        shard_fnames = [shard1_fname, shard2_fname]

        it = data_io.ShardedParallelSampleIter(shard_fnames, buckets,
                                               batch_size, bucket_batch_sizes)

        num_batches_seen = 0
        while it.iter_next():
            it.next()
            num_batches_seen += 1
        assert num_batches_seen == num_batches
コード例 #3
0
def test_parallel_sample_iter():
    batch_size = 2
    buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0)
    # The first bucket is going to be empty:
    bucket_counts = [0] + [None] * (len(buckets) - 1)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_type=C.BATCH_TYPE_SENTENCE,
        data_target_average_len=[None] * len(buckets))

    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5, bucket_counts=bucket_counts))
    it = data_io.ParallelSampleIter(dataset, buckets, batch_size,
                                    bucket_batch_sizes)

    with TemporaryDirectory() as work_dir:
        # Test 1
        it.next()
        expected_batch = it.next()

        fname = os.path.join(work_dir, "saved_iter")
        it.save_state(fname)

        it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size,
                                               bucket_batch_sizes)
        it_loaded.reset()
        it_loaded.load_state(fname)
        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        # Test 2
        it.reset()
        expected_batch = it.next()
        it.save_state(fname)

        it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size,
                                               bucket_batch_sizes)
        it_loaded.reset()
        it_loaded.load_state(fname)

        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        # Test 3
        it.reset()
        expected_batch = it.next()
        it.save_state(fname)
        it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size,
                                               bucket_batch_sizes)
        it_loaded.reset()
        it_loaded.load_state(fname)

        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        while it.iter_next():
            it.next()
            it_loaded.next()
        assert not it_loaded.iter_next()
コード例 #4
0
def test_word_based_define_bucket_batch_sizes(length_ratio):
    batch_by_words = True
    batch_num_devices = 1
    batch_size = 200
    max_seq_len = 100
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10,
                                              length_ratio)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets=buckets,
        batch_size=batch_size,
        batch_by_words=batch_by_words,
        batch_num_devices=batch_num_devices,
        data_target_average_len=[None] * len(buckets))
    max_num_words = 0
    # last bucket batch size is different
    for bbs in bucket_batch_sizes[:-1]:
        target_padded_seq_len = bbs.bucket[1]
        expected_batch_size = round(
            (batch_size / target_padded_seq_len) / batch_num_devices)
        assert bbs.batch_size == expected_batch_size
        expected_average_words_per_batch = expected_batch_size * bbs.bucket[1]
        assert bbs.average_words_per_batch == expected_average_words_per_batch
        max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket))

    last_bbs = bucket_batch_sizes[-1]
    min_expected_batch_size = round(
        (batch_size / last_bbs.bucket[1]) / batch_num_devices)
    assert last_bbs.batch_size >= min_expected_batch_size
    last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket)
    assert last_bbs_num_words >= max_num_words
コード例 #5
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_sharded_parallel_sample_iter_num_batches():
    num_shards = 2
    batch_size = 2
    num_batches_per_bucket = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets]
    num_batches_per_shard = num_batches_per_bucket * len(buckets)
    num_batches = num_shards * num_batches_per_shard
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))

    dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                  bucket_counts=bucket_counts))
    dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                  bucket_counts=bucket_counts))
    with TemporaryDirectory() as work_dir:
        shard1_fname = os.path.join(work_dir, 'shard1')
        shard2_fname = os.path.join(work_dir, 'shard2')
        dataset1.save(shard1_fname)
        dataset2.save(shard2_fname)
        shard_fnames = [shard1_fname, shard2_fname]

        it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes,
                                               'replicate')

        num_batches_seen = 0
        while it.iter_next():
            it.next()
            num_batches_seen += 1
        assert num_batches_seen == num_batches
コード例 #6
0
def test_word_based_define_bucket_batch_sizes(length_ratio,
                                              batch_sentences_multiple_of,
                                              expected_batch_sizes):
    batch_by_words = True
    batch_num_devices = 1
    batch_size = 1000
    max_seq_len = 50
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1,
                                              length_ratio)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets=buckets,
        batch_size=batch_size,
        batch_by_words=batch_by_words,
        batch_num_devices=batch_num_devices,
        data_target_average_len=[None] * len(buckets),
        batch_sentences_multiple_of=batch_sentences_multiple_of)
    max_num_words = 0
    # last bucket batch size is different
    for bbs, expected_batch_size in zip(bucket_batch_sizes,
                                        expected_batch_sizes):
        assert bbs.batch_size == expected_batch_size
        expected_average_target_words_per_batch = expected_batch_size * bbs.bucket[
            1]
        assert bbs.average_target_words_per_batch == expected_average_target_words_per_batch
        max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket))

    last_bbs = bucket_batch_sizes[-1]
    min_expected_batch_size = round(
        (batch_size / last_bbs.bucket[1]) / batch_num_devices)
    assert last_bbs.batch_size >= min_expected_batch_size
    last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket)
    assert last_bbs_num_words >= max_num_words
コード例 #7
0
def test_define_parallel_buckets(max_seq_len_source, max_seq_len_target,
                                 bucket_width, length_ratio, expected_buckets):
    buckets = data_io.define_parallel_buckets(max_seq_len_source,
                                              max_seq_len_target,
                                              bucket_width=bucket_width,
                                              length_ratio=length_ratio)
    assert buckets == expected_buckets
コード例 #8
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_parallel_data_set_permute():
    batch_size = 5
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up(
        bucket_batch_sizes, 'replicate')

    permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts())

    assert len(permutations) == len(inverse_permutations) == len(dataset)
    dataset_restored = dataset.permute(permutations).permute(inverse_permutations)
    assert len(dataset) == len(dataset_restored)
    for buck_idx in range(len(dataset)):
        num_samples = dataset.source[buck_idx].shape[0]
        if num_samples:
            assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all()
            assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all()
            assert (dataset.label[buck_idx] == dataset_restored.label[buck_idx]).asnumpy().all()
        else:
            assert not dataset_restored.source[buck_idx]
            assert not dataset_restored.target[buck_idx]
            assert not dataset_restored.label[buck_idx]
コード例 #9
0
def test_get_batch_indices():
    max_bucket_size = 50
    batch_size = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_by_words=False,
        batch_num_devices=1,
        data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets=buckets, min_count=1, max_count=max_bucket_size))

    indices = data_io.get_batch_indices(dataset,
                                        bucket_batch_sizes=bucket_batch_sizes)

    # check for valid indices
    for buck_idx, start_pos in indices:
        assert 0 <= buck_idx < len(dataset)
        assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1

    # check that all indices are used for a filled-up dataset
    dataset = dataset.fill_up(bucket_batch_sizes)
    indices = data_io.get_batch_indices(dataset,
                                        bucket_batch_sizes=bucket_batch_sizes)
    all_bucket_indices = set(list(range(len(dataset))))
    computed_bucket_indices = set([i for i, j in indices])

    assert not all_bucket_indices - computed_bucket_indices
コード例 #10
0
def test_parallel_data_set_permute():
    batch_size = 5
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_by_words=False,
        batch_num_devices=1,
        data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5)).fill_up(bucket_batch_sizes)

    permutations, inverse_permutations = data_io.get_permutations(
        dataset.get_bucket_counts())

    assert len(permutations) == len(inverse_permutations) == len(dataset)
    dataset_restored = dataset.permute(permutations).permute(
        inverse_permutations)
    assert len(dataset) == len(dataset_restored)
    for buck_idx in range(len(dataset)):
        num_samples = dataset.source[buck_idx].shape[0]
        if num_samples:
            assert (dataset.source[buck_idx] ==
                    dataset_restored.source[buck_idx]).asnumpy().all()
            assert (dataset.target[buck_idx] ==
                    dataset_restored.target[buck_idx]).asnumpy().all()
        else:
            assert not dataset_restored.source[buck_idx]
            assert not dataset_restored.target[buck_idx]
コード例 #11
0
def test_define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, bucket_scaling, length_ratio,
                                 expected_buckets):
    pytest.importorskip('mxnet')
    from sockeye import data_io
    buckets = data_io.define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width=bucket_width,
                                              bucket_scaling=bucket_scaling, length_ratio=length_ratio)
    assert buckets == expected_buckets
コード例 #12
0
def test_get_batch_indices():
    pytest.importorskip('mxnet')
    from sockeye import data_io
    max_bucket_size = 50
    batch_size = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_type=C.BATCH_TYPE_SENTENCE,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets=buckets,
                                                                 min_count=1,
                                                                 max_count=max_bucket_size))

    indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes)

    # check for valid indices
    for buck_idx, start_pos in indices:
        assert 0 <= buck_idx < len(dataset)
        assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1

    # check that all indices are used for a filled-up dataset
    dataset = dataset.fill_up(bucket_batch_sizes)
    indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes)
    all_bucket_indices = set(list(range(len(dataset))))
    computed_bucket_indices = set([i for i, j in indices])

    assert not all_bucket_indices - computed_bucket_indices
コード例 #13
0
def test_parallel_data_set_permute():
    pytest.importorskip('mxnet')
    from sockeye import data_io
    batch_size = 5
    buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_type=C.BATCH_TYPE_SENTENCE,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up(
        bucket_batch_sizes)

    permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts())

    assert len(permutations) == len(inverse_permutations) == len(dataset)
    dataset_restored = dataset.permute(permutations).permute(inverse_permutations)
    assert len(dataset) == len(dataset_restored)
    for buck_idx in range(len(dataset)):
        num_samples = dataset.source[buck_idx].shape[0]
        if num_samples:
            assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).all()
            assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).all()
        else:
            assert not dataset_restored.source[buck_idx]
            assert not dataset_restored.target[buck_idx]
コード例 #14
0
def test_word_based_define_bucket_batch_sizes(batch_num_devices, length_ratio, batch_sentences_multiple_of,
                                              expected_batch_sizes):
    pytest.importorskip('mxnet')
    from sockeye import data_io
    batch_type = C.BATCH_TYPE_WORD
    batch_size = 1000
    max_seq_len = 50
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, True, length_ratio)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
                                                           batch_size=batch_size,
                                                           batch_type=batch_type,
                                                           batch_num_devices=batch_num_devices,
                                                           data_target_average_len=[None] * len(buckets),
                                                           batch_sentences_multiple_of=batch_sentences_multiple_of)
    max_num_words = 0
    # last bucket batch size is different
    for bbs, expected_batch_size in zip(bucket_batch_sizes, expected_batch_sizes):
        assert bbs.batch_size == expected_batch_size
        expected_average_target_words_per_batch = expected_batch_size * bbs.bucket[1]
        assert bbs.average_target_words_per_batch == expected_average_target_words_per_batch
        max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket))

    last_bbs = bucket_batch_sizes[-1]
    min_expected_batch_size = round((batch_size / last_bbs.bucket[1]) / batch_num_devices)
    assert last_bbs.batch_size >= min_expected_batch_size
    last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket)
    assert last_bbs_num_words >= max_num_words
コード例 #15
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_get_batch_indices():
    max_bucket_size = 50
    batch_size = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets=buckets,
                                                                 min_count=1,
                                                                 max_count=max_bucket_size))

    indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes)

    # check for valid indices
    for buck_idx, start_pos in indices:
        assert 0 <= buck_idx < len(dataset)
        assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1

    # check that all indices are used for a filled-up dataset
    dataset = dataset.fill_up(bucket_batch_sizes, fill_up='replicate')
    indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes)
    all_bucket_indices = set(list(range(len(dataset))))
    computed_bucket_indices = set([i for i, j in indices])

    assert not all_bucket_indices - computed_bucket_indices
コード例 #16
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_parallel_sample_iter():
    batch_size = 2
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    # The first bucket is going to be empty:
    bucket_counts = [0] + [None] * (len(buckets) - 1)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))

    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                 bucket_counts=bucket_counts))
    it = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)

    with TemporaryDirectory() as work_dir:
        # Test 1
        it.next()
        expected_batch = it.next()

        fname = os.path.join(work_dir, "saved_iter")
        it.save_state(fname)

        it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
        it_loaded.reset()
        it_loaded.load_state(fname)
        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        # Test 2
        it.reset()
        expected_batch = it.next()
        it.save_state(fname)

        it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
        it_loaded.reset()
        it_loaded.load_state(fname)

        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        # Test 3
        it.reset()
        expected_batch = it.next()
        it.save_state(fname)
        it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
        it_loaded.reset()
        it_loaded.load_state(fname)

        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        while it.iter_next():
            it.next()
            it_loaded.next()
        assert not it_loaded.iter_next()
コード例 #17
0
def test_sharded_and_parallel_iter_same_num_batches():
    """ Tests that a sharded data iterator with just a single shard produces as many shards as an iterator directly
    using the same dataset. """
    batch_size = 2
    num_batches_per_bucket = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets]
    num_batches = num_batches_per_bucket * len(buckets)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_by_words=False,
        batch_num_devices=1,
        data_target_average_len=[None] * len(buckets))

    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5, bucket_counts=bucket_counts))

    with TemporaryDirectory() as work_dir:
        shard_fname = os.path.join(work_dir, 'shard1')
        dataset.save(shard_fname)
        shard_fnames = [shard_fname]

        it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets,
                                                       batch_size,
                                                       bucket_batch_sizes,
                                                       'replicate')

        it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size,
                                                 bucket_batch_sizes)

        num_batches_seen = 0
        while it_parallel.iter_next():
            assert it_sharded.iter_next()
            it_parallel.next()
            it_sharded.next()
            num_batches_seen += 1
        assert num_batches_seen == num_batches

        print("Resetting...")
        it_sharded.reset()
        it_parallel.reset()

        num_batches_seen = 0
        while it_parallel.iter_next():
            assert it_sharded.iter_next()
            it_parallel.next()
            it_sharded.next()

            num_batches_seen += 1

        assert num_batches_seen == num_batches
コード例 #18
0
def test_sample_based_define_bucket_batch_sizes():
    batch_by_words = False
    batch_size = 32
    max_seq_len = 100
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
                                                           batch_size=batch_size,
                                                           batch_by_words=batch_by_words,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    for bbs in bucket_batch_sizes:
        assert bbs.batch_size == batch_size
        assert bbs.average_words_per_batch == bbs.bucket[1] * batch_size
コード例 #19
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_sample_based_define_bucket_batch_sizes():
    batch_by_words = False
    batch_size = 32
    max_seq_len = 100
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
                                                           batch_size=batch_size,
                                                           batch_by_words=batch_by_words,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    for bbs in bucket_batch_sizes:
        assert bbs.batch_size == batch_size
        assert bbs.average_words_per_batch == bbs.bucket[1] * batch_size
コード例 #20
0
def test_sample_based_define_bucket_batch_sizes():
    batch_type = C.BATCH_TYPE_SENTENCE
    batch_size = 32
    max_seq_len = 100
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10,
                                              True, 1.5)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets=buckets,
        batch_size=batch_size,
        batch_type=batch_type,
        data_target_average_len=[None] * len(buckets))
    for bbs in bucket_batch_sizes:
        assert bbs.batch_size == batch_size
        assert bbs.average_target_words_per_batch == bbs.bucket[1] * batch_size
コード例 #21
0
def test_sample_based_define_bucket_batch_sizes():
    pytest.importorskip('mxnet')
    from sockeye import data_io
    batch_type = C.BATCH_TYPE_SENTENCE
    batch_size = 32
    max_seq_len = 100
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1, 1.5)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
                                                           batch_size=batch_size,
                                                           batch_type=batch_type,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    for bbs in bucket_batch_sizes:
        assert bbs.batch_size == batch_size
        assert bbs.average_target_words_per_batch == bbs.bucket[1] * batch_size
コード例 #22
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_sharded_and_parallel_iter_same_num_batches():
    """ Tests that a sharded data iterator with just a single shard produces as many shards as an iterator directly
    using the same dataset. """
    batch_size = 2
    num_batches_per_bucket = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets]
    num_batches = num_batches_per_bucket * len(buckets)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))

    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                 bucket_counts=bucket_counts))

    with TemporaryDirectory() as work_dir:
        shard_fname = os.path.join(work_dir, 'shard1')
        dataset.save(shard_fname)
        shard_fnames = [shard_fname]

        it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes,
                                                       'replicate')

        it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)

        num_batches_seen = 0
        while it_parallel.iter_next():
            assert it_sharded.iter_next()
            it_parallel.next()
            it_sharded.next()
            num_batches_seen += 1
        assert num_batches_seen == num_batches

        print("Resetting...")
        it_sharded.reset()
        it_parallel.reset()

        num_batches_seen = 0
        while it_parallel.iter_next():
            assert it_sharded.iter_next()
            it_parallel.next()
            it_sharded.next()

            num_batches_seen += 1

        assert num_batches_seen == num_batches
コード例 #23
0
def test_max_word_based_define_bucket_batch_sizes(length_ratio, batch_sentences_multiple_of, expected_batch_sizes):
    batch_type = C.BATCH_TYPE_MAX_WORD
    batch_num_devices = 1
    batch_size = 1000
    max_seq_len = 50
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, True, length_ratio)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
                                                           batch_size=batch_size,
                                                           batch_type=batch_type,
                                                           batch_num_devices=batch_num_devices,
                                                           data_target_average_len=[None] * len(buckets),
                                                           batch_sentences_multiple_of=batch_sentences_multiple_of)
    for bbs, expected_batch_size in zip(bucket_batch_sizes, expected_batch_sizes):
        assert bbs.batch_size == expected_batch_size
        expected_average_target_words_per_batch = expected_batch_size * bbs.bucket[1]
        assert bbs.average_target_words_per_batch == expected_average_target_words_per_batch
コード例 #24
0
def test_parallel_data_set():
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0)
    source, target = _get_random_bucketed_data(buckets, min_count=0, max_count=5)

    def check_equal(arrays1, arrays2):
        assert len(arrays1) == len(arrays2)
        for a1, a2 in zip(arrays1, arrays2):
            assert np.array_equal(a1.asnumpy(), a2.asnumpy())

    with TemporaryDirectory() as work_dir:
        dataset = data_io.ParallelDataSet(source, target)
        fname = os.path.join(work_dir, 'dataset')
        dataset.save(fname)
        dataset_loaded = data_io.ParallelDataSet.load(fname)
        check_equal(dataset.source, dataset_loaded.source)
        check_equal(dataset.target, dataset_loaded.target)
コード例 #25
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_parallel_data_set():
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    source, target, label = _get_random_bucketed_data(buckets, min_count=0, max_count=5)

    def check_equal(arrays1, arrays2):
        assert len(arrays1) == len(arrays2)
        for a1, a2 in zip(arrays1, arrays2):
            assert np.array_equal(a1.asnumpy(), a2.asnumpy())

    with TemporaryDirectory() as work_dir:
        dataset = data_io.ParallelDataSet(source, target, label)
        fname = os.path.join(work_dir, 'dataset')
        dataset.save(fname)
        dataset_loaded = data_io.ParallelDataSet.load(fname)
        check_equal(dataset.source, dataset_loaded.source)
        check_equal(dataset.target, dataset_loaded.target)
        check_equal(dataset.label, dataset_loaded.label)
コード例 #26
0
def test_parallel_data_set_fill_up():
    batch_size = 32
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_type=C.BATCH_TYPE_SENTENCE,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=1, max_count=5))

    dataset_filled_up = dataset.fill_up(bucket_batch_sizes)
    assert len(dataset_filled_up.source) == len(dataset.source)
    assert len(dataset_filled_up.target) == len(dataset.target)
    for bidx in range(len(dataset)):
        bucket_batch_size = bucket_batch_sizes[bidx].batch_size
        assert dataset_filled_up.source[bidx].shape[0] == bucket_batch_size
        assert dataset_filled_up.target[bidx].shape[0] == bucket_batch_size
コード例 #27
0
def test_word_based_define_bucket_batch_sizes():
    batch_by_words = True
    batch_num_devices = 1
    batch_size = 200
    max_seq_len = 100
    buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets,
                                                           batch_size=batch_size,
                                                           batch_by_words=batch_by_words,
                                                           batch_num_devices=batch_num_devices,
                                                           data_target_average_len=[None] * len(buckets))
    # last bucket batch size is different
    for bbs in bucket_batch_sizes[:-1]:
        expected_batch_size = round((batch_size / bbs.bucket[1]) / batch_num_devices)
        assert bbs.batch_size == expected_batch_size
        expected_average_words_per_batch = expected_batch_size * bbs.bucket[1]
        assert bbs.average_words_per_batch == expected_average_words_per_batch
コード例 #28
0
def test_parallel_data_set_fill_up():
    batch_size = 32
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=1, max_count=5))

    dataset_filled_up = dataset.fill_up(bucket_batch_sizes, 'replicate')
    assert len(dataset_filled_up.source) == len(dataset.source)
    assert len(dataset_filled_up.target) == len(dataset.target)
    assert len(dataset_filled_up.label) == len(dataset.label)
    for bidx in range(len(dataset)):
        bucket_batch_size = bucket_batch_sizes[bidx].batch_size
        assert dataset_filled_up.source[bidx].shape[0] == bucket_batch_size
        assert dataset_filled_up.target[bidx].shape[0] == bucket_batch_size
        assert dataset_filled_up.label[bidx].shape[0] == bucket_batch_size
コード例 #29
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_parallel_data_set_fill_up():
    batch_size = 32
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=1, max_count=5))

    dataset_filled_up = dataset.fill_up(bucket_batch_sizes, 'replicate')
    assert len(dataset_filled_up.source) == len(dataset.source)
    assert len(dataset_filled_up.target) == len(dataset.target)
    assert len(dataset_filled_up.label) == len(dataset.label)
    for bidx in range(len(dataset)):
        bucket_batch_size = bucket_batch_sizes[bidx].batch_size
        assert dataset_filled_up.source[bidx].shape[0] == bucket_batch_size
        assert dataset_filled_up.target[bidx].shape[0] == bucket_batch_size
        assert dataset_filled_up.label[bidx].shape[0] == bucket_batch_size
コード例 #30
0
def test_sharded_parallel_sample_iter():
    batch_size = 2
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0)
    # The first bucket is going to be empty:
    bucket_counts = [0] + [None] * (len(buckets) - 1)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_by_words=False,
        batch_num_devices=1,
        data_target_average_len=[None] * len(buckets))

    dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5, bucket_counts=bucket_counts))
    dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5, bucket_counts=bucket_counts))

    with TemporaryDirectory() as work_dir:
        shard1_fname = os.path.join(work_dir, 'shard1')
        shard2_fname = os.path.join(work_dir, 'shard2')
        dataset1.save(shard1_fname)
        dataset2.save(shard2_fname)
        shard_fnames = [shard1_fname, shard2_fname]

        it = data_io.ShardedParallelSampleIter(shard_fnames, buckets,
                                               batch_size, bucket_batch_sizes,
                                               'replicate')

        # Test 1
        it.next()
        expected_batch = it.next()

        fname = os.path.join(work_dir, "saved_iter")
        it.save_state(fname)

        it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets,
                                                      batch_size,
                                                      bucket_batch_sizes,
                                                      'replicate')
        it_loaded.reset()
        it_loaded.load_state(fname)
        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        # Test 2
        it.reset()
        expected_batch = it.next()
        it.save_state(fname)

        it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets,
                                                      batch_size,
                                                      bucket_batch_sizes,
                                                      'replicate')
        it_loaded.reset()
        it_loaded.load_state(fname)

        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        # Test 3
        it.reset()
        expected_batch = it.next()
        it.save_state(fname)
        it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets,
                                                      batch_size,
                                                      bucket_batch_sizes,
                                                      'replicate')
        it_loaded.reset()
        it_loaded.load_state(fname)

        loaded_batch = it_loaded.next()
        assert _data_batches_equal(expected_batch, loaded_batch)

        while it.iter_next():
            it.next()
            it_loaded.next()
        assert not it_loaded.iter_next()
コード例 #31
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, length_ratio, expected_buckets):
    buckets = data_io.define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width=bucket_width,
                                              length_ratio=length_ratio)
    assert buckets == expected_buckets