Ejemplo n.º 1
0
class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters(
        # pylint: disable=g-long-lambda
        ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)),
        ('Cache', lambda: dataset_ops.Dataset.range(5).cache()),
        ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
            dataset_ops.Dataset.range(5))),
        ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0))),
        ('FlatMap_Shuffle', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1)), True),
        ('Filter',
         lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
        ('FixedLengthRecordDatasetV2',
         lambda: readers.FixedLengthRecordDatasetV2([], 42)),
        ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)),
        ('FromTensorSlices',
         lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
        ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1)),
        ('Interleave_Shuffle', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1),
            cycle_length=1), True),
        ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
        ('Options', lambda: dataset_ops.Dataset.range(5).with_options(
            dataset_ops.Options())),
        ('PaddedBatch',
         lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])),
        ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0),
            cycle_length=1,
            num_parallel_calls=1)),
        ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
            lambda x: x, num_parallel_calls=1)),
        ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
        ('Range', lambda: dataset_ops.Dataset.range(0)),
        ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
        ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), True),
        ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
        ('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
        ('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
        ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])),
        ('Window', lambda: dataset_ops.Dataset.range(5).window(2)),
        ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
        # pylint: enable=g-long-lambda
    )
    def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False):
        dataset = dataset_fn()

        if not expect_shuffled:
            with test.mock.patch.object(logging, 'warning') as mock_log:
                shuffled = training_utils_v1.verify_dataset_shuffled(dataset)
                self.assertRegex(str(mock_log.call_args),
                                 'input dataset `x` is not shuffled.')
                self.assertFalse(shuffled)
        else:
            self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset))
Ejemplo n.º 2
0
class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters(
        # pylint: disable=g-long-lambda
        ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2), ValueError),
        ('Cache', lambda: dataset_ops.Dataset.range(5).cache()),
        ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
            dataset_ops.Dataset.range(5))),
        ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError),
        ('Filter',
         lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
        ('FixedLengthRecordDatasetV2',
         lambda: readers.FixedLengthRecordDatasetV2([], 42)),
        ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)),
        ('FromTensorSlices',
         lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
        ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
         ValueError),
        ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0),
            cycle_length=1,
            num_parallel_calls=1), ValueError),
        ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
        ('Options', lambda: dataset_ops.Dataset.range(5).with_options(
            dataset_ops.Options())),
        ('PaddedBatch',
         lambda: dataset_ops.Dataset.range(5).padded_batch(2, []), ValueError),
        ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
            lambda x: x, num_parallel_calls=1)),
        ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
        ('Range', lambda: dataset_ops.Dataset.range(0)),
        ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
        ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1)),
        ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
        ('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
        ('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
        ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])),
        ('Window', lambda: dataset_ops.Dataset.range(5).window(2), ValueError),
        ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
        # pylint: enable=g-long-lambda
    )
    def test_assert_not_batched(self, dataset_fn, expected_error=None):
        if expected_error is None:
            training_utils.assert_not_batched(dataset_fn())
        else:
            with self.assertRaises(expected_error):
                training_utils.assert_not_batched(dataset_fn())

    @parameterized.named_parameters(
        # pylint: disable=g-long-lambda
        ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)),
        ('Cache', lambda: dataset_ops.Dataset.range(5).cache()),
        ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
            dataset_ops.Dataset.range(5))),
        ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError),
        ('Filter',
         lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
        ('FixedLengthRecordDatasetV2',
         lambda: readers.FixedLengthRecordDatasetV2([], 42)),
        ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)),
        ('FromTensorSlices',
         lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
        ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
         ValueError),
        ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
        ('Options', lambda: dataset_ops.Dataset.range(5).with_options(
            dataset_ops.Options())),
        ('PaddedBatch',
         lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])),
        ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0),
            cycle_length=1,
            num_parallel_calls=1), ValueError),
        ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
            lambda x: x, num_parallel_calls=1)),
        ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
        ('Range', lambda: dataset_ops.Dataset.range(0)),
        ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
        ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1),
         ValueError),
        ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
        ('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
        ('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
        ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])),
        ('Window', lambda: dataset_ops.Dataset.range(5).window(2)),
        ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
        # pylint: enable=g-long-lambda
    )
    def test_assert_not_shuffled(self, dataset_fn, expected_error=None):
        if expected_error is None:
            training_utils.assert_not_shuffled(dataset_fn())
        else:
            with self.assertRaises(expected_error):
                training_utils.assert_not_shuffled(dataset_fn())