def test_num_epochs_with_mixture_fails(self): """Mixing datasets with fixed number epochs is not allowed.""" with self.assertRaisesRegex( ValueError, "Using fixed number of epochs is not allowed when mixing datasets." ): index_dataset.create_index_dataset([4, 8], num_epochs=2)
def test_sharding_no_drop_remainder(self): dataset = index_dataset.create_index_dataset( 8, shard_id=0, num_shards=3, sharding_drop_remainder=False) values = list(dataset.take(4).as_numpy_iterator()) # pyformat: disable self.assertEqual(values, [{ "index": 0, "record_key": 0 }, { "index": 3, "record_key": 1 }, { "index": 6, "record_key": 2 }, { "index": 9, "record_key": 0 }]) # pyformat: enable dataset = index_dataset.create_index_dataset( 8, shard_id=1, num_shards=3, sharding_drop_remainder=False) values = list(dataset.take(4).as_numpy_iterator()) # pyformat: disable self.assertEqual(values, [{ "index": 1, "record_key": 3 }, { "index": 4, "record_key": 4 }, { "index": 7, "record_key": 5 }, { "index": 10, "record_key": 3 }]) # pyformat: enable dataset = index_dataset.create_index_dataset( 8, shard_id=2, num_shards=3, sharding_drop_remainder=False) values = list(dataset.take(4).as_numpy_iterator()) # pyformat: disable self.assertEqual(values, [{ "index": 2, "record_key": 6 }, { "index": 5, "record_key": 7 }, { "index": 8, "record_key": 6 }, { "index": 11, "record_key": 7 }])
def test_determinism(self, records_per_dataset, proportions, shuffle: bool, num_shards: int): """Creating the dataset twice gives the same result.""" seed = (3, 84) if shuffle else None dataset = index_dataset.create_index_dataset(records_per_dataset, proportions=proportions, shuffle=shuffle, seed=seed, num_shards=num_shards) values_1 = list(dataset.take(50).as_numpy_iterator()) dataset = index_dataset.create_index_dataset(records_per_dataset, proportions=proportions, shuffle=shuffle, seed=seed, num_shards=num_shards) values_2 = list(dataset.take(50).as_numpy_iterator()) self.assertAllEqual(values_1, values_2)
def test_start_index(self, records_per_dataset, proportions, shuffle: bool, num_shards: int): """We can start anyway and get the same elements.""" seed = (3, 84) if shuffle else None dataset = index_dataset.create_index_dataset(records_per_dataset, proportions=proportions, shuffle=shuffle, seed=seed, num_shards=num_shards) all_values = list(dataset.take(50).as_numpy_iterator()) for start_index in range(1, 30): dataset = index_dataset.create_index_dataset( records_per_dataset, proportions=proportions, shuffle=shuffle, seed=seed, num_shards=num_shards, start_index=start_index) values = list(dataset.take(50 - start_index).as_numpy_iterator()) self.assertAllEqual(all_values[start_index:], values)
def test_num_epochs(self): """Setting the number of epochs yields a finite dataset.""" dataset = index_dataset.create_index_dataset(4, num_epochs=2) values = list(dataset.take(15).as_numpy_iterator()) self.assertEqual(dataset.cardinality(), 8) # pyformat: disable self.assertEqual( values, # First epoch. [ { "index": 0, "record_key": 0 }, { "index": 1, "record_key": 1 }, { "index": 2, "record_key": 2 }, { "index": 3, "record_key": 3 }, # Second epoch. { "index": 4, "record_key": 0 }, { "index": 5, "record_key": 1 }, { "index": 6, "record_key": 2 }, { "index": 7, "record_key": 3 } ])
def test_mixing_and_sharding(self): dataset = index_dataset.create_index_dataset([4, 6], shard_id=0, num_shards=2) values = list(dataset.take(10).as_numpy_iterator()) # pyformat: disable self.assertEqual( values, [ { "index": 0, "record_key": 0, "dataset_id": 0 }, { "index": 2, "record_key": 0, "dataset_id": 1 }, { "index": 4, "record_key": 1, "dataset_id": 0 }, { "index": 6, "record_key": 1, "dataset_id": 1 }, # second epoch of first dataset starts. { "index": 8, "record_key": 0, "dataset_id": 0 }, { "index": 10, "record_key": 2, "dataset_id": 1 }, { "index": 12, "record_key": 1, "dataset_id": 0 }, # second epoch of second dataset starts. { "index": 14, "record_key": 0, "dataset_id": 1 }, { "index": 16, "record_key": 0, "dataset_id": 0 }, { "index": 18, "record_key": 1, "dataset_id": 1 } ]) # pyformat: enable dataset = index_dataset.create_index_dataset([4, 6], shard_id=1, num_shards=2) values = list(dataset.take(10).as_numpy_iterator()) # pyformat: disable self.assertEqual( values, [ { "index": 1, "record_key": 2, "dataset_id": 0 }, { "index": 3, "record_key": 3, "dataset_id": 1 }, { "index": 5, "record_key": 3, "dataset_id": 0 }, { "index": 7, "record_key": 4, "dataset_id": 1 }, # second epoch of first dataset starts. { "index": 9, "record_key": 2, "dataset_id": 0 }, { "index": 11, "record_key": 5, "dataset_id": 1 }, { "index": 13, "record_key": 3, "dataset_id": 0 }, # second epoch of second dataset starts. { "index": 15, "record_key": 3, "dataset_id": 1 }, { "index": 17, "record_key": 2, "dataset_id": 0 }, { "index": 19, "record_key": 4, "dataset_id": 1 } ])
def test_shuffle_and_sharding(self): dataset = index_dataset.create_index_dataset(6, shuffle=True, seed=(32, 73), shard_id=0, num_shards=2) values = list(dataset.take(6).as_numpy_iterator()) # pyformat: disable self.assertEqual( values, [ { "index": 0, "record_key": 2 }, { "index": 2, "record_key": 0 }, { "index": 4, "record_key": 1 }, # Second epoch. { "index": 6, "record_key": 2 }, { "index": 8, "record_key": 0 }, { "index": 10, "record_key": 1 } ]) # pyformat: enable dataset = index_dataset.create_index_dataset(6, shuffle=True, seed=(42, 73), shard_id=1, num_shards=2) values = list(dataset.take(6).as_numpy_iterator()) # pyformat: disable self.assertEqual( values, [ { "index": 1, "record_key": 4 }, { "index": 3, "record_key": 5 }, { "index": 5, "record_key": 3 }, # Second epoch. { "index": 7, "record_key": 3 }, { "index": 9, "record_key": 4 }, { "index": 11, "record_key": 5 } ])
def test_mixing_with_float_proportions(self): """Test mixing with 2 datasets with 4 and 6 elements.""" dataset = index_dataset.create_index_dataset([3, 4], proportions=[0.2, 0.8]) values = list(dataset.take(16).as_numpy_iterator()) # pyformat: disable self.assertEqual( values, # First epoch for both datasets. [ { "index": 0, "record_key": 0, "dataset_id": 0 }, { "index": 1, "record_key": 0, "dataset_id": 1 }, { "index": 2, "record_key": 1, "dataset_id": 1 }, { "index": 3, "record_key": 2, "dataset_id": 1 }, { "index": 4, "record_key": 3, "dataset_id": 1 }, { "index": 5, "record_key": 1, "dataset_id": 0 }, # Second dataset is finished and starts second epoch. { "index": 6, "record_key": 0, "dataset_id": 1 }, { "index": 7, "record_key": 1, "dataset_id": 1 }, { "index": 8, "record_key": 2, "dataset_id": 1 }, { "index": 9, "record_key": 3, "dataset_id": 1 }, { "index": 10, "record_key": 2, "dataset_id": 0 }, # Second dataset starts third epoch. { "index": 11, "record_key": 0, "dataset_id": 1 }, { "index": 12, "record_key": 1, "dataset_id": 1 }, { "index": 13, "record_key": 2, "dataset_id": 1 }, { "index": 14, "record_key": 3, "dataset_id": 1 }, # First dataset is finished and starts second epoch. { "index": 15, "record_key": 0, "dataset_id": 0 } ])
def test_mixing_equal_probability(self): """Test mixing with 2 datasets with 4 and 6 elements.""" dataset = index_dataset.create_index_dataset([4, 6]) values = list(dataset.take(16).as_numpy_iterator()) # pyformat: disable self.assertEqual( values, # First epoch for both datasets. [ { "index": 0, "record_key": 0, "dataset_id": 0 }, { "index": 1, "record_key": 0, "dataset_id": 1 }, { "index": 2, "record_key": 1, "dataset_id": 0 }, { "index": 3, "record_key": 1, "dataset_id": 1 }, { "index": 4, "record_key": 2, "dataset_id": 0 }, { "index": 5, "record_key": 2, "dataset_id": 1 }, { "index": 6, "record_key": 3, "dataset_id": 0 }, { "index": 7, "record_key": 3, "dataset_id": 1 }, # First dataset is finished and starts second epoch. { "index": 8, "record_key": 0, "dataset_id": 0 }, { "index": 9, "record_key": 4, "dataset_id": 1 }, { "index": 10, "record_key": 1, "dataset_id": 0 }, { "index": 11, "record_key": 5, "dataset_id": 1 }, # Second dataset also starts second epoch. { "index": 12, "record_key": 2, "dataset_id": 0 }, { "index": 13, "record_key": 0, "dataset_id": 1 }, { "index": 14, "record_key": 3, "dataset_id": 0 }, { "index": 15, "record_key": 1, "dataset_id": 1 } ])
def test_simple(self): """No shuffling, no sharding, no mixing.""" dataset = index_dataset.create_index_dataset(6) values = list(dataset.take(15).as_numpy_iterator()) # pyformat: disable self.assertEqual( values, # First epoch. [ { "index": 0, "record_key": 0 }, { "index": 1, "record_key": 1 }, { "index": 2, "record_key": 2 }, { "index": 3, "record_key": 3 }, { "index": 4, "record_key": 4 }, { "index": 5, "record_key": 5 }, # Second epoch. { "index": 6, "record_key": 0 }, { "index": 7, "record_key": 1 }, { "index": 8, "record_key": 2 }, { "index": 9, "record_key": 3 }, { "index": 10, "record_key": 4 }, { "index": 11, "record_key": 5 }, # Third epoch. { "index": 12, "record_key": 0 }, { "index": 13, "record_key": 1 }, { "index": 14, "record_key": 2 } ])