def test_splits(self): test_utils.add_tfds_task("task_with_splits", splits=["validation"]) task = TaskRegistry.get("task_with_splits") self.assertSameElements(["validation"], task.splits) test_utils.add_tfds_task("task_with_sliced_splits", splits={"validation": "train[0:1%]"}) task = TaskRegistry.get("task_with_splits") self.assertSameElements(["validation"], task.splits)
def test_invalid_token_preprocessors(self): def _dummy_preprocessor(output): return lambda _, **unused: tf.data.Dataset.from_tensors(output) i64_arr = lambda x: np.array(x, dtype=np.int64) def _materialize(task): list( TaskRegistry.get_dataset( task, { "inputs": 13, "targets": 13 }, "train", use_cached=False).as_numpy_iterator()) test_utils.add_tfds_task( "token_prep_ok", token_preprocessor=_dummy_preprocessor( {"inputs": i64_arr([2, 3]), "targets": i64_arr([3]), "other": "test"})) _materialize("token_prep_ok") test_utils.add_tfds_task( "token_prep_missing_feature", token_preprocessor=_dummy_preprocessor({"inputs": i64_arr([2, 3])})) with self.assertRaisesRegex( ValueError, "Task dataset is missing expected output feature after preprocessing: " "targets"): _materialize("token_prep_missing_feature") test_utils.add_tfds_task( "token_prep_wrong_type", token_preprocessor=_dummy_preprocessor( {"inputs": "a", "targets": i64_arr([3])})) with self.assertRaisesRegex( ValueError, "Task dataset has incorrect type for feature 'inputs' after " "preprocessing: Got string, expected int64"): _materialize("token_prep_wrong_type") test_utils.add_tfds_task( "token_prep_wrong_shape", token_preprocessor=_dummy_preprocessor( {"inputs": i64_arr([2, 3]), "targets": i64_arr(1)})) with self.assertRaisesRegex( ValueError, "Task dataset has incorrect rank for feature 'targets' after " "preprocessing: Got 0, expected 1"): _materialize("token_prep_wrong_shape") test_utils.add_tfds_task( "token_prep_has_eos", token_preprocessor=_dummy_preprocessor( {"inputs": i64_arr([1, 3]), "targets": i64_arr([4])})) with self.assertRaisesRegex( tf.errors.InvalidArgumentError, r".*Feature \\'inputs\\' unexpectedly contains EOS=1 token after " r"preprocessing\..*"): _materialize("token_prep_has_eos")
def test_invalid_text_preprocessors(self): def _dummy_preprocessor(output): return lambda _: tf.data.Dataset.from_tensors(output) test_utils.add_tfds_task( "text_prep_ok", text_preprocessor=_dummy_preprocessor( {"inputs": "a", "targets": "b", "other": [0]})) TaskRegistry.get_dataset( "text_prep_ok", {"inputs": 13, "targets": 13}, "train", use_cached=False) test_utils.add_tfds_task( "text_prep_missing_feature", text_preprocessor=_dummy_preprocessor({"inputs": "a"})) with self.assertRaisesRegex( ValueError, "Task dataset is missing expected output feature after text " "preprocessing: targets"): TaskRegistry.get_dataset( "text_prep_missing_feature", {"inputs": 13, "targets": 13}, "train", use_cached=False) test_utils.add_tfds_task( "text_prep_wrong_type", text_preprocessor=_dummy_preprocessor({"inputs": 0, "targets": 1})) with self.assertRaisesRegex( ValueError, "Task dataset has incorrect type for feature 'inputs' after text " "preprocessing: Got int32, expected string"): TaskRegistry.get_dataset( "text_prep_wrong_type", {"inputs": 13, "targets": 13}, "train", use_cached=False) test_utils.add_tfds_task( "text_prep_wrong_shape", text_preprocessor=_dummy_preprocessor( {"inputs": "a", "targets": ["a", "b"]})) with self.assertRaisesRegex( ValueError, "Task dataset has incorrect rank for feature 'targets' after text " "preprocessing: Got 1, expected 0"): TaskRegistry.get_dataset( "text_prep_wrong_shape", {"inputs": 13, "targets": 13}, "train", use_cached=False)
def test_repeat_name(self): with self.assertRaisesRegex( ValueError, "Attempting to register duplicate provider: cached_task"): test_utils.add_tfds_task("cached_task")
def test_invalid_name(self): with self.assertRaisesRegex( ValueError, "Task name 'invalid/name' contains invalid characters. " "Must match regex: .*"): test_utils.add_tfds_task("invalid/name")
def test_no_tfds_version(self): with self.assertRaisesRegex( ValueError, "TFDS name must contain a version number, got: fake"): test_utils.add_tfds_task("fake_task", tfds_name="fake")
def test_splits(self): test_utils.add_tfds_task("task_with_splits", splits=["validation"]) task = TaskRegistry.get("task_with_splits") self.assertIn("validation", task.splits) self.assertNotIn("train", task.splits)