def test_sharded_parallel_sample_iter_num_batches(): num_shards = 2 batch_size = 2 num_batches_per_bucket = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] num_batches_per_shard = num_batches_per_bucket * len(buckets) num_batches = num_shards * num_batches_per_shard bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard1_fname = os.path.join(work_dir, 'shard1') shard2_fname = os.path.join(work_dir, 'shard2') dataset1.save(shard1_fname) dataset2.save(shard2_fname) shard_fnames = [shard1_fname, shard2_fname] it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes) num_batches_seen = 0 while it.iter_next(): it.next() num_batches_seen += 1 assert num_batches_seen == num_batches
def test_sharded_and_parallel_iter_same_num_batches(): """ Tests that a sharded data iterator with just a single shard produces as many shards as an iterator directly using the same dataset. """ batch_size = 2 num_batches_per_bucket = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets] num_batches = num_batches_per_bucket * len(buckets) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard_fname = os.path.join(work_dir, 'shard1') dataset.save(shard_fname) shard_fnames = [shard_fname] it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes) num_batches_seen = 0 while it_parallel.iter_next(): assert it_sharded.iter_next() it_parallel.next() it_sharded.next() num_batches_seen += 1 assert num_batches_seen == num_batches print("Resetting...") it_sharded.reset() it_parallel.reset() num_batches_seen = 0 while it_parallel.iter_next(): assert it_sharded.iter_next() it_parallel.next() it_sharded.next() num_batches_seen += 1 assert num_batches_seen == num_batches
def test_sharded_parallel_sample_iter(): batch_size = 2 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) # The first bucket is going to be empty: bucket_counts = [0] + [None] * (len(buckets) - 1) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5, bucket_counts=bucket_counts)) with TemporaryDirectory() as work_dir: shard1_fname = os.path.join(work_dir, 'shard1') shard2_fname = os.path.join(work_dir, 'shard2') dataset1.save(shard1_fname) dataset2.save(shard2_fname) shard_fnames = [shard1_fname, shard2_fname] it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') # Test 1 it.next() expected_batch = it.next() fname = os.path.join(work_dir, "saved_iter") it.save_state(fname) it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 2 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) # Test 3 it.reset() expected_batch = it.next() it.save_state(fname) it_loaded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate') it_loaded.reset() it_loaded.load_state(fname) loaded_batch = it_loaded.next() assert _data_batches_equal(expected_batch, loaded_batch) while it.iter_next(): it.next() it_loaded.next() assert not it_loaded.iter_next()