def test_no_eos(self):
   default_vocab = test_utils.sentencepiece_vocab()
   features = {
       "inputs":
           dataset_providers.Feature(add_eos=True, vocabulary=default_vocab),
       "targets":
           dataset_providers.Feature(add_eos=False, vocabulary=default_vocab),
   }
   self.add_task("task_no_eos", self.function_source, output_features=features)
   self.verify_task_matches_fake_datasets("task_no_eos", use_cached=False)
Exemplo n.º 2
0
    def test_feature_validation(self):
        default_vocab = test_utils.sentencepiece_vocab()
        features = {
            "inputs":
            dataset_providers.Feature(vocabulary=default_vocab,
                                      required=False),
            "targets":
            dataset_providers.Feature(vocabulary=default_vocab, required=True),
        }

        def _materialize(output):
            task = dataset_providers.Task(
                "feature_validation_task",
                self.function_source,
                output_features=features,
                preprocessors=(
                    lambda _: tf.data.Dataset.from_tensors(output), ),
                metric_fns=[],
            )
            list(
                task.get_dataset({
                    "inputs": 13,
                    "targets": 13
                },
                                 "train",
                                 use_cached=False).as_numpy_iterator())

        # Missing optional feature: OK
        _materialize({"targets": [0]})

        # Missing required feature.
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task dataset is missing expected output feature after preprocessing: "
                "targets"):
            _materialize({"inputs": [0]})

        # Wrong type.
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task dataset has incorrect type for feature 'targets' after "
                "preprocessing: Got string, expected int32"):
            _materialize({"targets": ["wrong type"]})

        # Wrong rank.
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task dataset has incorrect rank for feature 'targets' after "
                "preprocessing: Got 0, expected 1"):
            _materialize({"targets": 0})
Exemplo n.º 3
0
 def _task_from_tensor_slices(name, tensor_slices, label_classes):
   return dataset_providers.Task(
       name,
       dataset_providers.FunctionDataSource(
           lambda split, shuffle_files:
           tf.data.Dataset.from_tensor_slices(tensor_slices),
           splits=("validation")),
       preprocessors=[utils.map_over_dataset(lambda ex: {
           "inputs": tf.range(ex["inputs_lengths"]),
           "targets": tf.range(ex["targets_lengths"]),
           "targets_pretokenized": ex["targets_pretokenized"],
       })],
       postprocess_fn=functools.partial(
           _string_label_to_class_id_postprocessor,
           label_classes=label_classes),
       output_features={"inputs": dataset_providers.Feature(mock.Mock()),
                        "targets": dataset_providers.Feature(mock.Mock())}
   )
Exemplo n.º 4
0
    def add_task(
            self,
            name,
            source,
            preprocessors=DEFAULT_PREPROCESSORS,  # pylint:disable=redefined-outer-name
            output_features=None,
            **kwargs):

        if not output_features:
            output_features = {
                "inputs": dataset_providers.Feature(sentencepiece_vocab()),
                "targets": dataset_providers.Feature(sentencepiece_vocab())
            }

        return TaskRegistry.add(name,
                                source=source,
                                preprocessors=preprocessors,
                                output_features=output_features,
                                **kwargs)
Exemplo n.º 5
0
    def test_dtype(self):
        default_vocab = test_utils.sentencepiece_vocab()
        features = {
            "inputs":
            # defaults to int32
            dataset_providers.Feature(vocabulary=default_vocab),
            "targets":
            dataset_providers.Feature(dtype=tf.int64,
                                      vocabulary=default_vocab),
        }

        self.add_task(
            "task_dtypes",
            self.function_source,
            preprocessors=self.DEFAULT_PREPROCESSORS + (
                utils.map_over_dataset(
                    lambda x: {
                        k: tf.cast(v, tf.int64) if k == "targets" else v  # pylint:disable=g-long-lambda
                        for k, v in x.items()
                    }), ),
            output_features=features)
        self.verify_task_matches_fake_datasets("task_dtypes", use_cached=False)
Exemplo n.º 6
0
def get_mocked_task(
    name: str = "mocked_test",
    predict_metric_fns: Sequence[Callable] = (_sequence_accuracy_metric,),
    score_metric_fns: Sequence[Callable] = ()) -> mock.Mock:
  task = mock.Mock()
  task.name = name
  task.score_metric_fns = list(score_metric_fns)
  task.predict_metric_fns = list(predict_metric_fns)
  task.metric_fns = list(predict_metric_fns) + list(score_metric_fns)
  # Identity postprocess function
  task.postprocess_fn = lambda d, example, is_target: d

  mock_vocab = mock.Mock()
  task.output_features = {"targets": dataset_providers.Feature(mock_vocab)}
  return task
Exemplo n.º 7
0
    def test_value_errors(self):
        dataset_fn = (lambda split, shuffle_files: tf.data.Dataset.
                      from_tensors(["test"]))
        output_features = {
            "inputs":
            dataset_providers.Feature(test_utils.sentencepiece_vocab())
        }

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "`CacheDatasetPlaceholder` can appear at most once in the "
                "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'."
        ):
            dataset_providers.Task(
                "multiple_cache_placeholders",
                source=dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn, splits=["train", "validation"]),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    dataset_providers.CacheDatasetPlaceholder(),
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "'test_token_preprocessor' has a `sequence_length` argument but occurs "
                "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This "
                "is not allowed since the sequence length is specified at run time."
        ):
            dataset_providers.Task(
                "sequence_length_pre_cache",
                dataset_providers.FunctionDataSource(
                    dataset_fn=dataset_fn,
                    splits=["train"],
                ),
                preprocessors=[
                    test_utils.test_text_preprocessor, preprocessors.tokenize,
                    test_utils.test_token_preprocessor,
                    dataset_providers.CacheDatasetPlaceholder()
                ],
                output_features=output_features,
                metric_fns=[])
def register_dummy_task(
    task_name: str,
    dataset_fn: Callable[[str, str], tf.data.Dataset],
    output_feature_names: Sequence[str] = ("inputs", "targets")) -> None:
  """Register a dummy task for GetDatasetTest."""
  dataset_providers.TaskRegistry.add(
      task_name,
      source=dataset_providers.FunctionDataSource(
          dataset_fn=dataset_fn, splits=["train", "validation"]),
      preprocessors=[
          dataset_providers.CacheDatasetPlaceholder(),
          preprocessors.append_eos_after_trim,
      ],
      output_features={
          feat: dataset_providers.Feature(test_utils.sentencepiece_vocab())
          for feat in output_feature_names
      },
      metric_fns=[])
Exemplo n.º 9
0
def register_dummy_task(task_name: str,
                        dataset_fn: Callable[[str, str], tf.data.Dataset],
                        output_feature_names: Sequence[str] = ("inputs",
                                                               "targets"),
                        postprocess_fn=None,
                        metrics_fn=None) -> None:
    """Register a dummy task for GetDatasetTest."""
    dataset_providers.TaskRegistry.add(
        task_name,
        source=dataset_providers.FunctionDataSource(
            dataset_fn=dataset_fn, splits=["train", "validation"]),
        preprocessors=[preprocessors.append_eos_after_trim],
        postprocess_fn=postprocess_fn,
        output_features={
            # Mock the sentencepiece vocabulary.
            feat: dataset_providers.Feature(mock.Mock(eos_id=True))
            for feat in output_feature_names
        },
        metric_fns=metrics_fn)
Exemplo n.º 10
0
class FakeTaskTest(absltest.TestCase):
    """TestCase that sets up fake cached and uncached tasks."""

    DEFAULT_PREPROCESSORS = (test_text_preprocessor, preprocessors.tokenize,
                             dataset_providers.CacheDatasetPlaceholder(),
                             preprocessors.append_eos_after_trim)

    DEFAULT_OUTPUT_FEATURES = {
        "inputs": dataset_providers.Feature(sentencepiece_vocab()),
        "targets": dataset_providers.Feature(sentencepiece_vocab())
    }

    def add_task(
            self,
            name,
            source,
            preprocessors=DEFAULT_PREPROCESSORS,  # pylint:disable=redefined-outer-name
            output_features=None,
            **kwargs):

        if not output_features:
            output_features = {
                "inputs": dataset_providers.Feature(sentencepiece_vocab()),
                "targets": dataset_providers.Feature(sentencepiece_vocab())
            }

        return TaskRegistry.add(name,
                                source=source,
                                preprocessors=preprocessors,
                                output_features=output_features,
                                **kwargs)

    def get_tempdir(self):
        try:
            flags.FLAGS.test_tmpdir
        except flags.UnparsedFlagAccessError:
            # Need to initialize flags when running `pytest`.
            flags.FLAGS(sys.argv)
        return self.create_tempdir().full_path

    def setUp(self):
        super().setUp()
        self.maxDiff = None  # pylint:disable=invalid-name

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        self._prepare_sources_and_tasks()

    def _prepare_sources_and_tasks(self):
        clear_tasks()
        clear_mixtures()
        # Prepare TfdsSource
        # Note we don't use mock.Mock since they fail to pickle.
        fake_tfds_paths = {
            "train": [
                {  # pylint:disable=g-complex-comprehension
                    "filename": "train.tfrecord-%05d-of-00002" % i,
                    "skip": 0,
                    "take": -1
                } for i in range(2)
            ],
            "validation": [{
                "filename": "validation.tfrecord-00000-of-00001",
                "skip": 0,
                "take": -1
            }],
        }

        def _load_shard(shard_instruction, shuffle_files, seed):
            del shuffle_files
            del seed
            fname = shard_instruction["filename"]
            if "train" in fname:
                ds = get_fake_dataset("train")
                if fname.endswith("00000-of-00002"):
                    return ds.take(2)
                else:
                    return ds.skip(2)
            else:
                return get_fake_dataset("validation")

        fake_tfds = FakeLazyTfds(name="fake:0.0.0",
                                 load=get_fake_dataset,
                                 load_shard=_load_shard,
                                 info=FakeTfdsInfo(splits={
                                     "train": None,
                                     "validation": None
                                 }),
                                 files=fake_tfds_paths.get,
                                 size=lambda x: 30 if x == "train" else 10)
        self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader",
                                        new=mock.Mock(return_value=fake_tfds))
        self._tfds_patcher.start()

        # Set up data directory.
        self.test_tmpdir = self.get_tempdir()
        self.test_data_dir = os.path.join(self.test_tmpdir, "test_data")
        shutil.copytree(TEST_DATA_DIR, self.test_data_dir)
        for root, dirs, _ in os.walk(self.test_data_dir):
            for d in dirs + [""]:
                os.chmod(os.path.join(root, d), 0o777)

        # Prepare uncached TextLineTask.
        self.tfds_source = dataset_providers.TfdsDataSource(
            tfds_name="fake:0.0.0", splits=("train", "validation"))
        self.add_task("tfds_task", source=self.tfds_source)

        # Prepare TextLineSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir,
                                        "train.tsv"), _FAKE_DATASET["train"],
                           [2, 1], _dump_examples_to_tsv)
        self.text_line_source = dataset_providers.TextLineDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tsv*"),
            },
            skip_header_lines=1,
        )
        self.add_task("text_line_task",
                      source=self.text_line_source,
                      preprocessors=(split_tsv_preprocessor, ) +
                      self.DEFAULT_PREPROCESSORS)

        # Prepare TFExampleSource.
        _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"),
                           _FAKE_DATASET["train"], [2, 1],
                           _dump_examples_to_tfrecord)
        self.tf_example_source = dataset_providers.TFExampleDataSource(
            split_to_filepattern={
                "train": os.path.join(self.test_data_dir, "train.tfrecord*"),
            },
            feature_description={
                "prefix": tf.io.FixedLenFeature([], tf.string),
                "suffix": tf.io.FixedLenFeature([], tf.string),
            })
        self.add_task("tf_example_task", source=self.tf_example_source)

        # Prepare FunctionDataSource
        self.function_source = dataset_providers.FunctionDataSource(
            dataset_fn=get_fake_dataset, splits=["train", "validation"])
        self.add_task("function_task", source=self.function_source)

        # Prepare Task that is tokenized and preprocessed before caching.
        self.add_task("fully_processed_precache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                          dataset_providers.CacheDatasetPlaceholder(),
                      ))

        # Prepare Task that is tokenized after caching.
        self.add_task("tokenized_postcache",
                      source=self.function_source,
                      preprocessors=(
                          test_text_preprocessor,
                          dataset_providers.CacheDatasetPlaceholder(),
                          preprocessors.tokenize,
                          token_preprocessor_no_sequence_length,
                      ))

        # Prepare Task with randomization.
        self.random_task = self.add_task(
            "random_task",
            source=self.function_source,
            preprocessors=(
                test_text_preprocessor,
                dataset_providers.CacheDatasetPlaceholder(),
                preprocessors.tokenize,
                random_token_preprocessor,
            ))

        self.uncached_task = self.add_task("uncached_task",
                                           source=self.tfds_source)

        # Prepare cached task.
        dataset_utils.set_global_cache_dirs([self.test_data_dir])
        self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task")
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir,
                         "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"],
            [2, 1], _dump_examples_to_tfrecord)
        _dump_fake_dataset(
            os.path.join(self.cached_task_dir, "validation.tfrecord"),
            _FAKE_TOKENIZED_DATASET["validation"], [2],
            _dump_examples_to_tfrecord)
        self.cached_task = self.add_task("cached_task",
                                         source=self.tfds_source)

        # Prepare cached plaintext task.
        _dump_fake_dataset(
            os.path.join(self.test_data_dir, "cached_plaintext_task",
                         "train.tfrecord"),
            _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1],
            _dump_examples_to_tfrecord)
        self.cached_plaintext_task = self.add_task(
            "cached_plaintext_task",
            source=self.tfds_source,
            preprocessors=self.DEFAULT_PREPROCESSORS +
            (test_token_preprocessor, ))

    def tearDown(self):
        super().tearDown()
        self._tfds_patcher.stop()
        tf.random.set_seed(None)

    def verify_task_matches_fake_datasets(  # pylint:disable=dangerous-default-value
            self,
            task_name,
            use_cached,
            token_preprocessed=False,
            splits=("train", "validation"),
            sequence_length=_DEFAULT_SEQUENCE_LENGTH,
            num_shards=None):
        """Assert all splits for both tokenized datasets are correct."""
        task = TaskRegistry.get(task_name)
        for split in splits:
            get_dataset = functools.partial(task.get_dataset,
                                            sequence_length,
                                            split,
                                            use_cached=use_cached,
                                            shuffle=False)
            if num_shards:
                ds = get_dataset(
                    shard_info=dataset_providers.ShardInfo(0, num_shards))
                for i in range(1, num_shards):
                    ds = ds.concatenate(
                        get_dataset(shard_info=dataset_providers.ShardInfo(
                            i, num_shards)))
            else:
                ds = get_dataset()
            _assert_compare_to_fake_dataset(
                ds,
                split,
                task.output_features,
                sequence_length,
                token_preprocessed=token_preprocessed,
            )