def test_splits(self):
    test_utils.add_tfds_task("task_with_splits", splits=["validation"])
    task = TaskRegistry.get("task_with_splits")
    self.assertSameElements(["validation"], task.splits)

    test_utils.add_tfds_task("task_with_sliced_splits",
                             splits={"validation": "train[0:1%]"})
    task = TaskRegistry.get("task_with_splits")
    self.assertSameElements(["validation"], task.splits)
  def test_invalid_token_preprocessors(self):
    def _dummy_preprocessor(output):
      return lambda _, **unused: tf.data.Dataset.from_tensors(output)
    i64_arr = lambda x: np.array(x, dtype=np.int64)
    def _materialize(task):
      list(
          TaskRegistry.get_dataset(
              task, {
                  "inputs": 13,
                  "targets": 13
              }, "train", use_cached=False).as_numpy_iterator())

    test_utils.add_tfds_task(
        "token_prep_ok",
        token_preprocessor=_dummy_preprocessor(
            {"inputs": i64_arr([2, 3]), "targets": i64_arr([3]),
             "other": "test"}))
    _materialize("token_prep_ok")

    test_utils.add_tfds_task(
        "token_prep_missing_feature",
        token_preprocessor=_dummy_preprocessor({"inputs": i64_arr([2, 3])}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset is missing expected output feature after preprocessing: "
        "targets"):
      _materialize("token_prep_missing_feature")

    test_utils.add_tfds_task(
        "token_prep_wrong_type",
        token_preprocessor=_dummy_preprocessor(
            {"inputs": "a", "targets": i64_arr([3])}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset has incorrect type for feature 'inputs' after "
        "preprocessing: Got string, expected int64"):
      _materialize("token_prep_wrong_type")

    test_utils.add_tfds_task(
        "token_prep_wrong_shape",
        token_preprocessor=_dummy_preprocessor(
            {"inputs": i64_arr([2, 3]), "targets": i64_arr(1)}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset has incorrect rank for feature 'targets' after "
        "preprocessing: Got 0, expected 1"):
      _materialize("token_prep_wrong_shape")

    test_utils.add_tfds_task(
        "token_prep_has_eos",
        token_preprocessor=_dummy_preprocessor(
            {"inputs": i64_arr([1, 3]), "targets": i64_arr([4])}))
    with self.assertRaisesRegex(
        tf.errors.InvalidArgumentError,
        r".*Feature \\'inputs\\' unexpectedly contains EOS=1 token after "
        r"preprocessing\..*"):
      _materialize("token_prep_has_eos")
  def test_invalid_text_preprocessors(self):
    def _dummy_preprocessor(output):
      return lambda _: tf.data.Dataset.from_tensors(output)

    test_utils.add_tfds_task(
        "text_prep_ok",
        text_preprocessor=_dummy_preprocessor(
            {"inputs": "a", "targets": "b", "other": [0]}))
    TaskRegistry.get_dataset(
        "text_prep_ok", {"inputs": 13, "targets": 13},
        "train", use_cached=False)

    test_utils.add_tfds_task(
        "text_prep_missing_feature",
        text_preprocessor=_dummy_preprocessor({"inputs": "a"}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset is missing expected output feature after text "
        "preprocessing: targets"):
      TaskRegistry.get_dataset(
          "text_prep_missing_feature", {"inputs": 13, "targets": 13},
          "train", use_cached=False)

    test_utils.add_tfds_task(
        "text_prep_wrong_type",
        text_preprocessor=_dummy_preprocessor({"inputs": 0, "targets": 1}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset has incorrect type for feature 'inputs' after text "
        "preprocessing: Got int32, expected string"):
      TaskRegistry.get_dataset(
          "text_prep_wrong_type", {"inputs": 13, "targets": 13},
          "train", use_cached=False)

    test_utils.add_tfds_task(
        "text_prep_wrong_shape",
        text_preprocessor=_dummy_preprocessor(
            {"inputs": "a", "targets": ["a", "b"]}))
    with self.assertRaisesRegex(
        ValueError,
        "Task dataset has incorrect rank for feature 'targets' after text "
        "preprocessing: Got 1, expected 0"):
      TaskRegistry.get_dataset(
          "text_prep_wrong_shape", {"inputs": 13, "targets": 13},
          "train", use_cached=False)
 def test_repeat_name(self):
   with self.assertRaisesRegex(
       ValueError, "Attempting to register duplicate provider: cached_task"):
     test_utils.add_tfds_task("cached_task")
 def test_invalid_name(self):
   with self.assertRaisesRegex(
       ValueError,
       "Task name 'invalid/name' contains invalid characters. "
       "Must match regex: .*"):
     test_utils.add_tfds_task("invalid/name")
 def test_no_tfds_version(self):
   with self.assertRaisesRegex(
       ValueError, "TFDS name must contain a version number, got: fake"):
     test_utils.add_tfds_task("fake_task", tfds_name="fake")
예제 #7
0
 def test_splits(self):
     test_utils.add_tfds_task("task_with_splits", splits=["validation"])
     task = TaskRegistry.get("task_with_splits")
     self.assertIn("validation", task.splits)
     self.assertNotIn("train", task.splits)