def test_optional_features(self): def _dummy_preprocessor(output): return lambda _: tf.data.Dataset.from_tensors(output) default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": dataset_providers.Feature(vocabulary=default_vocab, required=False), "targets": dataset_providers.Feature(vocabulary=default_vocab, required=True), } test_utils.add_task( "text_missing_optional_feature", test_utils.get_fake_dataset, output_features=features, text_preprocessor=_dummy_preprocessor({"targets": "a"})) TaskRegistry.get_dataset( "text_missing_optional_feature", {"targets": 13}, "train", use_cached=False) test_utils.add_task( "text_missing_required_feature", test_utils.get_fake_dataset, output_features=features, text_preprocessor=_dummy_preprocessor({"inputs": "a"})) with self.assertRaisesRegex( ValueError, "Task dataset is missing expected output feature after preprocessing: " "targets"): TaskRegistry.get_dataset( "text_missing_required_feature", {"inputs": 13}, "train", use_cached=False)
def test_vocab(self): vocab = test_utils.sentencepiece_vocab() self.assertEqual(26, vocab.vocab_size) self.assertSequenceEqual(_TEST_TOKENS, vocab.encode(_TEST_STRING)) self.assertEqual(_TEST_STRING, tf.compat.as_bytes(vocab.decode(_TEST_TOKENS))) self.assertSequenceEqual(_TEST_TOKENS, tuple(vocab.encode_tf(_TEST_STRING).numpy())) self.assertEqual(_TEST_STRING, vocab.decode_tf(_TEST_TOKENS).numpy())
def test_vocab(self): vocab = test_utils.sentencepiece_vocab() self.assertEqual(26, vocab.vocab_size) self.assertSequenceEqual(self.TEST_TOKENS, vocab.encode(self.TEST_STRING)) self.assertEqual(self.TEST_STRING, vocab.decode(self.TEST_TOKENS)) self.assertSequenceEqual( self.TEST_TOKENS, tuple(vocab.encode_tf(self.TEST_STRING).numpy())) self.assertEqual(self.TEST_STRING, _decode_tf(vocab, self.TEST_TOKENS))
def test_no_eos(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": utils.Feature(add_eos=True, vocabulary=default_vocab), "targets": utils.Feature(add_eos=False, vocabulary=default_vocab), } test_utils.add_task("task_no_eos", test_utils.get_fake_dataset, output_features=features) fn_task = TaskRegistry.get("task_no_eos") test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
def test_triviaqa_truncate_text(self): vocab = test_utils.sentencepiece_vocab() def tokenize_and_prepare_dataset(inputs, targets): tokenized_inputs = vocab.encode(inputs) tokenized_targets = vocab.encode(targets) dataset = tf.data.Dataset.from_tensors({ 'inputs': tokenized_inputs, 'targets': tokenized_targets, }) return dataset, tokenized_targets inputs = 'This is a very very long string which must contain the answer.' targets = 'long string' og_dataset, tokenized_targets = tokenize_and_prepare_dataset( inputs, targets) for _ in range(0, 10): dataset = prep.trivia_qa_truncate_inputs( og_dataset, output_features=None, sequence_length={'inputs': 20}) for data in test_utils.dataset_as_text(dataset): self.assertLen(data['inputs'], 20) self.assertContainsSubset(tokenized_targets, data['inputs']) # Dummy input which exists in the vocab to be able to compare strings after # decoding. inputs = 'w h d n r t v' targets = 'h d' og_dataset, _ = tokenize_and_prepare_dataset(inputs, targets) for _ in range(0, 5): dataset = prep.trivia_qa_truncate_inputs( og_dataset, output_features=None, sequence_length={'inputs': 5}) for data in test_utils.dataset_as_text(dataset): self.assertLen(data['inputs'], 5) truncated_inputs = vocab.decode(data['inputs'].tolist()) new_targets = vocab.decode(data['targets'].tolist()) self.assertRegex(truncated_inputs, '.*' + targets + '.*') self.assertEqual(targets, new_targets)
def test_dtype(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": # defaults to int32 dataset_providers.Feature(vocabulary=default_vocab), "targets": dataset_providers.Feature(dtype=tf.int64, vocabulary=default_vocab), } test_utils.add_task("task_dtypes", test_utils.get_fake_dataset, output_features=features) dtype_task = TaskRegistry.get("task_dtypes") test_utils.verify_task_matches_fake_datasets(dtype_task, use_cached=False)
def register_dummy_task( task_name: str, dataset_fn: Callable[[str, str], tf.data.Dataset], output_feature_names: Sequence[str] = ("inputs", "targets")) -> None: """Register a dummy task for GetDatasetTest.""" dataset_providers.TaskRegistry.add( task_name, dataset_providers.TaskV3, source=dataset_providers.FunctionSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[dataset_providers.CacheDatasetPlaceholder()], output_features={ feat: dataset_providers.Feature(test_utils.sentencepiece_vocab()) for feat in output_feature_names }, metric_fns=[])
def test_v3_value_errors(self): dataset_fn = (lambda split, shuffle_files: tf.data.Dataset. from_tensors(["test"])) output_features = { "inputs": dataset_providers.Feature(test_utils.sentencepiece_vocab()) } with self.assertRaisesRegex( ValueError, "`CacheDatasetPlaceholder` can appear at most once in the " "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'." ): dataset_providers.TaskV3( "multiple_cache_placeholders", source=dataset_providers.FunctionSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder(), test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[]) with self.assertRaisesRegex( ValueError, "'test_token_preprocessor' has a `sequence_length` argument but occurs " "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This " "is not allowed since the sequence length is specified at run time." ): dataset_providers.TaskV3( "sequence_length_pre_cache", dataset_providers.FunctionSource( dataset_fn=dataset_fn, splits=["train"], ), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[])
def test_denoise_nested_decorators(self): """Test whether gin and utils.map_over_dataset decorators are compatible.""" bindings = """ preprocessors.unsupervised.preprocessors = [@preprocessors.denoise] preprocessors.denoise.noise_density = 0.15 preprocessors.denoise.noise_mask_fn = @preprocessors.iid_noise_mask preprocessors.denoise.inputs_fn = @noise_token_to_sentinel """ gin.parse_config(bindings) og_dataset = tf.data.Dataset.from_tensor_slices({'targets': [1, 2, 3]}) output_features = { 'targets': Feature(test_utils.sentencepiece_vocab()) } # Test denoise function when it is used as a gin-configurable of another # gin-configurable, prep.unsupervised. dataset = prep.unsupervised(og_dataset, output_features=output_features) self.assertIsInstance(dataset, tf.data.Dataset)
def test_denoise(self): tf.set_random_seed(55) vocab = test_utils.sentencepiece_vocab() target_tokens = vocab.encode('The quick brown fox.') # This is what it encodes to. self.assertEqual( target_tokens, [3, 2, 20, 4, 3, 2, 8, 13, 2, 3, 2, 23, 7, 19, 22, 3, 2, 7, 2]) og_dataset = tf.data.Dataset.from_tensor_slices({ 'targets': [target_tokens], }) output_features = { 'targets': utils.Feature(vocab), } # These are the parameters of denoise in the operative config of 'base'. # Except noise_density, bumped up from 0.15 to 0.3 in order to demonstrate # multiple corrupted spans. denoised_dataset = prep.denoise( og_dataset, output_features, noise_density=0.3, noise_mask_fn=prep.random_spans_noise_mask, inputs_fn=prep.noise_span_to_unique_sentinel, targets_fn=prep.nonnoise_span_to_unique_sentinel) # Two spans corrupted, [2] and [22, 3, 2, 7, 2], replaced by unique # sentinels 25 and 24 respectively. assert_dataset(denoised_dataset, [ { 'inputs': [ 3, 25, 20, 4, 3, 2, 8, 13, 2, 3, 2, 23, 7, 19, 24 ], 'targets': [ 25, 2, 24, 22, 3, 2, 7, 2 ], }, ])
def test_prefix_lm(self): vocab = test_utils.sentencepiece_vocab() inp = list(range(1, 101)) og_dataset = tf.data.Dataset.from_tensor_slices({'targets': [inp]}) og_dataset = og_dataset.repeat(100) output_features = {'targets': Feature(vocab)} output_dataset = prep.prefix_lm( og_dataset, { 'inputs': 100, 'targets': 100 }, output_features, ) input_lengths = set() for ex in output_dataset.as_numpy_iterator(): self.assertListEqual( ex['inputs'].tolist() + ex['targets'].tolist(), inp) input_lengths.add(len(ex['inputs'])) self.assertGreater(len(input_lengths), 1)
def test_not_equal(self): vocab1 = test_utils.sentencepiece_vocab() vocab2 = test_utils.sentencepiece_vocab(10) self.assertNotEqual(vocab1, vocab2)
def test_extra_ids(self): vocab = test_utils.sentencepiece_vocab(extra_ids=10) self.assertEqual(36, vocab.vocab_size) self.assertEqual("v", vocab.decode([25])) self.assertEqual(_UNK_STRING, tf.compat.as_bytes(vocab.decode([35]))) self.assertEqual(_UNK_STRING, vocab.decode_tf([35]).numpy())
def test_extra_ids(self): vocab = test_utils.sentencepiece_vocab(extra_ids=10) self.assertEqual(36, vocab.vocab_size) self.assertEqual("v", vocab.decode([25])) self.assertEqual(self.UNK_STRING, vocab.decode([35])) self.assertEqual(self.UNK_STRING, _decode_tf(vocab, [35]))
def test_get_sentencepiece_model_path(self): self.assertEqual( test_utils.sentencepiece_vocab().sentencepiece_model_file, mesh_transformer.get_sentencepiece_model_path("cached_mixture") )