def test_value_errors(self): dataset_fn = (lambda split, shuffle_files: tf.data.Dataset. from_tensors(["test"])) output_features = { "inputs": dataset_providers.Feature(test_utils.sentencepiece_vocab()) } with self.assertRaisesWithLiteralMatch( ValueError, "`CacheDatasetPlaceholder` can appear at most once in the " "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'." ): dataset_providers.Task( "multiple_cache_placeholders", source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder(), test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[]) with self.assertRaisesWithLiteralMatch( ValueError, "'test_token_preprocessor' has a `sequence_length` argument but occurs " "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This " "is not allowed since the sequence length is specified at run time." ): dataset_providers.Task( "sequence_length_pre_cache", dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train"], ), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[])
def test_supports_caching(self): self.assertFalse( dataset_providers.Task( "nosupports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[]).supports_caching) self.assertFalse( dataset_providers.Task( "nosupports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[preprocessors.tokenize]).supports_caching) self.assertTrue( dataset_providers.Task( "supports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[ preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder() ]).supports_caching) self.assertTrue( dataset_providers.Task( "supports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[ dataset_providers.CacheDatasetPlaceholder(required=True), preprocessors.tokenize, ]).supports_caching) self.assertTrue( dataset_providers.Task( "supports_cache", source=self.function_source, output_features=self.DEFAULT_OUTPUT_FEATURES, preprocessors=[ dataset_providers.CacheDatasetPlaceholder(), ]).supports_caching)
def test_requires_caching(self): self.assertFalse( dataset_providers.Task( "nosupports_cache", output_features=self.DEFAULT_OUTPUT_FEATURES, source=self.function_source, preprocessors=[preprocessors.tokenize]).requires_caching) self.assertFalse( dataset_providers.Task( "supports_cache", output_features=self.DEFAULT_OUTPUT_FEATURES, source=self.function_source, preprocessors=[ preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder() ]).requires_caching) task = dataset_providers.Task( "requires_cache", output_features=self.DEFAULT_OUTPUT_FEATURES, source=self.function_source, preprocessors=[ dataset_providers.CacheDatasetPlaceholder(required=True), preprocessors.tokenize, ]) self.assertTrue(task.requires_caching) with self.assertRaisesWithLiteralMatch( ValueError, "Task 'requires_cache' requires caching, but was called with " "`use_cached=False`."): task.get_dataset({"inputs": 512, "targets": 512}, use_cached=False) # We haven't actually cached the task, so it still fails but with a # different error. with self.assertRaisesWithLiteralMatch( AssertionError, "'requires_cache' does not exist in any of the task cache " "directories."): task.get_dataset({"inputs": 512, "targets": 512}, use_cached=True)
def register_dummy_task( task_name: str, dataset_fn: Callable[[str, str], tf.data.Dataset], output_feature_names: Sequence[str] = ("inputs", "targets")) -> None: """Register a dummy task for GetDatasetTest.""" dataset_providers.TaskRegistry.add( task_name, source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ dataset_providers.CacheDatasetPlaceholder(), preprocessors.append_eos_after_trim, ], output_features={ feat: dataset_providers.Feature(test_utils.sentencepiece_vocab()) for feat in output_feature_names }, metric_fns=[])
def _prepare_sources_and_tasks(self): clear_tasks() clear_mixtures() # Prepare TfdsSource # Note we don't use mock.Mock since they fail to pickle. fake_tfds_paths = { "train": [ { # pylint:disable=g-complex-comprehension "filename": "train.tfrecord-%05d-of-00002" % i, "skip": 0, "take": -1 } for i in range(2) ], "validation": [{ "filename": "validation.tfrecord-00000-of-00001", "skip": 0, "take": -1 }], } def _load_shard(shard_instruction, shuffle_files, seed): del shuffle_files del seed fname = shard_instruction["filename"] if "train" in fname: ds = get_fake_dataset("train") if fname.endswith("00000-of-00002"): return ds.take(2) else: return ds.skip(2) else: return get_fake_dataset("validation") fake_tfds = FakeLazyTfds(name="fake:0.0.0", load=get_fake_dataset, load_shard=_load_shard, info=FakeTfdsInfo(splits={ "train": None, "validation": None }), files=fake_tfds_paths.get, size=lambda x: 30 if x == "train" else 10) self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader", new=mock.Mock(return_value=fake_tfds)) self._tfds_patcher.start() # Set up data directory. self.test_tmpdir = self.get_tempdir() self.test_data_dir = os.path.join(self.test_tmpdir, "test_data") shutil.copytree(TEST_DATA_DIR, self.test_data_dir) for root, dirs, _ in os.walk(self.test_data_dir): for d in dirs + [""]: os.chmod(os.path.join(root, d), 0o777) # Prepare uncached TextLineTask. self.tfds_source = dataset_providers.TfdsDataSource( tfds_name="fake:0.0.0", splits=("train", "validation")) self.add_task("tfds_task", source=self.tfds_source) # Prepare TextLineSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tsv"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tsv) self.text_line_source = dataset_providers.TextLineDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tsv*"), }, skip_header_lines=1, ) self.add_task("text_line_task", source=self.text_line_source, preprocessors=(split_tsv_preprocessor, ) + self.DEFAULT_PREPROCESSORS) # Prepare TFExampleSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.tf_example_source = dataset_providers.TFExampleDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tfrecord*"), }, feature_description={ "prefix": tf.io.FixedLenFeature([], tf.string), "suffix": tf.io.FixedLenFeature([], tf.string), }) self.add_task("tf_example_task", source=self.tf_example_source) # Prepare FunctionDataSource self.function_source = dataset_providers.FunctionDataSource( dataset_fn=get_fake_dataset, splits=["train", "validation"]) self.add_task("function_task", source=self.function_source) # Prepare Task that is tokenized and preprocessed before caching. self.add_task("fully_processed_precache", source=self.function_source, preprocessors=( test_text_preprocessor, preprocessors.tokenize, token_preprocessor_no_sequence_length, dataset_providers.CacheDatasetPlaceholder(), )) # Prepare Task that is tokenized after caching. self.add_task("tokenized_postcache", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, token_preprocessor_no_sequence_length, )) # Prepare Task with randomization. self.random_task = self.add_task( "random_task", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, random_token_preprocessor, )) self.uncached_task = self.add_task("uncached_task", source=self.tfds_source) # Prepare cached task. dataset_utils.set_global_cache_dirs([self.test_data_dir]) self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task") _dump_fake_dataset( os.path.join(self.cached_task_dir, "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) _dump_fake_dataset( os.path.join(self.cached_task_dir, "validation.tfrecord"), _FAKE_TOKENIZED_DATASET["validation"], [2], _dump_examples_to_tfrecord) self.cached_task = self.add_task("cached_task", source=self.tfds_source) # Prepare cached plaintext task. _dump_fake_dataset( os.path.join(self.test_data_dir, "cached_plaintext_task", "train.tfrecord"), _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.cached_plaintext_task = self.add_task( "cached_plaintext_task", source=self.tfds_source, preprocessors=self.DEFAULT_PREPROCESSORS + (test_token_preprocessor, ))
class FakeTaskTest(absltest.TestCase): """TestCase that sets up fake cached and uncached tasks.""" DEFAULT_PREPROCESSORS = (test_text_preprocessor, preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder(), preprocessors.append_eos_after_trim) DEFAULT_OUTPUT_FEATURES = { "inputs": dataset_providers.Feature(sentencepiece_vocab()), "targets": dataset_providers.Feature(sentencepiece_vocab()) } def add_task( self, name, source, preprocessors=DEFAULT_PREPROCESSORS, # pylint:disable=redefined-outer-name output_features=None, **kwargs): if not output_features: output_features = { "inputs": dataset_providers.Feature(sentencepiece_vocab()), "targets": dataset_providers.Feature(sentencepiece_vocab()) } return TaskRegistry.add(name, source=source, preprocessors=preprocessors, output_features=output_features, **kwargs) def get_tempdir(self): try: flags.FLAGS.test_tmpdir except flags.UnparsedFlagAccessError: # Need to initialize flags when running `pytest`. flags.FLAGS(sys.argv) return self.create_tempdir().full_path def setUp(self): super().setUp() self.maxDiff = None # pylint:disable=invalid-name # Set up data directory. self.test_tmpdir = self.get_tempdir() self.test_data_dir = os.path.join(self.test_tmpdir, "test_data") shutil.copytree(TEST_DATA_DIR, self.test_data_dir) for root, dirs, _ in os.walk(self.test_data_dir): for d in dirs + [""]: os.chmod(os.path.join(root, d), 0o777) self._prepare_sources_and_tasks() def _prepare_sources_and_tasks(self): clear_tasks() clear_mixtures() # Prepare TfdsSource # Note we don't use mock.Mock since they fail to pickle. fake_tfds_paths = { "train": [ { # pylint:disable=g-complex-comprehension "filename": "train.tfrecord-%05d-of-00002" % i, "skip": 0, "take": -1 } for i in range(2) ], "validation": [{ "filename": "validation.tfrecord-00000-of-00001", "skip": 0, "take": -1 }], } def _load_shard(shard_instruction, shuffle_files, seed): del shuffle_files del seed fname = shard_instruction["filename"] if "train" in fname: ds = get_fake_dataset("train") if fname.endswith("00000-of-00002"): return ds.take(2) else: return ds.skip(2) else: return get_fake_dataset("validation") fake_tfds = FakeLazyTfds(name="fake:0.0.0", load=get_fake_dataset, load_shard=_load_shard, info=FakeTfdsInfo(splits={ "train": None, "validation": None }), files=fake_tfds_paths.get, size=lambda x: 30 if x == "train" else 10) self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader", new=mock.Mock(return_value=fake_tfds)) self._tfds_patcher.start() # Set up data directory. self.test_tmpdir = self.get_tempdir() self.test_data_dir = os.path.join(self.test_tmpdir, "test_data") shutil.copytree(TEST_DATA_DIR, self.test_data_dir) for root, dirs, _ in os.walk(self.test_data_dir): for d in dirs + [""]: os.chmod(os.path.join(root, d), 0o777) # Prepare uncached TextLineTask. self.tfds_source = dataset_providers.TfdsDataSource( tfds_name="fake:0.0.0", splits=("train", "validation")) self.add_task("tfds_task", source=self.tfds_source) # Prepare TextLineSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tsv"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tsv) self.text_line_source = dataset_providers.TextLineDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tsv*"), }, skip_header_lines=1, ) self.add_task("text_line_task", source=self.text_line_source, preprocessors=(split_tsv_preprocessor, ) + self.DEFAULT_PREPROCESSORS) # Prepare TFExampleSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.tf_example_source = dataset_providers.TFExampleDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tfrecord*"), }, feature_description={ "prefix": tf.io.FixedLenFeature([], tf.string), "suffix": tf.io.FixedLenFeature([], tf.string), }) self.add_task("tf_example_task", source=self.tf_example_source) # Prepare FunctionDataSource self.function_source = dataset_providers.FunctionDataSource( dataset_fn=get_fake_dataset, splits=["train", "validation"]) self.add_task("function_task", source=self.function_source) # Prepare Task that is tokenized and preprocessed before caching. self.add_task("fully_processed_precache", source=self.function_source, preprocessors=( test_text_preprocessor, preprocessors.tokenize, token_preprocessor_no_sequence_length, dataset_providers.CacheDatasetPlaceholder(), )) # Prepare Task that is tokenized after caching. self.add_task("tokenized_postcache", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, token_preprocessor_no_sequence_length, )) # Prepare Task with randomization. self.random_task = self.add_task( "random_task", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, random_token_preprocessor, )) self.uncached_task = self.add_task("uncached_task", source=self.tfds_source) # Prepare cached task. dataset_utils.set_global_cache_dirs([self.test_data_dir]) self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task") _dump_fake_dataset( os.path.join(self.cached_task_dir, "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) _dump_fake_dataset( os.path.join(self.cached_task_dir, "validation.tfrecord"), _FAKE_TOKENIZED_DATASET["validation"], [2], _dump_examples_to_tfrecord) self.cached_task = self.add_task("cached_task", source=self.tfds_source) # Prepare cached plaintext task. _dump_fake_dataset( os.path.join(self.test_data_dir, "cached_plaintext_task", "train.tfrecord"), _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.cached_plaintext_task = self.add_task( "cached_plaintext_task", source=self.tfds_source, preprocessors=self.DEFAULT_PREPROCESSORS + (test_token_preprocessor, )) def tearDown(self): super().tearDown() self._tfds_patcher.stop() tf.random.set_seed(None) def verify_task_matches_fake_datasets( # pylint:disable=dangerous-default-value self, task_name, use_cached, token_preprocessed=False, splits=("train", "validation"), sequence_length=_DEFAULT_SEQUENCE_LENGTH, num_shards=None): """Assert all splits for both tokenized datasets are correct.""" task = TaskRegistry.get(task_name) for split in splits: get_dataset = functools.partial(task.get_dataset, sequence_length, split, use_cached=use_cached, shuffle=False) if num_shards: ds = get_dataset( shard_info=dataset_providers.ShardInfo(0, num_shards)) for i in range(1, num_shards): ds = ds.concatenate( get_dataset(shard_info=dataset_providers.ShardInfo( i, num_shards))) else: ds = get_dataset() _assert_compare_to_fake_dataset( ds, split, task.output_features, sequence_length, token_preprocessed=token_preprocessed, )