def test_v3_value_errors(self): dataset_fn = (lambda split, shuffle_files: tf.data.Dataset. from_tensors(["test"])) output_features = { "inputs": dataset_providers.Feature(test_utils.sentencepiece_vocab()) } with self.assertRaisesRegex( ValueError, "`CacheDatasetPlaceholder` can appear at most once in the " "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'." ): dataset_providers.TaskV3( "multiple_cache_placeholders", source=dataset_providers.FunctionSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder(), test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[]) with self.assertRaisesRegex( ValueError, "'test_token_preprocessor' has a `sequence_length` argument but occurs " "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This " "is not allowed since the sequence length is specified at run time." ): dataset_providers.TaskV3( "sequence_length_pre_cache", dataset_providers.FunctionSource( dataset_fn=dataset_fn, splits=["train"], ), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[])
def register_dummy_task( task_name: str, dataset_fn: Callable[[str, str], tf.data.Dataset], output_feature_names: Sequence[str] = ("inputs", "targets")) -> None: """Register a dummy task for GetDatasetTest.""" dataset_providers.TaskRegistry.add( task_name, dataset_providers.TaskV3, source=dataset_providers.FunctionSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[dataset_providers.CacheDatasetPlaceholder()], output_features={ feat: dataset_providers.Feature(test_utils.sentencepiece_vocab()) for feat in output_feature_names }, metric_fns=[])
def setUp(self): super().setUp() self.maxDiff = None # pylint:disable=invalid-name # Mock TFDS # Note we don't use mock.Mock since they fail to pickle. fake_tfds_paths = { "train": [ { # pylint:disable=g-complex-comprehension "filename": "train.tfrecord-%05d-of-00002" % i, "skip": 0, "take": -1 } for i in range(2) ], "validation": [{ "filename": "validation.tfrecord-00000-of-00001", "skip": 0, "take": -1 }], } def _load_shard(shard_instruction, shuffle_files, seed): del shuffle_files del seed fname = shard_instruction["filename"] if "train" in fname: if fname.endswith("00000-of-00002"): return get_fake_dataset("train").take(2) else: return get_fake_dataset("train").skip(2) else: return get_fake_dataset("validation") fake_tfds = FakeLazyTfds(name="fake:0.0.0", load=get_fake_dataset, load_shard=_load_shard, info=FakeTfdsInfo(splits={ "train": None, "validation": None }), files=fake_tfds_paths.get, size=lambda x: 30 if x == "train" else 10) self._tfds_patcher = mock.patch("t5.data.utils.LazyTfdsLoader", new=mock.Mock(return_value=fake_tfds)) self._tfds_patcher.start() # Set up data directory. self.test_tmpdir = self.get_tempdir() self.test_data_dir = os.path.join(self.test_tmpdir, "test_data") shutil.copytree(TEST_DATA_DIR, self.test_data_dir) for root, dirs, _ in os.walk(self.test_data_dir): for d in dirs + [""]: os.chmod(os.path.join(root, d), 0o777) # Register a cached test Task. dataset_utils.set_global_cache_dirs([self.test_data_dir]) clear_tasks() add_tfds_task("cached_task", token_preprocessor=test_token_preprocessor) add_tfds_task("cached_task_no_token_prep") # Prepare cached tasks. self.cached_task = TaskRegistry.get("cached_task") cached_task_dir = os.path.join(self.test_data_dir, "cached_task") _dump_fake_dataset(os.path.join(cached_task_dir, "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) _dump_fake_dataset( os.path.join(cached_task_dir, "validation.tfrecord"), _FAKE_TOKENIZED_DATASET["validation"], [2], _dump_examples_to_tfrecord) shutil.copytree( cached_task_dir, os.path.join(self.test_data_dir, "cached_task_no_token_prep")) # Prepare uncached TfdsTask. add_tfds_task("uncached_task", token_preprocessor=test_token_preprocessor) add_tfds_task("uncached_task_no_token_prep") self.uncached_task = TaskRegistry.get("uncached_task") # Prepare uncached, random TfdsTask add_tfds_task("uncached_random_task", token_preprocessor=random_token_preprocessor) self.uncached_random_task = TaskRegistry.get("uncached_random_task") # Prepare uncached TextLineTask. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tsv"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tsv) TaskRegistry.add("text_line_task", dataset_providers.TextLineTask, split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tsv*"), }, skip_header_lines=1, text_preprocessor=[ _split_tsv_preprocessor, test_text_preprocessor ], output_features=dataset_providers.Feature( sentencepiece_vocab()), metric_fns=[]) self.text_line_task = TaskRegistry.get("text_line_task") # Prepare uncached TFExampleTask _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) TaskRegistry.add( "tf_example_task", dataset_providers.TFExampleTask, split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tfrecord*"), }, feature_description={ "prefix": tf.io.FixedLenFeature([], tf.string), "suffix": tf.io.FixedLenFeature([], tf.string), }, text_preprocessor=[test_text_preprocessor], output_features=dataset_providers.Feature(sentencepiece_vocab()), metric_fns=[]) self.tf_example_task = TaskRegistry.get("tf_example_task") # Prepare uncached Task. def _dataset_fn(split, shuffle_files, filepattern=os.path.join(self.test_data_dir, "train.tsv*")): del split files = tf.data.Dataset.list_files(filepattern, shuffle=shuffle_files) return files.interleave( lambda f: tf.data.TextLineDataset(f).skip(1), num_parallel_calls=tf.data.experimental.AUTOTUNE) TaskRegistry.add("general_task", dataset_providers.Task, dataset_fn=_dataset_fn, splits=["train"], text_preprocessor=[ _split_tsv_preprocessor, test_text_preprocessor ], output_features=dataset_providers.Feature( sentencepiece_vocab()), metric_fns=[]) self.general_task = TaskRegistry.get("general_task") # Prepare uncached TaskV3. TaskRegistry.add("task_v3", dataset_providers.TaskV3, source=dataset_providers.FunctionSource( dataset_fn=get_fake_dataset, splits=["train", "validation"]), preprocessors=[ test_text_preprocessor, preprocessors.tokenize, token_preprocessor_no_sequence_length, dataset_providers.CacheDatasetPlaceholder(), ], output_features={ "inputs": dataset_providers.Feature(sentencepiece_vocab()), "targets": dataset_providers.Feature(sentencepiece_vocab()), }, metric_fns=[]) self.task_v3 = TaskRegistry.get("task_v3") # Prepare uncached TaskV3 with no caching before tokenization. TaskRegistry.add("task_v3_tokenized_postcache", dataset_providers.TaskV3, source=dataset_providers.FunctionSource( dataset_fn=get_fake_dataset, splits=["train", "validation"]), preprocessors=[ test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, token_preprocessor_no_sequence_length, ], output_features={ "inputs": dataset_providers.Feature(sentencepiece_vocab()), "targets": dataset_providers.Feature(sentencepiece_vocab()), }, metric_fns=[])