def _prepare_sources_and_tasks(self): clear_tasks() clear_mixtures() # Prepare TfdsSource # Note we don't use mock.Mock since they fail to pickle. fake_tfds_paths = { "train": [ { # pylint:disable=g-complex-comprehension "filename": "train.tfrecord-%05d-of-00002" % i, "skip": 0, "take": -1 } for i in range(2) ], "validation": [{ "filename": "validation.tfrecord-00000-of-00001", "skip": 0, "take": -1 }], } def _load_shard(shard_instruction, shuffle_files, seed): del shuffle_files del seed fname = shard_instruction["filename"] if "train" in fname: ds = get_fake_dataset("train") if fname.endswith("00000-of-00002"): return ds.take(2) else: return ds.skip(2) else: return get_fake_dataset("validation") fake_tfds = FakeLazyTfds(name="fake:0.0.0", load=get_fake_dataset, load_shard=_load_shard, info=FakeTfdsInfo(splits={ "train": None, "validation": None }), files=fake_tfds_paths.get, size=lambda x: 30 if x == "train" else 10) self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader", new=mock.Mock(return_value=fake_tfds)) self._tfds_patcher.start() # Set up data directory. self.test_tmpdir = self.get_tempdir() self.test_data_dir = os.path.join(self.test_tmpdir, "test_data") shutil.copytree(TEST_DATA_DIR, self.test_data_dir) for root, dirs, _ in os.walk(self.test_data_dir): for d in dirs + [""]: os.chmod(os.path.join(root, d), 0o777) # Prepare uncached TextLineTask. self.tfds_source = dataset_providers.TfdsDataSource( tfds_name="fake:0.0.0", splits=("train", "validation")) self.add_task("tfds_task", source=self.tfds_source) # Prepare TextLineSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tsv"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tsv) self.text_line_source = dataset_providers.TextLineDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tsv*"), }, skip_header_lines=1, ) self.add_task("text_line_task", source=self.text_line_source, preprocessors=(split_tsv_preprocessor, ) + self.DEFAULT_PREPROCESSORS) # Prepare TFExampleSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.tf_example_source = dataset_providers.TFExampleDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tfrecord*"), }, feature_description={ "prefix": tf.io.FixedLenFeature([], tf.string), "suffix": tf.io.FixedLenFeature([], tf.string), }) self.add_task("tf_example_task", source=self.tf_example_source) # Prepare FunctionDataSource self.function_source = dataset_providers.FunctionDataSource( dataset_fn=get_fake_dataset, splits=["train", "validation"]) self.add_task("function_task", source=self.function_source) # Prepare Task that is tokenized and preprocessed before caching. self.add_task("fully_processed_precache", source=self.function_source, preprocessors=( test_text_preprocessor, preprocessors.tokenize, token_preprocessor_no_sequence_length, dataset_providers.CacheDatasetPlaceholder(), )) # Prepare Task that is tokenized after caching. self.add_task("tokenized_postcache", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, token_preprocessor_no_sequence_length, )) # Prepare Task with randomization. self.random_task = self.add_task( "random_task", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, random_token_preprocessor, )) self.uncached_task = self.add_task("uncached_task", source=self.tfds_source) # Prepare cached task. dataset_utils.set_global_cache_dirs([self.test_data_dir]) self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task") _dump_fake_dataset( os.path.join(self.cached_task_dir, "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) _dump_fake_dataset( os.path.join(self.cached_task_dir, "validation.tfrecord"), _FAKE_TOKENIZED_DATASET["validation"], [2], _dump_examples_to_tfrecord) self.cached_task = self.add_task("cached_task", source=self.tfds_source) # Prepare cached plaintext task. _dump_fake_dataset( os.path.join(self.test_data_dir, "cached_plaintext_task", "train.tfrecord"), _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.cached_plaintext_task = self.add_task( "cached_plaintext_task", source=self.tfds_source, preprocessors=self.DEFAULT_PREPROCESSORS + (test_token_preprocessor, ))
def test_set_global_cache_dirs(self): utils.set_global_cache_dirs([]) self.assertFalse(self.cached_task.cache_dir) utils.set_global_cache_dirs([self.test_data_dir]) self.assertTrue(self.cached_task.cache_dir)