Beispiel #1
0
    def test_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.Task(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.Task(
                "sequence_length_pre_cache",
                dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
    def test_supports_caching(self):
        self.assertFalse(
            dataset_providers.Task(
                "nosupports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[]).supports_caching)

        self.assertFalse(
            dataset_providers.Task(
                "nosupports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[preprocessors.tokenize]).supports_caching)

        self.assertTrue(
            dataset_providers.Task(
                "supports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[
                    preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder()
                ]).supports_caching)

        self.assertTrue(
            dataset_providers.Task(
                "supports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[
                    dataset_providers.CacheDatasetPlaceholder(required=True),
                    preprocessors.tokenize,
                ]).supports_caching)

        self.assertTrue(
            dataset_providers.Task(
                "supports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[
                    dataset_providers.CacheDatasetPlaceholder(),
                ]).supports_caching)
    def test_requires_caching(self):
        self.assertFalse(
            dataset_providers.Task(
                "nosupports_cache",
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                source=self.function_source,
                preprocessors=[preprocessors.tokenize]).requires_caching)

        self.assertFalse(
            dataset_providers.Task(
                "supports_cache",
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                source=self.function_source,
                preprocessors=[
                    preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder()
                ]).requires_caching)

        task = dataset_providers.Task(
            "requires_cache",
            output_features=self.DEFAULT_OUTPUT_FEATURES,
            source=self.function_source,
            preprocessors=[
                dataset_providers.CacheDatasetPlaceholder(required=True),
                preprocessors.tokenize,
            ])

        self.assertTrue(task.requires_caching)

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task 'requires_cache' requires caching, but was called with "
                "`use_cached=False`."):
            task.get_dataset({"inputs": 512, "targets": 512}, use_cached=False)

        # We haven't actually cached the task, so it still fails but with a
        # different error.
        with self.assertRaisesWithLiteralMatch(
                AssertionError,
                "'requires_cache' does not exist in any of the task cache "
                "directories."):
            task.get_dataset({"inputs": 512, "targets": 512}, use_cached=True)
 def _materialize(output):
   task = dataset_providers.Task(
       "feature_validation_task",
       self.function_source,
       output_features=features,
       preprocessors=(lambda _: tf.data.Dataset.from_tensors(output),),
       metric_fns=[],
   )
   list(
       task.get_dataset(
           {"inputs": 13, "targets": 13}, "train", use_cached=False
       ).as_numpy_iterator()
   )
Beispiel #5
0
 def _task_from_tensor_slices(name, tensor_slices, label_classes):
   return dataset_providers.Task(
       name,
       dataset_providers.FunctionDataSource(
           lambda split, shuffle_files:
           tf.data.Dataset.from_tensor_slices(tensor_slices),
           splits=("validation")),
       preprocessors=[utils.map_over_dataset(lambda ex: {
           "inputs": tf.range(ex["inputs_lengths"]),
           "targets": tf.range(ex["targets_lengths"]),
           "targets_pretokenized": ex["targets_pretokenized"],
       })],
       postprocess_fn=functools.partial(
           _string_label_to_class_id_postprocessor,
           label_classes=label_classes),
       output_features={"inputs": dataset_providers.Feature(mock.Mock()),
                        "targets": dataset_providers.Feature(mock.Mock())}
   )
    def test_disallow_shuffle(self):
        task = dataset_providers.Task(
            "no_shuffle",
            source=self.function_source,
            output_features=self.DEFAULT_OUTPUT_FEATURES,
            preprocessors=self.DEFAULT_PREPROCESSORS,
            shuffle_buffer_size=None)

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Shuffling is disallowed for Task 'no_shuffle' since its "
                '`shuffle_buffer_size` was set to `None` on construction.'):
            task.get_dataset(None, shuffle=True)

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Shuffling is disallowed for Task 'no_shuffle' since its "
                '`shuffle_buffer_size` was set to `None` on construction.'):
            task.get_dataset(None, shuffle=True, shuffle_buffer_size=100)

        task.get_dataset(None, shuffle=False)