예제 #1
0
    def test_v3_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesRegex(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.TaskV3(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesRegex(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.TaskV3(
                "sequence_length_pre_cache",
                dataset_providers.FunctionSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
예제 #2
0
def register_dummy_task(
    task_name: str,
    dataset_fn: Callable[[str, str], tf.data.Dataset],
    output_feature_names: Sequence[str] = ("inputs", "targets")) -> None:
  """Register a dummy task for GetDatasetTest."""
  dataset_providers.TaskRegistry.add(
      task_name,
      dataset_providers.TaskV3,
      source=dataset_providers.FunctionSource(
          dataset_fn=dataset_fn, splits=["train", "validation"]),
      preprocessors=[dataset_providers.CacheDatasetPlaceholder()],
      output_features={
          feat: dataset_providers.Feature(test_utils.sentencepiece_vocab())
          for feat in output_feature_names
      },
      metric_fns=[])
    def setUp(self):
        super().setUp()
        self.maxDiff = None  # pylint:disable=invalid-name

        # Mock TFDS
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction, shuffle_files, seed):
            del shuffle_files
            del seed
            fname = shard_instruction["filename"]
            if "train" in fname:
                if fname.endswith("00000-of-00002"):
                    return get_fake_dataset("train").take(2)
                else:
                    return get_fake_dataset("train").skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.data.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Register a cached test Task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        clear_tasks()
        add_tfds_task("cached_task",
                      token_preprocessor=test_token_preprocessor)
        add_tfds_task("cached_task_no_token_prep")

        # Prepare cached tasks.
        self.cached_task = TaskRegistry.get("cached_task")
        cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(os.path.join(cached_task_dir, "train.tfrecord"),
                           _FAKE_TOKENIZED_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)
        shutil.copytree(
            cached_task_dir,
            os.path.join(self.test_data_dir, "cached_task_no_token_prep"))

        # Prepare uncached TfdsTask.
        add_tfds_task("uncached_task",
                      token_preprocessor=test_token_preprocessor)
        add_tfds_task("uncached_task_no_token_prep")
        self.uncached_task = TaskRegistry.get("uncached_task")
        # Prepare uncached, random TfdsTask
        add_tfds_task("uncached_random_task",
                      token_preprocessor=random_token_preprocessor)
        self.uncached_random_task = TaskRegistry.get("uncached_random_task")

        # Prepare uncached TextLineTask.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        TaskRegistry.add("text_line_task",
                         dataset_providers.TextLineTask,
                         split_to_filepattern={
                             "train":
                             os.path.join(self.test_data_dir, "train.tsv*"),
                         },
                         skip_header_lines=1,
                         text_preprocessor=[
                             _split_tsv_preprocessor, test_text_preprocessor
                         ],
                         output_features=dataset_providers.Feature(
                             sentencepiece_vocab()),
                         metric_fns=[])
        self.text_line_task = TaskRegistry.get("text_line_task")

        # Prepare uncached TFExampleTask
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        TaskRegistry.add(
            "tf_example_task",
            dataset_providers.TFExampleTask,
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            },
            text_preprocessor=[test_text_preprocessor],
            output_features=dataset_providers.Feature(sentencepiece_vocab()),
            metric_fns=[])
        self.tf_example_task = TaskRegistry.get("tf_example_task")

        # Prepare uncached Task.
        def _dataset_fn(split,
                        shuffle_files,
                        filepattern=os.path.join(self.test_data_dir,
                                                 "train.tsv*")):
            del split
            files = tf.data.Dataset.list_files(filepattern,
                                               shuffle=shuffle_files)
            return files.interleave(
                lambda f: tf.data.TextLineDataset(f).skip(1),
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

        TaskRegistry.add("general_task",
                         dataset_providers.Task,
                         dataset_fn=_dataset_fn,
                         splits=["train"],
                         text_preprocessor=[
                             _split_tsv_preprocessor, test_text_preprocessor
                         ],
                         output_features=dataset_providers.Feature(
                             sentencepiece_vocab()),
                         metric_fns=[])
        self.general_task = TaskRegistry.get("general_task")

        # Prepare uncached TaskV3.
        TaskRegistry.add("task_v3",
                         dataset_providers.TaskV3,
                         source=dataset_providers.FunctionSource(
                             dataset_fn=get_fake_dataset,
                             splits=["train", "validation"]),
                         preprocessors=[
                             test_text_preprocessor,
                             preprocessors.tokenize,
                             token_preprocessor_no_sequence_length,
                             dataset_providers.CacheDatasetPlaceholder(),
                         ],
                         output_features={
                             "inputs":
                             dataset_providers.Feature(sentencepiece_vocab()),
                             "targets":
                             dataset_providers.Feature(sentencepiece_vocab()),
                         },
                         metric_fns=[])
        self.task_v3 = TaskRegistry.get("task_v3")

        # Prepare uncached TaskV3 with no caching before tokenization.
        TaskRegistry.add("task_v3_tokenized_postcache",
                         dataset_providers.TaskV3,
                         source=dataset_providers.FunctionSource(
                             dataset_fn=get_fake_dataset,
                             splits=["train", "validation"]),
                         preprocessors=[
                             test_text_preprocessor,
                             dataset_providers.CacheDatasetPlaceholder(),
                             preprocessors.tokenize,
                             token_preprocessor_no_sequence_length,
                         ],
                         output_features={
                             "inputs":
                             dataset_providers.Feature(sentencepiece_vocab()),
                             "targets":
                             dataset_providers.Feature(sentencepiece_vocab()),
                         },
                         metric_fns=[])