コード例 #1
0
    def test_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.Task(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.Task(
                "sequence_length_pre_cache",
                dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
    def test_supports_caching(self):
        self.assertFalse(
            dataset_providers.Task(
                "nosupports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[]).supports_caching)

        self.assertFalse(
            dataset_providers.Task(
                "nosupports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[preprocessors.tokenize]).supports_caching)

        self.assertTrue(
            dataset_providers.Task(
                "supports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[
                    preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder()
                ]).supports_caching)

        self.assertTrue(
            dataset_providers.Task(
                "supports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[
                    dataset_providers.CacheDatasetPlaceholder(required=True),
                    preprocessors.tokenize,
                ]).supports_caching)

        self.assertTrue(
            dataset_providers.Task(
                "supports_cache",
                source=self.function_source,
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                preprocessors=[
                    dataset_providers.CacheDatasetPlaceholder(),
                ]).supports_caching)
    def test_requires_caching(self):
        self.assertFalse(
            dataset_providers.Task(
                "nosupports_cache",
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                source=self.function_source,
                preprocessors=[preprocessors.tokenize]).requires_caching)

        self.assertFalse(
            dataset_providers.Task(
                "supports_cache",
                output_features=self.DEFAULT_OUTPUT_FEATURES,
                source=self.function_source,
                preprocessors=[
                    preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder()
                ]).requires_caching)

        task = dataset_providers.Task(
            "requires_cache",
            output_features=self.DEFAULT_OUTPUT_FEATURES,
            source=self.function_source,
            preprocessors=[
                dataset_providers.CacheDatasetPlaceholder(required=True),
                preprocessors.tokenize,
            ])

        self.assertTrue(task.requires_caching)

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task 'requires_cache' requires caching, but was called with "
                "`use_cached=False`."):
            task.get_dataset({"inputs": 512, "targets": 512}, use_cached=False)

        # We haven't actually cached the task, so it still fails but with a
        # different error.
        with self.assertRaisesWithLiteralMatch(
                AssertionError,
                "'requires_cache' does not exist in any of the task cache "
                "directories."):
            task.get_dataset({"inputs": 512, "targets": 512}, use_cached=True)
def register_dummy_task(
    task_name: str,
    dataset_fn: Callable[[str, str], tf.data.Dataset],
    output_feature_names: Sequence[str] = ("inputs", "targets")) -> None:
  """Register a dummy task for GetDatasetTest."""
  dataset_providers.TaskRegistry.add(
      task_name,
      source=dataset_providers.FunctionDataSource(
          dataset_fn=dataset_fn, splits=["train", "validation"]),
      preprocessors=[
          dataset_providers.CacheDatasetPlaceholder(),
          preprocessors.append_eos_after_trim,
      ],
      output_features={
          feat: dataset_providers.Feature(test_utils.sentencepiece_vocab())
          for feat in output_feature_names
      },
      metric_fns=[])
コード例 #5
0
    def _prepare_sources_and_tasks(self):
        clear_tasks()
        clear_mixtures()
        # Prepare TfdsSource
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction, shuffle_files, seed):
            del shuffle_files
            del seed
            fname = shard_instruction["filename"]
            if "train" in fname:
                ds = get_fake_dataset("train")
                if fname.endswith("00000-of-00002"):
                    return ds.take(2)
                else:
                    return ds.skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Prepare uncached TextLineTask.
        self.tfds_source = dataset_providers.TfdsDataSource(
            tfds_name="fake:0.0.0", splits=("train", "validation"))
        self.add_task("tfds_task", source=self.tfds_source)

        # Prepare TextLineSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        self.text_line_source = dataset_providers.TextLineDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tsv*"),
            },
            skip_header_lines=1,
        )
        self.add_task("text_line_task",
                      source=self.text_line_source,
                      preprocessors=(split_tsv_preprocessor, ) +
                      self.DEFAULT_PREPROCESSORS)

        # Prepare TFExampleSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        self.tf_example_source = dataset_providers.TFExampleDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            })
        self.add_task("tf_example_task", source=self.tf_example_source)

        # Prepare FunctionDataSource
        self.function_source = dataset_providers.FunctionDataSource(
            dataset_fn=get_fake_dataset, splits=["train", "validation"])
        self.add_task("function_task", source=self.function_source)

        # Prepare Task that is tokenized and preprocessed before caching.
        self.add_task("fully_processed_precache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                          dataset_providers.CacheDatasetPlaceholder(),
                      ))

        # Prepare Task that is tokenized after caching.
        self.add_task("tokenized_postcache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          dataset_providers.CacheDatasetPlaceholder(),
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                      ))

        # Prepare Task with randomization.
        self.random_task = self.add_task(
            "random_task",
            source=self.function_source,
            preprocessors=(
                test_text_preprocessor,
                dataset_providers.CacheDatasetPlaceholder(),
                preprocessors.tokenize,
                random_token_preprocessor,
            ))

        self.uncached_task = self.add_task("uncached_task",
                                           source=self.tfds_source)

        # Prepare cached task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir,
                         "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"],
            [2, 1], _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)
        self.cached_task = self.add_task("cached_task",
                                         source=self.tfds_source)

        # Prepare cached plaintext task.
        _dump_fake_dataset(
            os.path.join(self.test_data_dir, "cached_plaintext_task",
                         "train.tfrecord"),
            _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1],
            _dump_examples_to_tfrecord)
        self.cached_plaintext_task = self.add_task(
            "cached_plaintext_task",
            source=self.tfds_source,
            preprocessors=self.DEFAULT_PREPROCESSORS +
            (test_token_preprocessor, ))
コード例 #6
0
class FakeTaskTest(absltest.TestCase):
    """TestCase that sets up fake cached and uncached tasks."""

    DEFAULT_PREPROCESSORS = (test_text_preprocessor, preprocessors.tokenize,
                             dataset_providers.CacheDatasetPlaceholder(),
                             preprocessors.append_eos_after_trim)

    DEFAULT_OUTPUT_FEATURES = {
        "inputs": dataset_providers.Feature(sentencepiece_vocab()),
        "targets": dataset_providers.Feature(sentencepiece_vocab())
    }

    def add_task(
            self,
            name,
            source,
            preprocessors=DEFAULT_PREPROCESSORS,  # pylint:disable=redefined-outer-name
            output_features=None,
            **kwargs):

        if not output_features:
            output_features = {
                "inputs": dataset_providers.Feature(sentencepiece_vocab()),
                "targets": dataset_providers.Feature(sentencepiece_vocab())
            }

        return TaskRegistry.add(name,
                                source=source,
                                preprocessors=preprocessors,
                                output_features=output_features,
                                **kwargs)

    def get_tempdir(self):
        try:
            flags.FLAGS.test_tmpdir
        except flags.UnparsedFlagAccessError:
            # Need to initialize flags when running `pytest`.
            flags.FLAGS(sys.argv)
        return self.create_tempdir().full_path

    def setUp(self):
        super().setUp()
        self.maxDiff = None  # pylint:disable=invalid-name

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        self._prepare_sources_and_tasks()

    def _prepare_sources_and_tasks(self):
        clear_tasks()
        clear_mixtures()
        # Prepare TfdsSource
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction, shuffle_files, seed):
            del shuffle_files
            del seed
            fname = shard_instruction["filename"]
            if "train" in fname:
                ds = get_fake_dataset("train")
                if fname.endswith("00000-of-00002"):
                    return ds.take(2)
                else:
                    return ds.skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Prepare uncached TextLineTask.
        self.tfds_source = dataset_providers.TfdsDataSource(
            tfds_name="fake:0.0.0", splits=("train", "validation"))
        self.add_task("tfds_task", source=self.tfds_source)

        # Prepare TextLineSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        self.text_line_source = dataset_providers.TextLineDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tsv*"),
            },
            skip_header_lines=1,
        )
        self.add_task("text_line_task",
                      source=self.text_line_source,
                      preprocessors=(split_tsv_preprocessor, ) +
                      self.DEFAULT_PREPROCESSORS)

        # Prepare TFExampleSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        self.tf_example_source = dataset_providers.TFExampleDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            })
        self.add_task("tf_example_task", source=self.tf_example_source)

        # Prepare FunctionDataSource
        self.function_source = dataset_providers.FunctionDataSource(
            dataset_fn=get_fake_dataset, splits=["train", "validation"])
        self.add_task("function_task", source=self.function_source)

        # Prepare Task that is tokenized and preprocessed before caching.
        self.add_task("fully_processed_precache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                          dataset_providers.CacheDatasetPlaceholder(),
                      ))

        # Prepare Task that is tokenized after caching.
        self.add_task("tokenized_postcache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          dataset_providers.CacheDatasetPlaceholder(),
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                      ))

        # Prepare Task with randomization.
        self.random_task = self.add_task(
            "random_task",
            source=self.function_source,
            preprocessors=(
                test_text_preprocessor,
                dataset_providers.CacheDatasetPlaceholder(),
                preprocessors.tokenize,
                random_token_preprocessor,
            ))

        self.uncached_task = self.add_task("uncached_task",
                                           source=self.tfds_source)

        # Prepare cached task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir,
                         "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"],
            [2, 1], _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)
        self.cached_task = self.add_task("cached_task",
                                         source=self.tfds_source)

        # Prepare cached plaintext task.
        _dump_fake_dataset(
            os.path.join(self.test_data_dir, "cached_plaintext_task",
                         "train.tfrecord"),
            _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1],
            _dump_examples_to_tfrecord)
        self.cached_plaintext_task = self.add_task(
            "cached_plaintext_task",
            source=self.tfds_source,
            preprocessors=self.DEFAULT_PREPROCESSORS +
            (test_token_preprocessor, ))

    def tearDown(self):
        super().tearDown()
        self._tfds_patcher.stop()
        tf.random.set_seed(None)

    def verify_task_matches_fake_datasets(  # pylint:disable=dangerous-default-value
            self,
            task_name,
            use_cached,
            token_preprocessed=False,
            splits=("train", "validation"),
            sequence_length=_DEFAULT_SEQUENCE_LENGTH,
            num_shards=None):
        """Assert all splits for both tokenized datasets are correct."""
        task = TaskRegistry.get(task_name)
        for split in splits:
            get_dataset = functools.partial(task.get_dataset,
                                            sequence_length,
                                            split,
                                            use_cached=use_cached,
                                            shuffle=False)
            if num_shards:
                ds = get_dataset(
                    shard_info=dataset_providers.ShardInfo(0, num_shards))
                for i in range(1, num_shards):
                    ds = ds.concatenate(
                        get_dataset(shard_info=dataset_providers.ShardInfo(
                            i, num_shards)))
            else:
                ds = get_dataset()
            _assert_compare_to_fake_dataset(
                ds,
                split,
                task.output_features,
                sequence_length,
                token_preprocessed=token_preprocessed,
            )