Esempio n. 1
0
 def dataset_fn(input_context):
   dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
   dataset = dataset.shard(input_context.num_input_pipelines,
                           input_context.input_pipeline_id)
   return readers.TextLineDatasetV2(dataset).map(
       string_ops.string_to_number).batch(
           input_context.get_per_replica_batch_size(4))
Esempio n. 2
0
class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters(
        # pylint: disable=g-long-lambda
        ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)),
        ('Cache', lambda: dataset_ops.Dataset.range(5).cache()),
        ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
            dataset_ops.Dataset.range(5))),
        ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0))),
        ('FlatMap_Shuffle', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1)), True),
        ('Filter',
         lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
        ('FixedLengthRecordDatasetV2',
         lambda: readers.FixedLengthRecordDatasetV2([], 42)),
        ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)),
        ('FromTensorSlices',
         lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
        ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1)),
        ('Interleave_Shuffle', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0).shuffle(1),
            cycle_length=1), True),
        ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
        ('Options', lambda: dataset_ops.Dataset.range(5).with_options(
            dataset_ops.Options())),
        ('PaddedBatch',
         lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])),
        ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0),
            cycle_length=1,
            num_parallel_calls=1)),
        ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
            lambda x: x, num_parallel_calls=1)),
        ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
        ('Range', lambda: dataset_ops.Dataset.range(0)),
        ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
        ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1), True),
        ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
        ('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
        ('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
        ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])),
        ('Window', lambda: dataset_ops.Dataset.range(5).window(2)),
        ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
        # pylint: enable=g-long-lambda
    )
    def test_verify_dataset_shuffled(self, dataset_fn, expect_shuffled=False):
        dataset = dataset_fn()

        if not expect_shuffled:
            with test.mock.patch.object(logging, 'warning') as mock_log:
                shuffled = training_utils_v1.verify_dataset_shuffled(dataset)
                self.assertRegex(str(mock_log.call_args),
                                 'input dataset `x` is not shuffled.')
                self.assertFalse(shuffled)
        else:
            self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset))
    def testFromDatasetFileShardingDoesNotTriggerFunctionTracing(
            self, distribution, drop_remainder):
        # Create files that produce partial/empty batches at different batch.
        fname1 = os.path.join(self.get_temp_dir(), "1.txt")
        _create_text_file(fname1, 5)
        fname2 = os.path.join(self.get_temp_dir(), "2.txt")
        _create_text_file(fname2, 9)

        self.trace_count = 0

        @def_function.function
        def f(v):
            del v
            self.trace_count += 1

        distribution.extended.experimental_enable_get_next_as_optional = True
        dataset = readers.TextLineDatasetV2([fname1, fname2]).batch(
            4, drop_remainder=drop_remainder)
        dataset = distribution.experimental_distribute_dataset(dataset)
        for v in iter(dataset):
            f(v)
        self.assertEqual(self.trace_count, 1)
Esempio n. 4
0
class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters(
        # pylint: disable=g-long-lambda
        ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2), ValueError),
        ('Cache', lambda: dataset_ops.Dataset.range(5).cache()),
        ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
            dataset_ops.Dataset.range(5))),
        ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError),
        ('Filter',
         lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
        ('FixedLengthRecordDatasetV2',
         lambda: readers.FixedLengthRecordDatasetV2([], 42)),
        ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)),
        ('FromTensorSlices',
         lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
        ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
         ValueError),
        ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0),
            cycle_length=1,
            num_parallel_calls=1), ValueError),
        ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
        ('Options', lambda: dataset_ops.Dataset.range(5).with_options(
            dataset_ops.Options())),
        ('PaddedBatch',
         lambda: dataset_ops.Dataset.range(5).padded_batch(2, []), ValueError),
        ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
            lambda x: x, num_parallel_calls=1)),
        ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
        ('Range', lambda: dataset_ops.Dataset.range(0)),
        ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
        ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1)),
        ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
        ('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
        ('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
        ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])),
        ('Window', lambda: dataset_ops.Dataset.range(5).window(2), ValueError),
        ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
        # pylint: enable=g-long-lambda
    )
    def test_assert_not_batched(self, dataset_fn, expected_error=None):
        if expected_error is None:
            training_utils.assert_not_batched(dataset_fn())
        else:
            with self.assertRaises(expected_error):
                training_utils.assert_not_batched(dataset_fn())

    @parameterized.named_parameters(
        # pylint: disable=g-long-lambda
        ('Batch', lambda: dataset_ops.Dataset.range(5).batch(2)),
        ('Cache', lambda: dataset_ops.Dataset.range(5).cache()),
        ('Concatenate', lambda: dataset_ops.Dataset.range(5).concatenate(
            dataset_ops.Dataset.range(5))),
        ('FlatMap', lambda: dataset_ops.Dataset.range(5).flat_map(
            lambda _: dataset_ops.Dataset.from_tensors(0)), ValueError),
        ('Filter',
         lambda: dataset_ops.Dataset.range(5).filter(lambda _: True)),
        ('FixedLengthRecordDatasetV2',
         lambda: readers.FixedLengthRecordDatasetV2([], 42)),
        ('FromTensors', lambda: dataset_ops.Dataset.from_tensors(0)),
        ('FromTensorSlices',
         lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0])),
        ('Interleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
         ValueError),
        ('Map', lambda: dataset_ops.Dataset.range(5).map(lambda x: x)),
        ('Options', lambda: dataset_ops.Dataset.range(5).with_options(
            dataset_ops.Options())),
        ('PaddedBatch',
         lambda: dataset_ops.Dataset.range(5).padded_batch(2, [])),
        ('ParallelInterleave', lambda: dataset_ops.Dataset.range(5).interleave(
            lambda _: dataset_ops.Dataset.from_tensors(0),
            cycle_length=1,
            num_parallel_calls=1), ValueError),
        ('ParallelMap', lambda: dataset_ops.Dataset.range(5).map(
            lambda x: x, num_parallel_calls=1)),
        ('Prefetch', lambda: dataset_ops.Dataset.range(5).prefetch(1)),
        ('Range', lambda: dataset_ops.Dataset.range(0)),
        ('Repeat', lambda: dataset_ops.Dataset.range(0).repeat(0)),
        ('Shuffle', lambda: dataset_ops.Dataset.range(5).shuffle(1),
         ValueError),
        ('Skip', lambda: dataset_ops.Dataset.range(5).skip(2)),
        ('Take', lambda: dataset_ops.Dataset.range(5).take(2)),
        ('TextLineDataset', lambda: readers.TextLineDatasetV2([])),
        ('TFRecordDataset', lambda: readers.TFRecordDatasetV2([])),
        ('Window', lambda: dataset_ops.Dataset.range(5).window(2)),
        ('Zip', lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5))),
        # pylint: enable=g-long-lambda
    )
    def test_assert_not_shuffled(self, dataset_fn, expected_error=None):
        if expected_error is None:
            training_utils.assert_not_shuffled(dataset_fn())
        else:
            with self.assertRaises(expected_error):
                training_utils.assert_not_shuffled(dataset_fn())