class DatasetUtilsTest(test.TestCase, parameterized.TestCase): @parameterized.named_parameters( # pylint: disable=g-long-lambda ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)), ('Cache', lambda: dataset_ops.Dataset.range(5).cache()), ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate( dataset_ops.Dataset.range(5))), ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map( lambda _: dataset_ops.Dataset.from_tensors(0))), ('FlatMap_Shuffle', lambda: dataset_ops.Dataset.range(5).flat_map( lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1)), True), ('Filter', lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)), ('FixedLengthRecordDatasetV2', lambda: readers.FixedLengthRecordDatasetV2([], 42)), ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)), ('FromTensorSlices', lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])), ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave( lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1)), ('Interleave_Shuffle', lambda: dataset_ops.Dataset.range(5).interleave( lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1), cycle_length=1), True), ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)), ('Options', lambda: dataset_ops.Dataset.range(5).with_options( dataset_ops.Options())), ('PaddedBatch', lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])), ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave( lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1, num_parallel_calls=1)), ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map( lambda x: x, num_parallel_calls=1)), ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)), ('Range', lambda: dataset_ops.Dataset.range(0)), ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)), ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), True), ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)), ('Take', lambda: dataset_ops.Dataset.range(5).take(2)), ('TextLineDataset', lambda: readers.TextLineDatasetV2([])), ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])), ('Window', lambda: dataset_ops.Dataset.range(5).window(2)), ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))), # pylint: enable=g-long-lambda ) def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False): dataset = dataset_fn() if not expect_shuffled: with test.mock.patch.object(logging, 'warning') as mock_log: shuffled = training_utils_v1.verify_dataset_shuffled(dataset) self.assertRegex(str(mock_log.call_args), 'input dataset `x` is not shuffled.') self.assertFalse(shuffled) else: self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset))
class DatasetUtilsTest(test.TestCase, parameterized.TestCase): @parameterized.named_parameters( # pylint: disable=g-long-lambda ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2), ValueError), ('Cache', lambda: dataset_ops.Dataset.range(5).cache()), ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate( dataset_ops.Dataset.range(5))), ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map( lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError), ('Filter', lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)), ('FixedLengthRecordDatasetV2', lambda: readers.FixedLengthRecordDatasetV2([], 42)), ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)), ('FromTensorSlices', lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])), ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave( lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1), ValueError), ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave( lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1, num_parallel_calls=1), ValueError), ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)), ('Options', lambda: dataset_ops.Dataset.range(5).with_options( dataset_ops.Options())), ('PaddedBatch', lambda: dataset_ops.Dataset.range(5).padded_batch(2, []), ValueError), ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map( lambda x: x, num_parallel_calls=1)), ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)), ('Range', lambda: dataset_ops.Dataset.range(0)), ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)), ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1)), ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)), ('Take', lambda: dataset_ops.Dataset.range(5).take(2)), ('TextLineDataset', lambda: readers.TextLineDatasetV2([])), ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])), ('Window', lambda: dataset_ops.Dataset.range(5).window(2), ValueError), ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))), # pylint: enable=g-long-lambda ) def test_assert_not_batched(self, dataset_fn, expected_error=None): if expected_error is None: training_utils.assert_not_batched(dataset_fn()) else: with self.assertRaises(expected_error): training_utils.assert_not_batched(dataset_fn()) @parameterized.named_parameters( # pylint: disable=g-long-lambda ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)), ('Cache', lambda: dataset_ops.Dataset.range(5).cache()), ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate( dataset_ops.Dataset.range(5))), ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map( lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError), ('Filter', lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)), ('FixedLengthRecordDatasetV2', lambda: readers.FixedLengthRecordDatasetV2([], 42)), ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)), ('FromTensorSlices', lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])), ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave( lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1), ValueError), ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)), ('Options', lambda: dataset_ops.Dataset.range(5).with_options( dataset_ops.Options())), ('PaddedBatch', lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])), ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave( lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1, num_parallel_calls=1), ValueError), ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map( lambda x: x, num_parallel_calls=1)), ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)), ('Range', lambda: dataset_ops.Dataset.range(0)), ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)), ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), ValueError), ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)), ('Take', lambda: dataset_ops.Dataset.range(5).take(2)), ('TextLineDataset', lambda: readers.TextLineDatasetV2([])), ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])), ('Window', lambda: dataset_ops.Dataset.range(5).window(2)), ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))), # pylint: enable=g-long-lambda ) def test_assert_not_shuffled(self, dataset_fn, expected_error=None): if expected_error is None: training_utils.assert_not_shuffled(dataset_fn()) else: with self.assertRaises(expected_error): training_utils.assert_not_shuffled(dataset_fn())