def test_word_based_define_bucket_batch_sizes(length_ratio): batch_by_words = True batch_num_devices = 1 batch_size = 200 max_seq_len = 100 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, length_ratio) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, batch_size=batch_size, batch_by_words=batch_by_words, batch_num_devices=batch_num_devices, data_target_average_len=[None] * len(buckets)) max_num_words = 0 # last bucket batch size is different for bbs in bucket_batch_sizes[:-1]: target_padded_seq_len = bbs.bucket[1] expected_batch_size = round((batch_size / target_padded_seq_len) / batch_num_devices) assert bbs.batch_size == expected_batch_size expected_average_words_per_batch = expected_batch_size * bbs.bucket[1] assert bbs.average_words_per_batch == expected_average_words_per_batch max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket)) last_bbs = bucket_batch_sizes[-1] min_expected_batch_size = round((batch_size / last_bbs.bucket[1]) / batch_num_devices) assert last_bbs.batch_size >= min_expected_batch_size last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket) assert last_bbs_num_words >= max_num_words
def test_sharded_parallel_sample_iter_num_batches(): num_shards = 2 batch_size = 2 num_batches_per_bucket = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] num_batches_per_shard = num_batches_per_bucket * len(buckets) num_batches = num_shards * num_batches_per_shard bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard1_fname = os.path.join(work_dir, 'shard1') shard2_fname = os.path.join(work_dir, 'shard2') dataset1.save(shard1_fname) dataset2.save(shard2_fname) shard_fnames = [shard1_fname, shard2_fname] it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes) num_batches_seen = 0 while it.iter_next(): it.next() num_batches_seen += 1 assert num_batches_seen == num_batches
def test_parallel_sample_iter(): batch_size = 2 buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0) # The first bucket is going to be empty: bucket_counts = [0] + [None] * (len(buckets) - 1) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_type=C.BATCH_TYPE_SENTENCE, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) it = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) with TemporaryDirectory() as work_dir: # Test 1 it.next() expected_batch = it.next() fname = os.path.join(work_dir, "saved_iter") it.save_state(fname) it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 2 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 3 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) while it.iter_next(): it.next() it_loaded.next() assert not it_loaded.iter_next()
def test_word_based_define_bucket_batch_sizes(length_ratio): batch_by_words = True batch_num_devices = 1 batch_size = 200 max_seq_len = 100 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, length_ratio) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets=buckets, batch_size=batch_size, batch_by_words=batch_by_words, batch_num_devices=batch_num_devices, data_target_average_len=[None] * len(buckets)) max_num_words = 0 # last bucket batch size is different for bbs in bucket_batch_sizes[:-1]: target_padded_seq_len = bbs.bucket[1] expected_batch_size = round( (batch_size / target_padded_seq_len) / batch_num_devices) assert bbs.batch_size == expected_batch_size expected_average_words_per_batch = expected_batch_size * bbs.bucket[1] assert bbs.average_words_per_batch == expected_average_words_per_batch max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket)) last_bbs = bucket_batch_sizes[-1] min_expected_batch_size = round( (batch_size / last_bbs.bucket[1]) / batch_num_devices) assert last_bbs.batch_size >= min_expected_batch_size last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket) assert last_bbs_num_words >= max_num_words
def test_sharded_parallel_sample_iter_num_batches(): num_shards = 2 batch_size = 2 num_batches_per_bucket = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] num_batches_per_shard = num_batches_per_bucket * len(buckets) num_batches = num_shards * num_batches_per_shard bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard1_fname = os.path.join(work_dir, 'shard1') shard2_fname = os.path.join(work_dir, 'shard2') dataset1.save(shard1_fname) dataset2.save(shard2_fname) shard_fnames = [shard1_fname, shard2_fname] it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') num_batches_seen = 0 while it.iter_next(): it.next() num_batches_seen += 1 assert num_batches_seen == num_batches
def test_word_based_define_bucket_batch_sizes(length_ratio, batch_sentences_multiple_of, expected_batch_sizes): batch_by_words = True batch_num_devices = 1 batch_size = 1000 max_seq_len = 50 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1, length_ratio) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets=buckets, batch_size=batch_size, batch_by_words=batch_by_words, batch_num_devices=batch_num_devices, data_target_average_len=[None] * len(buckets), batch_sentences_multiple_of=batch_sentences_multiple_of) max_num_words = 0 # last bucket batch size is different for bbs, expected_batch_size in zip(bucket_batch_sizes, expected_batch_sizes): assert bbs.batch_size == expected_batch_size expected_average_target_words_per_batch = expected_batch_size * bbs.bucket[ 1] assert bbs.average_target_words_per_batch == expected_average_target_words_per_batch max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket)) last_bbs = bucket_batch_sizes[-1] min_expected_batch_size = round( (batch_size / last_bbs.bucket[1]) / batch_num_devices) assert last_bbs.batch_size >= min_expected_batch_size last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket) assert last_bbs_num_words >= max_num_words
def test_define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, length_ratio, expected_buckets): buckets = data_io.define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width=bucket_width, length_ratio=length_ratio) assert buckets == expected_buckets
def test_parallel_data_set_permute(): batch_size = 5 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up( bucket_batch_sizes, 'replicate') permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts()) assert len(permutations) == len(inverse_permutations) == len(dataset) dataset_restored = dataset.permute(permutations).permute(inverse_permutations) assert len(dataset) == len(dataset_restored) for buck_idx in range(len(dataset)): num_samples = dataset.source[buck_idx].shape[0] if num_samples: assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all() assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all() assert (dataset.label[buck_idx] == dataset_restored.label[buck_idx]).asnumpy().all() else: assert not dataset_restored.source[buck_idx] assert not dataset_restored.target[buck_idx] assert not dataset_restored.label[buck_idx]
def test_get_batch_indices(): max_bucket_size = 50 batch_size = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets=buckets, min_count=1, max_count=max_bucket_size)) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) # check for valid indices for buck_idx, start_pos in indices: assert 0 <= buck_idx < len(dataset) assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 # check that all indices are used for a filled-up dataset dataset = dataset.fill_up(bucket_batch_sizes) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) all_bucket_indices = set(list(range(len(dataset)))) computed_bucket_indices = set([i for i, j in indices]) assert not all_bucket_indices - computed_bucket_indices
def test_parallel_data_set_permute(): batch_size = 5 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5)).fill_up(bucket_batch_sizes) permutations, inverse_permutations = data_io.get_permutations( dataset.get_bucket_counts()) assert len(permutations) == len(inverse_permutations) == len(dataset) dataset_restored = dataset.permute(permutations).permute( inverse_permutations) assert len(dataset) == len(dataset_restored) for buck_idx in range(len(dataset)): num_samples = dataset.source[buck_idx].shape[0] if num_samples: assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all() assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all() else: assert not dataset_restored.source[buck_idx] assert not dataset_restored.target[buck_idx]
def test_define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, bucket_scaling, length_ratio, expected_buckets): pytest.importorskip('mxnet') from sockeye import data_io buckets = data_io.define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width=bucket_width, bucket_scaling=bucket_scaling, length_ratio=length_ratio) assert buckets == expected_buckets
def test_get_batch_indices(): pytest.importorskip('mxnet') from sockeye import data_io max_bucket_size = 50 batch_size = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_type=C.BATCH_TYPE_SENTENCE, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets=buckets, min_count=1, max_count=max_bucket_size)) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) # check for valid indices for buck_idx, start_pos in indices: assert 0 <= buck_idx < len(dataset) assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 # check that all indices are used for a filled-up dataset dataset = dataset.fill_up(bucket_batch_sizes) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) all_bucket_indices = set(list(range(len(dataset)))) computed_bucket_indices = set([i for i, j in indices]) assert not all_bucket_indices - computed_bucket_indices
def test_parallel_data_set_permute(): pytest.importorskip('mxnet') from sockeye import data_io batch_size = 5 buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_type=C.BATCH_TYPE_SENTENCE, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up( bucket_batch_sizes) permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts()) assert len(permutations) == len(inverse_permutations) == len(dataset) dataset_restored = dataset.permute(permutations).permute(inverse_permutations) assert len(dataset) == len(dataset_restored) for buck_idx in range(len(dataset)): num_samples = dataset.source[buck_idx].shape[0] if num_samples: assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).all() assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).all() else: assert not dataset_restored.source[buck_idx] assert not dataset_restored.target[buck_idx]
def test_word_based_define_bucket_batch_sizes(batch_num_devices, length_ratio, batch_sentences_multiple_of, expected_batch_sizes): pytest.importorskip('mxnet') from sockeye import data_io batch_type = C.BATCH_TYPE_WORD batch_size = 1000 max_seq_len = 50 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, True, length_ratio) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, batch_size=batch_size, batch_type=batch_type, batch_num_devices=batch_num_devices, data_target_average_len=[None] * len(buckets), batch_sentences_multiple_of=batch_sentences_multiple_of) max_num_words = 0 # last bucket batch size is different for bbs, expected_batch_size in zip(bucket_batch_sizes, expected_batch_sizes): assert bbs.batch_size == expected_batch_size expected_average_target_words_per_batch = expected_batch_size * bbs.bucket[1] assert bbs.average_target_words_per_batch == expected_average_target_words_per_batch max_num_words = max(max_num_words, bbs.batch_size * max(*bbs.bucket)) last_bbs = bucket_batch_sizes[-1] min_expected_batch_size = round((batch_size / last_bbs.bucket[1]) / batch_num_devices) assert last_bbs.batch_size >= min_expected_batch_size last_bbs_num_words = last_bbs.batch_size * max(*last_bbs.bucket) assert last_bbs_num_words >= max_num_words
def test_get_batch_indices(): max_bucket_size = 50 batch_size = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets=buckets, min_count=1, max_count=max_bucket_size)) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) # check for valid indices for buck_idx, start_pos in indices: assert 0 <= buck_idx < len(dataset) assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 # check that all indices are used for a filled-up dataset dataset = dataset.fill_up(bucket_batch_sizes, fill_up='replicate') indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) all_bucket_indices = set(list(range(len(dataset)))) computed_bucket_indices = set([i for i, j in indices]) assert not all_bucket_indices - computed_bucket_indices
def test_parallel_sample_iter(): batch_size = 2 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) # The first bucket is going to be empty: bucket_counts = [0] + [None] * (len(buckets) - 1) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) it = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) with TemporaryDirectory() as work_dir: # Test 1 it.next() expected_batch = it.next() fname = os.path.join(work_dir, "saved_iter") it.save_state(fname) it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 2 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 3 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) while it.iter_next(): it.next() it_loaded.next() assert not it_loaded.iter_next()
def test_sharded_and_parallel_iter_same_num_batches(): """ Tests that a sharded data iterator with just a single shard produces as many shards as an iterator directly using the same dataset. """ batch_size = 2 num_batches_per_bucket = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] num_batches = num_batches_per_bucket * len(buckets) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard_fname = os.path.join(work_dir, 'shard1') dataset.save(shard_fname) shard_fnames = [shard_fname] it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) num_batches_seen = 0 while it_parallel.iter_next(): assert it_sharded.iter_next() it_parallel.next() it_sharded.next() num_batches_seen += 1 assert num_batches_seen == num_batches print("Resetting...") it_sharded.reset() it_parallel.reset() num_batches_seen = 0 while it_parallel.iter_next(): assert it_sharded.iter_next() it_parallel.next() it_sharded.next() num_batches_seen += 1 assert num_batches_seen == num_batches
def test_sample_based_define_bucket_batch_sizes(): batch_by_words = False batch_size = 32 max_seq_len = 100 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, batch_size=batch_size, batch_by_words=batch_by_words, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) for bbs in bucket_batch_sizes: assert bbs.batch_size == batch_size assert bbs.average_words_per_batch == bbs.bucket[1] * batch_size
def test_sample_based_define_bucket_batch_sizes(): batch_type = C.BATCH_TYPE_SENTENCE batch_size = 32 max_seq_len = 100 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, True, 1.5) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets=buckets, batch_size=batch_size, batch_type=batch_type, data_target_average_len=[None] * len(buckets)) for bbs in bucket_batch_sizes: assert bbs.batch_size == batch_size assert bbs.average_target_words_per_batch == bbs.bucket[1] * batch_size
def test_sample_based_define_bucket_batch_sizes(): pytest.importorskip('mxnet') from sockeye import data_io batch_type = C.BATCH_TYPE_SENTENCE batch_size = 32 max_seq_len = 100 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1, 1.5) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, batch_size=batch_size, batch_type=batch_type, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) for bbs in bucket_batch_sizes: assert bbs.batch_size == batch_size assert bbs.average_target_words_per_batch == bbs.bucket[1] * batch_size
def test_sharded_and_parallel_iter_same_num_batches(): """ Tests that a sharded data iterator with just a single shard produces as many shards as an iterator directly using the same dataset. """ batch_size = 2 num_batches_per_bucket = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] num_batches = num_batches_per_bucket * len(buckets) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard_fname = os.path.join(work_dir, 'shard1') dataset.save(shard_fname) shard_fnames = [shard_fname] it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) num_batches_seen = 0 while it_parallel.iter_next(): assert it_sharded.iter_next() it_parallel.next() it_sharded.next() num_batches_seen += 1 assert num_batches_seen == num_batches print("Resetting...") it_sharded.reset() it_parallel.reset() num_batches_seen = 0 while it_parallel.iter_next(): assert it_sharded.iter_next() it_parallel.next() it_sharded.next() num_batches_seen += 1 assert num_batches_seen == num_batches
def test_max_word_based_define_bucket_batch_sizes(length_ratio, batch_sentences_multiple_of, expected_batch_sizes): batch_type = C.BATCH_TYPE_MAX_WORD batch_num_devices = 1 batch_size = 1000 max_seq_len = 50 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, True, length_ratio) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, batch_size=batch_size, batch_type=batch_type, batch_num_devices=batch_num_devices, data_target_average_len=[None] * len(buckets), batch_sentences_multiple_of=batch_sentences_multiple_of) for bbs, expected_batch_size in zip(bucket_batch_sizes, expected_batch_sizes): assert bbs.batch_size == expected_batch_size expected_average_target_words_per_batch = expected_batch_size * bbs.bucket[1] assert bbs.average_target_words_per_batch == expected_average_target_words_per_batch
def test_parallel_data_set(): buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) source, target = _get_random_bucketed_data(buckets, min_count=0, max_count=5) def check_equal(arrays1, arrays2): assert len(arrays1) == len(arrays2) for a1, a2 in zip(arrays1, arrays2): assert np.array_equal(a1.asnumpy(), a2.asnumpy()) with TemporaryDirectory() as work_dir: dataset = data_io.ParallelDataSet(source, target) fname = os.path.join(work_dir, 'dataset') dataset.save(fname) dataset_loaded = data_io.ParallelDataSet.load(fname) check_equal(dataset.source, dataset_loaded.source) check_equal(dataset.target, dataset_loaded.target)
def test_parallel_data_set(): buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) source, target, label = _get_random_bucketed_data(buckets, min_count=0, max_count=5) def check_equal(arrays1, arrays2): assert len(arrays1) == len(arrays2) for a1, a2 in zip(arrays1, arrays2): assert np.array_equal(a1.asnumpy(), a2.asnumpy()) with TemporaryDirectory() as work_dir: dataset = data_io.ParallelDataSet(source, target, label) fname = os.path.join(work_dir, 'dataset') dataset.save(fname) dataset_loaded = data_io.ParallelDataSet.load(fname) check_equal(dataset.source, dataset_loaded.source) check_equal(dataset.target, dataset_loaded.target) check_equal(dataset.label, dataset_loaded.label)
def test_parallel_data_set_fill_up(): batch_size = 32 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_type=C.BATCH_TYPE_SENTENCE, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=1, max_count=5)) dataset_filled_up = dataset.fill_up(bucket_batch_sizes) assert len(dataset_filled_up.source) == len(dataset.source) assert len(dataset_filled_up.target) == len(dataset.target) for bidx in range(len(dataset)): bucket_batch_size = bucket_batch_sizes[bidx].batch_size assert dataset_filled_up.source[bidx].shape[0] == bucket_batch_size assert dataset_filled_up.target[bidx].shape[0] == bucket_batch_size
def test_word_based_define_bucket_batch_sizes(): batch_by_words = True batch_num_devices = 1 batch_size = 200 max_seq_len = 100 buckets = data_io.define_parallel_buckets(max_seq_len, max_seq_len, 10, 1.5) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets=buckets, batch_size=batch_size, batch_by_words=batch_by_words, batch_num_devices=batch_num_devices, data_target_average_len=[None] * len(buckets)) # last bucket batch size is different for bbs in bucket_batch_sizes[:-1]: expected_batch_size = round((batch_size / bbs.bucket[1]) / batch_num_devices) assert bbs.batch_size == expected_batch_size expected_average_words_per_batch = expected_batch_size * bbs.bucket[1] assert bbs.average_words_per_batch == expected_average_words_per_batch
def test_parallel_data_set_fill_up(): batch_size = 32 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=1, max_count=5)) dataset_filled_up = dataset.fill_up(bucket_batch_sizes, 'replicate') assert len(dataset_filled_up.source) == len(dataset.source) assert len(dataset_filled_up.target) == len(dataset.target) assert len(dataset_filled_up.label) == len(dataset.label) for bidx in range(len(dataset)): bucket_batch_size = bucket_batch_sizes[bidx].batch_size assert dataset_filled_up.source[bidx].shape[0] == bucket_batch_size assert dataset_filled_up.target[bidx].shape[0] == bucket_batch_size assert dataset_filled_up.label[bidx].shape[0] == bucket_batch_size
def test_sharded_parallel_sample_iter(): batch_size = 2 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) # The first bucket is going to be empty: bucket_counts = [0] + [None] * (len(buckets) - 1) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard1_fname = os.path.join(work_dir, 'shard1') shard2_fname = os.path.join(work_dir, 'shard2') dataset1.save(shard1_fname) dataset2.save(shard2_fname) shard_fnames = [shard1_fname, shard2_fname] it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') # Test 1 it.next() expected_batch = it.next() fname = os.path.join(work_dir, "saved_iter") it.save_state(fname) it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 2 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 3 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) while it.iter_next(): it.next() it_loaded.next() assert not it_loaded.iter_next()