def test_optional_features(self):
    def _dummy_preprocessor(output):
      return lambda _: tf.data.Dataset.from_tensors(output)

    default_vocab = test_utils.sentencepiece_vocab()
    features = {
        "inputs":
            dataset_providers.Feature(vocabulary=default_vocab, required=False),
        "targets":
            dataset_providers.Feature(vocabulary=default_vocab, required=True),
    }

    test_utils.add_task(
        "text_missing_optional_feature",
        test_utils.get_fake_dataset,
        output_features=features,
        text_preprocessor=_dummy_preprocessor({"targets": "a"}))
    TaskRegistry.get_dataset(
        "text_missing_optional_feature", {"targets": 13},
        "train", use_cached=False)

    test_utils.add_task(
        "text_missing_required_feature",
        test_utils.get_fake_dataset,
        output_features=features,
        text_preprocessor=_dummy_preprocessor({"inputs": "a"}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset is missing expected output feature after preprocessing: "
        "targets"):
      TaskRegistry.get_dataset(
          "text_missing_required_feature", {"inputs": 13},
          "train", use_cached=False)
 def test_no_eos(self):
   default_vocab = test_utils.sentencepiece_vocab()
   features = {
       "inputs":
           dataset_providers.Feature(add_eos=True, vocabulary=default_vocab),
       "targets":
           dataset_providers.Feature(add_eos=False, vocabulary=default_vocab),
   }
   test_utils.add_task(
       "task_no_eos", test_utils.get_fake_dataset, output_features=features
   )
   fn_task = TaskRegistry.get("task_no_eos")
   test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
Пример #3
0
 def test_dtype(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs":
         # defaults to int32
         dataset_providers.Feature(vocabulary=default_vocab),
         "targets":
         dataset_providers.Feature(dtype=tf.int64,
                                   vocabulary=default_vocab),
     }
     test_utils.add_task("task_dtypes",
                         test_utils.get_fake_dataset,
                         output_features=features)
     dtype_task = TaskRegistry.get("task_dtypes")
     test_utils.verify_task_matches_fake_datasets(dtype_task,
                                                  use_cached=False)
Пример #4
0
def add_tfds_task(name,
                  tfds_name="fake:0.0.0",
                  text_preprocessor=test_text_preprocessor,
                  token_preprocessor=None,
                  splits=None):
    TaskRegistry.add(name,
                     dataset_providers.TfdsTask,
                     tfds_name=tfds_name,
                     text_preprocessor=text_preprocessor,
                     token_preprocessor=token_preprocessor,
                     output_features=dataset_providers.Feature(
                         sentencepiece_vocab()),
                     metric_fns=[],
                     splits=splits)
Пример #5
0
def add_task(name,
             dataset_fn,
             text_preprocessor=test_text_preprocessor,
             token_preprocessor=None,
             splits=("train", "validation"),
             **kwargs):
    if "output_features" not in kwargs:
        kwargs["output_features"] = dataset_providers.Feature(
            sentencepiece_vocab())
    TaskRegistry.add(name,
                     dataset_fn=dataset_fn,
                     splits=splits,
                     text_preprocessor=text_preprocessor,
                     token_preprocessor=token_preprocessor,
                     metric_fns=[],
                     **kwargs)
Пример #6
0
def register_dummy_task(
    task_name: str,
    dataset_fn: Callable[[str, str], tf.data.Dataset],
    output_feature_names: Sequence[str] = ("inputs", "targets")) -> None:
  """Register a dummy task for GetDatasetTest."""
  dataset_providers.TaskRegistry.add(
      task_name,
      dataset_providers.TaskV3,
      source=dataset_providers.FunctionSource(
          dataset_fn=dataset_fn, splits=["train", "validation"]),
      preprocessors=[dataset_providers.CacheDatasetPlaceholder()],
      output_features={
          feat: dataset_providers.Feature(test_utils.sentencepiece_vocab())
          for feat in output_feature_names
      },
      metric_fns=[])
Пример #7
0
    def test_v3_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesRegex(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.TaskV3(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesRegex(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.TaskV3(
                "sequence_length_pre_cache",
                dataset_providers.FunctionSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
Пример #8
0
    def setUp(self):
        super().setUp()
        self.maxDiff = None  # pylint:disable=invalid-name

        # Mock TFDS
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction):
            fname = shard_instruction["filename"]
            if "train" in fname:
                if fname.endswith("00000-of-00002"):
                    return get_fake_dataset("train").take(2)
                else:
                    return get_fake_dataset("train").skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.data.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Register a cached test Task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        clear_tasks()
        add_tfds_task("cached_task")

        # Prepare cached task.
        self.cached_task = TaskRegistry.get("cached_task")
        cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(os.path.join(cached_task_dir, "train.tfrecord"),
                           _FAKE_TOKENIZED_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)

        # Prepare uncached TfdsTask.
        add_tfds_task("uncached_task")
        self.uncached_task = TaskRegistry.get("uncached_task")

        # Prepare uncached, random TfdsTask
        add_tfds_task("uncached_random_task",
                      token_preprocessor=random_token_preprocessor)
        self.uncached_random_task = TaskRegistry.get("uncached_random_task")

        # Prepare uncached TextLineTask.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        TaskRegistry.add("text_line_task",
                         dataset_providers.TextLineTask,
                         split_to_filepattern={
                             "train":
                             os.path.join(self.test_data_dir, "train.tsv*"),
                         },
                         skip_header_lines=1,
                         text_preprocessor=[
                             _split_tsv_preprocessor, test_text_preprocessor
                         ],
                         output_features=dataset_providers.Feature(
                             sentencepiece_vocab()),
                         metric_fns=[])
        self.text_line_task = TaskRegistry.get("text_line_task")

        # Prepare uncached TFExampleTask
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        TaskRegistry.add(
            "tf_example_task",
            dataset_providers.TFExampleTask,
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            },
            text_preprocessor=[test_text_preprocessor],
            output_features=dataset_providers.Feature(sentencepiece_vocab()),
            metric_fns=[])
        self.tf_example_task = TaskRegistry.get("tf_example_task")

        # Prepare uncached Task.
        def _dataset_fn(split, shuffle_files):
            del split
            del shuffle_files
            filepattern = os.path.join(self.test_data_dir, "train.tsv*")
            return tf.data.TextLineDataset(filepattern)

        TaskRegistry.add("general_task",
                         dataset_providers.Task,
                         dataset_fn=_dataset_fn,
                         splits=["train"],
                         text_preprocessor=[
                             _split_tsv_preprocessor, test_text_preprocessor
                         ],
                         output_features=dataset_providers.Feature(
                             sentencepiece_vocab()),
                         metric_fns=[])
        self.general_task = TaskRegistry.get("general_task")