def test_tfds_source_splits(self):
    default_splits_src = dataset_providers.TfdsDataSource("fake:0.0.0")
    self.assertSameElements(["train", "validation"], default_splits_src.splits)

    validation_split_src = dataset_providers.TfdsDataSource(
        "fake:0.0.0", splits=["validation"])
    self.assertSameElements(["validation"], validation_split_src.splits)

    sliced_split_src = dataset_providers.TfdsDataSource(
        "fake:0.0.0", splits={"validation": "train[0:1%]"})
    self.assertSameElements(["validation"], sliced_split_src.splits)
 def test_tfds_splits(self):
     self.assertSameElements(
         ["train", "validation"],
         dataset_providers.TfdsDataSource(tfds_name="fake:0.0.0").splits)
     self.assertSameElements(["validation"],
                             dataset_providers.TfdsDataSource(
                                 tfds_name="fake:0.0.0",
                                 splits=["validation"]).splits)
     self.assertSameElements(["validation"],
                             dataset_providers.TfdsDataSource(
                                 tfds_name="fake:0.0.0",
                                 splits={
                                     "validation": "train"
                                 }).splits)
Ejemplo n.º 3
0
    def _prepare_sources_and_tasks(self):
        clear_tasks()
        clear_mixtures()
        # Prepare TfdsSource
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction, shuffle_files, seed):
            del shuffle_files
            del seed
            fname = shard_instruction["filename"]
            if "train" in fname:
                ds = get_fake_dataset("train")
                if fname.endswith("00000-of-00002"):
                    return ds.take(2)
                else:
                    return ds.skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Prepare uncached TextLineTask.
        self.tfds_source = dataset_providers.TfdsDataSource(
            tfds_name="fake:0.0.0", splits=("train", "validation"))
        self.add_task("tfds_task", source=self.tfds_source)

        # Prepare TextLineSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        self.text_line_source = dataset_providers.TextLineDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tsv*"),
            },
            skip_header_lines=1,
        )
        self.add_task("text_line_task",
                      source=self.text_line_source,
                      preprocessors=(split_tsv_preprocessor, ) +
                      self.DEFAULT_PREPROCESSORS)

        # Prepare TFExampleSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        self.tf_example_source = dataset_providers.TFExampleDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            })
        self.add_task("tf_example_task", source=self.tf_example_source)

        # Prepare FunctionDataSource
        self.function_source = dataset_providers.FunctionDataSource(
            dataset_fn=get_fake_dataset, splits=["train", "validation"])
        self.add_task("function_task", source=self.function_source)

        # Prepare Task that is tokenized and preprocessed before caching.
        self.add_task("fully_processed_precache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                          dataset_providers.CacheDatasetPlaceholder(),
                      ))

        # Prepare Task that is tokenized after caching.
        self.add_task("tokenized_postcache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          dataset_providers.CacheDatasetPlaceholder(),
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                      ))

        # Prepare Task with randomization.
        self.random_task = self.add_task(
            "random_task",
            source=self.function_source,
            preprocessors=(
                test_text_preprocessor,
                dataset_providers.CacheDatasetPlaceholder(),
                preprocessors.tokenize,
                random_token_preprocessor,
            ))

        self.uncached_task = self.add_task("uncached_task",
                                           source=self.tfds_source)

        # Prepare cached task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir,
                         "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"],
            [2, 1], _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)
        self.cached_task = self.add_task("cached_task",
                                         source=self.tfds_source)

        # Prepare cached plaintext task.
        _dump_fake_dataset(
            os.path.join(self.test_data_dir, "cached_plaintext_task",
                         "train.tfrecord"),
            _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1],
            _dump_examples_to_tfrecord)
        self.cached_plaintext_task = self.add_task(
            "cached_plaintext_task",
            source=self.tfds_source,
            preprocessors=self.DEFAULT_PREPROCESSORS +
            (test_token_preprocessor, ))
Ejemplo n.º 4
0
 def test_no_tfds_version(self):
     with self.assertRaisesWithLiteralMatch(
             ValueError,
             "TFDS name must contain a version number, got: fake"):
         dataset_providers.TfdsDataSource(tfds_name="fake")