def test_value_errors(self): dataset_fn = (lambda split, shuffle_files: tf.data.Dataset. from_tensors(["test"])) output_features = { "inputs": dataset_providers.Feature(test_utils.sentencepiece_vocab()) } with self.assertRaisesWithLiteralMatch( ValueError, "`CacheDatasetPlaceholder` can appear at most once in the " "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'." ): dataset_providers.Task( "multiple_cache_placeholders", source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder(), test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[]) with self.assertRaisesWithLiteralMatch( ValueError, "'test_token_preprocessor' has a `sequence_length` argument but occurs " "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This " "is not allowed since the sequence length is specified at run time." ): dataset_providers.Task( "sequence_length_pre_cache", dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train"], ), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[])
def test_supports_caching(self): self.assertFalse( dataset_providers.Task( "nosupports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[]).supports_caching) self.assertFalse( dataset_providers.Task( "nosupports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[preprocessors.tokenize]).supports_caching) self.assertTrue( dataset_providers.Task( "supports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[ preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder() ]).supports_caching) self.assertTrue( dataset_providers.Task( "supports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[ dataset_providers.CacheDatasetPlaceholder(required=True), preprocessors.tokenize, ]).supports_caching) self.assertTrue( dataset_providers.Task( "supports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[ dataset_providers.CacheDatasetPlaceholder(), ]).supports_caching)
def test_requires_caching(self): self.assertFalse( dataset_providers.Task( "nosupports_cache", output_features=self.DEFAULT_OUTPUT_FEATURES, source=self.function_source, preprocessors=[preprocessors.tokenize]).requires_caching) self.assertFalse( dataset_providers.Task( "supports_cache", output_features=self.DEFAULT_OUTPUT_FEATURES, source=self.function_source, preprocessors=[ preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder() ]).requires_caching) task = dataset_providers.Task( "requires_cache", output_features=self.DEFAULT_OUTPUT_FEATURES, source=self.function_source, preprocessors=[ dataset_providers.CacheDatasetPlaceholder(required=True), preprocessors.tokenize, ]) self.assertTrue(task.requires_caching) with self.assertRaisesWithLiteralMatch( ValueError, "Task 'requires_cache' requires caching, but was called with " "`use_cached=False`."): task.get_dataset({"inputs": 512, "targets": 512}, use_cached=False) # We haven't actually cached the task, so it still fails but with a # different error. with self.assertRaisesWithLiteralMatch( AssertionError, "'requires_cache' does not exist in any of the task cache " "directories."): task.get_dataset({"inputs": 512, "targets": 512}, use_cached=True)
def _materialize(output): task = dataset_providers.Task( "feature_validation_task", self.function_source, output_features=features, preprocessors=(lambda _: tf.data.Dataset.from_tensors(output),), metric_fns=[], ) list( task.get_dataset( {"inputs": 13, "targets": 13}, "train", use_cached=False ).as_numpy_iterator() )
def _task_from_tensor_slices(name, tensor_slices, label_classes): return dataset_providers.Task( name, dataset_providers.FunctionDataSource( lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices(tensor_slices), splits=("validation")), preprocessors=[utils.map_over_dataset(lambda ex: { "inputs": tf.range(ex["inputs_lengths"]), "targets": tf.range(ex["targets_lengths"]), "targets_pretokenized": ex["targets_pretokenized"], })], postprocess_fn=functools.partial( _string_label_to_class_id_postprocessor, label_classes=label_classes), output_features={"inputs": dataset_providers.Feature(mock.Mock()), "targets": dataset_providers.Feature(mock.Mock())} )
def test_disallow_shuffle(self): task = dataset_providers.Task( "no_shuffle", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=self.DEFAULT_PREPROCESSORS, shuffle_buffer_size=None) with self.assertRaisesWithLiteralMatch( ValueError, "Shuffling is disallowed for Task 'no_shuffle' since its " '`shuffle_buffer_size` was set to `None` on construction.'): task.get_dataset(None, shuffle=True) with self.assertRaisesWithLiteralMatch( ValueError, "Shuffling is disallowed for Task 'no_shuffle' since its " '`shuffle_buffer_size` was set to `None` on construction.'): task.get_dataset(None, shuffle=True, shuffle_buffer_size=100) task.get_dataset(None, shuffle=False)