def test_get_batch_indices(): max_bucket_size = 50 batch_size = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets=buckets, min_count=1, max_count=max_bucket_size)) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) # check for valid indices for buck_idx, start_pos in indices: assert 0 <= buck_idx < len(dataset) assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 # check that all indices are used for a filled-up dataset dataset = dataset.fill_up(bucket_batch_sizes, fill_up='replicate') indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) all_bucket_indices = set(list(range(len(dataset)))) computed_bucket_indices = set([i for i, j in indices]) assert not all_bucket_indices - computed_bucket_indices
def test_get_batch_indices(): max_bucket_size = 50 batch_size = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets=buckets, min_count=1, max_count=max_bucket_size)) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) # check for valid indices for buck_idx, start_pos in indices: assert 0 <= buck_idx < len(dataset) assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 # check that all indices are used for a filled-up dataset dataset = dataset.fill_up(bucket_batch_sizes) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) all_bucket_indices = set(list(range(len(dataset)))) computed_bucket_indices = set([i for i, j in indices]) assert not all_bucket_indices - computed_bucket_indices
def test_get_batch_indices(): pytest.importorskip('mxnet') from sockeye import data_io max_bucket_size = 50 batch_size = 10 buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_type=C.BATCH_TYPE_SENTENCE, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets=buckets, min_count=1, max_count=max_bucket_size)) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) # check for valid indices for buck_idx, start_pos in indices: assert 0 <= buck_idx < len(dataset) assert 0 <= start_pos < len(dataset.source[buck_idx]) - batch_size + 1 # check that all indices are used for a filled-up dataset dataset = dataset.fill_up(bucket_batch_sizes) indices = data_io.get_batch_indices(dataset, bucket_batch_sizes=bucket_batch_sizes) all_bucket_indices = set(list(range(len(dataset)))) computed_bucket_indices = set([i for i, j in indices]) assert not all_bucket_indices - computed_bucket_indices