def test_optional_features(self):
    def _dummy_preprocessor(output):
      return lambda _: tf.data.Dataset.from_tensors(output)

    default_vocab = test_utils.sentencepiece_vocab()
    features = {
        "inputs":
            dataset_providers.Feature(vocabulary=default_vocab, required=False),
        "targets":
            dataset_providers.Feature(vocabulary=default_vocab, required=True),
    }

    test_utils.add_task(
        "text_missing_optional_feature",
        test_utils.get_fake_dataset,
        output_features=features,
        text_preprocessor=_dummy_preprocessor({"targets": "a"}))
    TaskRegistry.get_dataset(
        "text_missing_optional_feature", {"targets": 13},
        "train", use_cached=False)

    test_utils.add_task(
        "text_missing_required_feature",
        test_utils.get_fake_dataset,
        output_features=features,
        text_preprocessor=_dummy_preprocessor({"inputs": "a"}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset is missing expected output feature after preprocessing: "
        "targets"):
      TaskRegistry.get_dataset(
          "text_missing_required_feature", {"inputs": 13},
          "train", use_cached=False)
  def test_get_dataset_mix(self):
    # pylint:disable=g-long-lambda
    test_utils.add_task(
        "two_task",
        test_utils.get_fake_dataset,
        token_preprocessor=lambda ds, **unused: ds.map(
            lambda _: {
                "targets": tf.constant([2], tf.int64),
                "inputs": tf.constant([2], tf.int64),
            }))

    test_utils.add_task(
        "three_task",
        test_utils.get_fake_dataset,
        token_preprocessor=lambda ds, **unused: ds.map(
            lambda _: {
                "targets": tf.constant([3], tf.int64),
                "inputs": tf.constant([3], tf.int64),
            }))
    # pylint:enable=g-long-lambda
    MixtureRegistry.add("test_mix4", [("two_task", 1), ("three_task", 1)])

    sequence_length = {"inputs": 2, "targets": 2}
    mix_ds = MixtureRegistry.get("test_mix4").get_dataset(
        sequence_length, "train", seed=13).take(1000)

    res = sum(int(item["inputs"][0]) for item in mix_ds.as_numpy_iterator())
    self.assertEqual(res, 2500)
 def test_get_rate_with_callable(self):
   def fn(t):
     self.assertEqual(t.name, "task4")
     return 42
   test_utils.add_task("task4", test_utils.get_fake_dataset)
   task = TaskRegistry.get("task4")
   MixtureRegistry.add("test_mix5", [("task4", fn)])
   mix = MixtureRegistry.get("test_mix5")
   self.assertEqual(mix.get_rate(task), 42)
  def test_tasks(self):
    test_utils.add_task("task1", test_utils.get_fake_dataset)
    test_utils.add_task("task2", test_utils.get_fake_dataset)
    MixtureRegistry.add("test_mix1", [("task1", 1), ("task2", 1)])
    mix = MixtureRegistry.get("test_mix1")
    self.assertEqual(len(mix.tasks), 2)

    for task in mix.tasks:
      test_utils.verify_task_matches_fake_datasets(task, use_cached=False)
      self.assertEqual(mix.get_rate(task), 1)
Beispiel #5
0
 def test_no_eos(self):
     features = {
         "inputs": utils.Feature(add_eos=True),
         "targets": utils.Feature(add_eos=False),
     }
     test_utils.add_task("task_no_eos",
                         test_utils.get_fake_dataset,
                         output_features=features)
     fn_task = TaskRegistry.get("task_no_eos")
     test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
Beispiel #6
0
 def test_no_eos(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs": utils.Feature(add_eos=True, vocabulary=default_vocab),
         "targets": utils.Feature(add_eos=False, vocabulary=default_vocab),
     }
     test_utils.add_task("task_no_eos",
                         test_utils.get_fake_dataset,
                         output_features=features)
     fn_task = TaskRegistry.get("task_no_eos")
     test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
Beispiel #7
0
 def test_dtype(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs":
         # defaults to int32
         dataset_providers.Feature(vocabulary=default_vocab),
         "targets":
         dataset_providers.Feature(dtype=tf.int64,
                                   vocabulary=default_vocab),
     }
     test_utils.add_task("task_dtypes",
                         test_utils.get_fake_dataset,
                         output_features=features)
     dtype_task = TaskRegistry.get("task_dtypes")
     test_utils.verify_task_matches_fake_datasets(dtype_task,
                                                  use_cached=False)
 def test_mixture_of_mixtures(self):
   test_utils.add_task("task_a", test_utils.get_fake_dataset)
   test_utils.add_task("task_b", test_utils.get_fake_dataset)
   test_utils.add_task("task_c", test_utils.get_fake_dataset)
   MixtureRegistry.add("another_mix", [("task_a", 1), ("task_b", 1)])
   MixtureRegistry.add("supermix", [("another_mix", 1), ("task_c", 1)])
   supermix = MixtureRegistry.get("supermix")
   names = [task.name for task in supermix.tasks]
   self.assertEqual(names, ["task_a", "task_b", "task_c"])
   self.assertEqual([supermix.get_rate(t) for t in supermix.tasks],
                    [0.5, 0.5, 1])
 def test_dataset_fn(self):
   test_utils.add_task("fn_task", test_utils.get_fake_dataset)
   fn_task = TaskRegistry.get("fn_task")
   test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
  def test_dataset_fn_signature(self):
    # Good signatures.
    def good_fn(split, shuffle_files):
      del split
      del shuffle_files
    test_utils.add_task("good_fn", good_fn)

    def default_good_fn(split, shuffle_files=False):
      del split
      del shuffle_files
    test_utils.add_task("default_good_fn", default_good_fn)

    def seed_fn(split, shuffle_files=True, seed=0):
      del split
      del shuffle_files
      del seed
    test_utils.add_task("seed_fn", seed_fn)

    def extra_kwarg_good_fn(split, shuffle_files, unused_kwarg=True):
      del split
      del shuffle_files
    test_utils.add_task("extra_kwarg_good_fn", extra_kwarg_good_fn)

    # Bad signatures.
    with self.assertRaisesRegex(
        ValueError,
        r"'missing_shuff' must have positional args \('split', "
        r"'shuffle_files'\), got: \('split',\)"):
      def missing_shuff(split):
        del split
      test_utils.add_task("fake_task", missing_shuff)

    with self.assertRaisesRegex(
        ValueError,
        r"'missing_split' must have positional args \('split', "
        r"'shuffle_files'\), got: \('shuffle_files',\)"):
      def missing_split(shuffle_files):
        del shuffle_files
      test_utils.add_task("fake_task", missing_split)

    with self.assertRaisesRegex(
        ValueError,
        r"'extra_pos_arg' may only have positional args \('split', "
        r"'shuffle_files'\), got: \('split', 'shuffle_files', 'unused_arg'\)"):
      def extra_pos_arg(split, shuffle_files, unused_arg):
        del split
        del shuffle_files
      test_utils.add_task("fake_task", extra_pos_arg)
 def test_no_tfds_version(self):
     with self.assertRaisesRegexp(
             ValueError,
             "TFDS name must contain a version number, got: fake"):
         test_utils.add_task("fake_task", tfds_name="fake")
 def test_repeat_name(self):
     with self.assertRaisesRegexp(
             ValueError,
             "Attempting to register duplicate provider: cached_task"):
         test_utils.add_task("cached_task")
 def test_invalid_name(self):
     with self.assertRaisesRegexp(
             ValueError,
             "Task name 'invalid/name' contains invalid characters. "
             "Must match regex: .*"):
         test_utils.add_task("invalid/name")
 def test_splits(self):
     test_utils.add_task("task_with_splits", splits=["validation"])
     task = TaskRegistry.get("task_with_splits")
     self.assertIn("validation", task.splits)
     self.assertNotIn("train", task.splits)
    def test_invalid_token_preprocessors(self):
        def _dummy_preprocessor(output):
            return lambda _, **unused: tf.data.Dataset.from_tensors(output)

        i64_arr = lambda x: np.array(x, dtype=np.int64)

        def _materialize(task):
            list(
                tfds.as_numpy(
                    TaskRegistry.get_dataset(task, {
                        "inputs": 13,
                        "targets": 13
                    },
                                             "train",
                                             use_cached=False)))

        test_utils.add_task("token_prep_ok",
                            token_preprocessor=_dummy_preprocessor({
                                "inputs":
                                i64_arr([2, 3]),
                                "targets":
                                i64_arr([3]),
                                "other":
                                "test"
                            }))
        _materialize("token_prep_ok")

        test_utils.add_task("token_prep_missing_feature",
                            token_preprocessor=_dummy_preprocessor(
                                {"inputs": i64_arr([2, 3])}))
        with self.assertRaisesRegexp(
                ValueError,
                "Task dataset is missing expected output feature after token "
                "preprocessing: targets"):
            _materialize("token_prep_missing_feature")

        test_utils.add_task("token_prep_wrong_type",
                            token_preprocessor=_dummy_preprocessor({
                                "inputs":
                                "a",
                                "targets":
                                i64_arr([3])
                            }))
        with self.assertRaisesRegexp(
                ValueError,
                "Task dataset has incorrect type for feature 'inputs' after token "
                "preprocessing: Got string, expected int64"):
            _materialize("token_prep_wrong_type")

        test_utils.add_task("token_prep_wrong_shape",
                            token_preprocessor=_dummy_preprocessor({
                                "inputs":
                                i64_arr([2, 3]),
                                "targets":
                                i64_arr(1)
                            }))
        with self.assertRaisesRegexp(
                ValueError,
                "Task dataset has incorrect rank for feature 'targets' after token "
                "preprocessing: Got 0, expected 1"):
            _materialize("token_prep_wrong_shape")

        test_utils.add_task("token_prep_has_eos",
                            token_preprocessor=_dummy_preprocessor({
                                "inputs":
                                i64_arr([1, 3]),
                                "targets":
                                i64_arr([4])
                            }))
        with self.assertRaisesRegexp(
                tf.errors.InvalidArgumentError,
                r".*Feature \\'inputs\\' unexpectedly contains EOS=1 token after token "
                r"preprocessing\..*"):
            _materialize("token_prep_has_eos")
    def test_invalid_text_preprocessors(self):
        def _dummy_preprocessor(output):
            return lambda _: tf.data.Dataset.from_tensors(output)

        test_utils.add_task("text_prep_ok",
                            text_preprocessor=_dummy_preprocessor({
                                "inputs": "a",
                                "targets": "b",
                                "other": [0]
                            }))
        TaskRegistry.get_dataset("text_prep_ok", {
            "inputs": 13,
            "targets": 13
        },
                                 "train",
                                 use_cached=False)

        test_utils.add_task("text_prep_missing_feature",
                            text_preprocessor=_dummy_preprocessor(
                                {"inputs": "a"}))
        with self.assertRaisesRegexp(
                ValueError,
                "Task dataset is missing expected output feature after text "
                "preprocessing: targets"):
            TaskRegistry.get_dataset("text_prep_missing_feature", {
                "inputs": 13,
                "targets": 13
            },
                                     "train",
                                     use_cached=False)

        test_utils.add_task("text_prep_wrong_type",
                            text_preprocessor=_dummy_preprocessor({
                                "inputs": 0,
                                "targets": 1
                            }))
        with self.assertRaisesRegexp(
                ValueError,
                "Task dataset has incorrect type for feature 'inputs' after text "
                "preprocessing: Got int32, expected string"):
            TaskRegistry.get_dataset("text_prep_wrong_type", {
                "inputs": 13,
                "targets": 13
            },
                                     "train",
                                     use_cached=False)

        test_utils.add_task("text_prep_wrong_shape",
                            text_preprocessor=_dummy_preprocessor({
                                "inputs":
                                "a",
                                "targets": ["a", "b"]
                            }))
        with self.assertRaisesRegexp(
                ValueError,
                "Task dataset has incorrect rank for feature 'targets' after text "
                "preprocessing: Got 1, expected 0"):
            TaskRegistry.get_dataset("text_prep_wrong_shape", {
                "inputs": 13,
                "targets": 13
            },
                                     "train",
                                     use_cached=False)