def test_extra_ids(self): vocab = test_utils.sentencepiece_vocab(extra_ids=10) self.assertEqual(36, vocab.vocab_size) self.assertEqual("v", vocab.decode([25])) test_string = "<extra_id_0> <extra_id_1> v <extra_id_9>" test_tokens = (35, 34, 3, 25, 26) self.assertEqual(test_string, vocab.decode(test_tokens)) self.assertEqual(test_string, _decode_tf(vocab, test_tokens)) self.assertSequenceEqual(test_tokens, vocab.encode(test_string)) self.assertSequenceEqual( test_tokens, tuple(vocab.encode_tf(test_string).numpy()))
def test_no_eos(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": dataset_providers.Feature(add_eos=True, vocabulary=default_vocab), "targets": dataset_providers.Feature(add_eos=False, vocabulary=default_vocab), } self.add_task("task_no_eos", self.function_source, output_features=features) self.verify_task_matches_fake_datasets("task_no_eos", use_cached=False)
def test_feature_validation(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": dataset_providers.Feature(vocabulary=default_vocab, required=False), "targets": dataset_providers.Feature(vocabulary=default_vocab, required=True), } def _materialize(output): task = dataset_providers.Task( "feature_validation_task", self.function_source, output_features=features, preprocessors=( lambda _: tf.data.Dataset.from_tensors(output), ), metric_fns=[], ) list( task.get_dataset({ "inputs": 13, "targets": 13 }, "train", use_cached=False).as_numpy_iterator()) # Missing optional feature: OK _materialize({"targets": [0]}) # Missing required feature. with self.assertRaisesWithLiteralMatch( ValueError, "Task dataset is missing expected output feature after preprocessing: " "targets"): _materialize({"inputs": [0]}) # Wrong type. with self.assertRaisesWithLiteralMatch( ValueError, "Task dataset has incorrect type for feature 'targets' after " "preprocessing: Got string, expected int32"): _materialize({"targets": ["wrong type"]}) # Wrong rank. with self.assertRaisesWithLiteralMatch( ValueError, "Task dataset has incorrect rank for feature 'targets' after " "preprocessing: Got 0, expected 1"): _materialize({"targets": 0})
def test_value_errors(self): dataset_fn = (lambda split, shuffle_files: tf.data.Dataset. from_tensors(["test"])) output_features = { "inputs": dataset_providers.Feature(test_utils.sentencepiece_vocab()) } with self.assertRaisesWithLiteralMatch( ValueError, "`CacheDatasetPlaceholder` can appear at most once in the " "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'." ): dataset_providers.Task( "multiple_cache_placeholders", source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder(), test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[]) with self.assertRaisesWithLiteralMatch( ValueError, "'test_token_preprocessor' has a `sequence_length` argument but occurs " "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This " "is not allowed since the sequence length is specified at run time." ): dataset_providers.Task( "sequence_length_pre_cache", dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train"], ), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[])
def test_append_eos(self): og_dataset = tf.data.Dataset.from_tensors({ 'inputs': [1, 2, 3], 'targets': [4, 5, 6, 7], 'arrows': [8, 9, 10, 11], 'bows': [12, 13], }) vocab = test_utils.sentencepiece_vocab() output_features = { 'inputs': Feature(vocab, add_eos=False), 'targets': Feature(vocab, add_eos=True), 'arrows': Feature(vocab, add_eos=True), } sequence_length = {'inputs': 4, 'targets': 3, 'arrows': 5, 'bows': 1} # Add eos only. assert_dataset( preprocessors.append_eos(og_dataset, output_features), { 'inputs': [1, 2, 3], 'targets': [4, 5, 6, 7, 1], 'arrows': [8, 9, 10, 11, 1], 'bows': [12, 13], }) # Trim to sequence lengths. assert_dataset( preprocessors.append_eos_after_trim( og_dataset, output_features=output_features, sequence_length=sequence_length), { 'inputs': [1, 2, 3], 'targets': [4, 5, 1], 'arrows': [8, 9, 10, 11, 1], 'bows': [12, 13], }) # Don't trim to sequence lengths. assert_dataset( preprocessors.append_eos_after_trim( og_dataset, output_features=output_features), { 'inputs': [1, 2, 3], 'targets': [4, 5, 6, 7, 1], 'arrows': [8, 9, 10, 11, 1], 'bows': [12, 13], })
def register_dummy_task( task_name: str, dataset_fn: Callable[[str, str], tf.data.Dataset], output_feature_names: Sequence[str] = ("inputs", "targets")) -> None: """Register a dummy task for GetDatasetTest.""" dataset_providers.TaskRegistry.add( task_name, source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ dataset_providers.CacheDatasetPlaceholder(), preprocessors.append_eos_after_trim, ], output_features={ feat: dataset_providers.Feature(test_utils.sentencepiece_vocab()) for feat in output_feature_names }, metric_fns=[])
def test_dtype(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": # defaults to int32 dataset_providers.Feature(vocabulary=default_vocab), "targets": dataset_providers.Feature(dtype=tf.int64, vocabulary=default_vocab), } self.add_task( "task_dtypes", self.function_source, preprocessors=self.DEFAULT_PREPROCESSORS + ( utils.map_over_dataset( lambda x: { k: tf.cast(v, tf.int64) if k == "targets" else v # pylint:disable=g-long-lambda for k, v in x.items() }), ), output_features=features) self.verify_task_matches_fake_datasets("task_dtypes", use_cached=False)
def test_not_equal(self): vocab1 = test_utils.sentencepiece_vocab() vocab2 = test_utils.sentencepiece_vocab(10) self.assertNotEqual(vocab1, vocab2)