def test_optional_features(self): def _dummy_preprocessor(output): return lambda _: tf.data.Dataset.from_tensors(output) default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": dataset_providers.Feature(vocabulary=default_vocab, required=False), "targets": dataset_providers.Feature(vocabulary=default_vocab, required=True), } test_utils.add_task( "text_missing_optional_feature", test_utils.get_fake_dataset, output_features=features, text_preprocessor=_dummy_preprocessor({"targets": "a"})) TaskRegistry.get_dataset( "text_missing_optional_feature", {"targets": 13}, "train", use_cached=False) test_utils.add_task( "text_missing_required_feature", test_utils.get_fake_dataset, output_features=features, text_preprocessor=_dummy_preprocessor({"inputs": "a"})) with self.assertRaisesRegex( ValueError, "Task dataset is missing expected output feature after preprocessing: " "targets"): TaskRegistry.get_dataset( "text_missing_required_feature", {"inputs": 13}, "train", use_cached=False)
def test_get_dataset_mix(self): # pylint:disable=g-long-lambda test_utils.add_task( "two_task", test_utils.get_fake_dataset, token_preprocessor=lambda ds, **unused: ds.map( lambda _: { "targets": tf.constant([2], tf.int64), "inputs": tf.constant([2], tf.int64), })) test_utils.add_task( "three_task", test_utils.get_fake_dataset, token_preprocessor=lambda ds, **unused: ds.map( lambda _: { "targets": tf.constant([3], tf.int64), "inputs": tf.constant([3], tf.int64), })) # pylint:enable=g-long-lambda MixtureRegistry.add("test_mix4", [("two_task", 1), ("three_task", 1)]) sequence_length = {"inputs": 2, "targets": 2} mix_ds = MixtureRegistry.get("test_mix4").get_dataset( sequence_length, "train", seed=13).take(1000) res = sum(int(item["inputs"][0]) for item in mix_ds.as_numpy_iterator()) self.assertEqual(res, 2500)
def test_get_rate_with_callable(self): def fn(t): self.assertEqual(t.name, "task4") return 42 test_utils.add_task("task4", test_utils.get_fake_dataset) task = TaskRegistry.get("task4") MixtureRegistry.add("test_mix5", [("task4", fn)]) mix = MixtureRegistry.get("test_mix5") self.assertEqual(mix.get_rate(task), 42)
def test_tasks(self): test_utils.add_task("task1", test_utils.get_fake_dataset) test_utils.add_task("task2", test_utils.get_fake_dataset) MixtureRegistry.add("test_mix1", [("task1", 1), ("task2", 1)]) mix = MixtureRegistry.get("test_mix1") self.assertEqual(len(mix.tasks), 2) for task in mix.tasks: test_utils.verify_task_matches_fake_datasets(task, use_cached=False) self.assertEqual(mix.get_rate(task), 1)
def test_no_eos(self): features = { "inputs": utils.Feature(add_eos=True), "targets": utils.Feature(add_eos=False), } test_utils.add_task("task_no_eos", test_utils.get_fake_dataset, output_features=features) fn_task = TaskRegistry.get("task_no_eos") test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
def test_no_eos(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": utils.Feature(add_eos=True, vocabulary=default_vocab), "targets": utils.Feature(add_eos=False, vocabulary=default_vocab), } test_utils.add_task("task_no_eos", test_utils.get_fake_dataset, output_features=features) fn_task = TaskRegistry.get("task_no_eos") test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
def test_dtype(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": # defaults to int32 dataset_providers.Feature(vocabulary=default_vocab), "targets": dataset_providers.Feature(dtype=tf.int64, vocabulary=default_vocab), } test_utils.add_task("task_dtypes", test_utils.get_fake_dataset, output_features=features) dtype_task = TaskRegistry.get("task_dtypes") test_utils.verify_task_matches_fake_datasets(dtype_task, use_cached=False)
def test_mixture_of_mixtures(self): test_utils.add_task("task_a", test_utils.get_fake_dataset) test_utils.add_task("task_b", test_utils.get_fake_dataset) test_utils.add_task("task_c", test_utils.get_fake_dataset) MixtureRegistry.add("another_mix", [("task_a", 1), ("task_b", 1)]) MixtureRegistry.add("supermix", [("another_mix", 1), ("task_c", 1)]) supermix = MixtureRegistry.get("supermix") names = [task.name for task in supermix.tasks] self.assertEqual(names, ["task_a", "task_b", "task_c"]) self.assertEqual([supermix.get_rate(t) for t in supermix.tasks], [0.5, 0.5, 1])
def test_dataset_fn(self): test_utils.add_task("fn_task", test_utils.get_fake_dataset) fn_task = TaskRegistry.get("fn_task") test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
def test_dataset_fn_signature(self): # Good signatures. def good_fn(split, shuffle_files): del split del shuffle_files test_utils.add_task("good_fn", good_fn) def default_good_fn(split, shuffle_files=False): del split del shuffle_files test_utils.add_task("default_good_fn", default_good_fn) def seed_fn(split, shuffle_files=True, seed=0): del split del shuffle_files del seed test_utils.add_task("seed_fn", seed_fn) def extra_kwarg_good_fn(split, shuffle_files, unused_kwarg=True): del split del shuffle_files test_utils.add_task("extra_kwarg_good_fn", extra_kwarg_good_fn) # Bad signatures. with self.assertRaisesRegex( ValueError, r"'missing_shuff' must have positional args \('split', " r"'shuffle_files'\), got: \('split',\)"): def missing_shuff(split): del split test_utils.add_task("fake_task", missing_shuff) with self.assertRaisesRegex( ValueError, r"'missing_split' must have positional args \('split', " r"'shuffle_files'\), got: \('shuffle_files',\)"): def missing_split(shuffle_files): del shuffle_files test_utils.add_task("fake_task", missing_split) with self.assertRaisesRegex( ValueError, r"'extra_pos_arg' may only have positional args \('split', " r"'shuffle_files'\), got: \('split', 'shuffle_files', 'unused_arg'\)"): def extra_pos_arg(split, shuffle_files, unused_arg): del split del shuffle_files test_utils.add_task("fake_task", extra_pos_arg)
def test_no_tfds_version(self): with self.assertRaisesRegexp( ValueError, "TFDS name must contain a version number, got: fake"): test_utils.add_task("fake_task", tfds_name="fake")
def test_repeat_name(self): with self.assertRaisesRegexp( ValueError, "Attempting to register duplicate provider: cached_task"): test_utils.add_task("cached_task")
def test_invalid_name(self): with self.assertRaisesRegexp( ValueError, "Task name 'invalid/name' contains invalid characters. " "Must match regex: .*"): test_utils.add_task("invalid/name")
def test_splits(self): test_utils.add_task("task_with_splits", splits=["validation"]) task = TaskRegistry.get("task_with_splits") self.assertIn("validation", task.splits) self.assertNotIn("train", task.splits)
def test_invalid_token_preprocessors(self): def _dummy_preprocessor(output): return lambda _, **unused: tf.data.Dataset.from_tensors(output) i64_arr = lambda x: np.array(x, dtype=np.int64) def _materialize(task): list( tfds.as_numpy( TaskRegistry.get_dataset(task, { "inputs": 13, "targets": 13 }, "train", use_cached=False))) test_utils.add_task("token_prep_ok", token_preprocessor=_dummy_preprocessor({ "inputs": i64_arr([2, 3]), "targets": i64_arr([3]), "other": "test" })) _materialize("token_prep_ok") test_utils.add_task("token_prep_missing_feature", token_preprocessor=_dummy_preprocessor( {"inputs": i64_arr([2, 3])})) with self.assertRaisesRegexp( ValueError, "Task dataset is missing expected output feature after token " "preprocessing: targets"): _materialize("token_prep_missing_feature") test_utils.add_task("token_prep_wrong_type", token_preprocessor=_dummy_preprocessor({ "inputs": "a", "targets": i64_arr([3]) })) with self.assertRaisesRegexp( ValueError, "Task dataset has incorrect type for feature 'inputs' after token " "preprocessing: Got string, expected int64"): _materialize("token_prep_wrong_type") test_utils.add_task("token_prep_wrong_shape", token_preprocessor=_dummy_preprocessor({ "inputs": i64_arr([2, 3]), "targets": i64_arr(1) })) with self.assertRaisesRegexp( ValueError, "Task dataset has incorrect rank for feature 'targets' after token " "preprocessing: Got 0, expected 1"): _materialize("token_prep_wrong_shape") test_utils.add_task("token_prep_has_eos", token_preprocessor=_dummy_preprocessor({ "inputs": i64_arr([1, 3]), "targets": i64_arr([4]) })) with self.assertRaisesRegexp( tf.errors.InvalidArgumentError, r".*Feature \\'inputs\\' unexpectedly contains EOS=1 token after token " r"preprocessing\..*"): _materialize("token_prep_has_eos")
def test_invalid_text_preprocessors(self): def _dummy_preprocessor(output): return lambda _: tf.data.Dataset.from_tensors(output) test_utils.add_task("text_prep_ok", text_preprocessor=_dummy_preprocessor({ "inputs": "a", "targets": "b", "other": [0] })) TaskRegistry.get_dataset("text_prep_ok", { "inputs": 13, "targets": 13 }, "train", use_cached=False) test_utils.add_task("text_prep_missing_feature", text_preprocessor=_dummy_preprocessor( {"inputs": "a"})) with self.assertRaisesRegexp( ValueError, "Task dataset is missing expected output feature after text " "preprocessing: targets"): TaskRegistry.get_dataset("text_prep_missing_feature", { "inputs": 13, "targets": 13 }, "train", use_cached=False) test_utils.add_task("text_prep_wrong_type", text_preprocessor=_dummy_preprocessor({ "inputs": 0, "targets": 1 })) with self.assertRaisesRegexp( ValueError, "Task dataset has incorrect type for feature 'inputs' after text " "preprocessing: Got int32, expected string"): TaskRegistry.get_dataset("text_prep_wrong_type", { "inputs": 13, "targets": 13 }, "train", use_cached=False) test_utils.add_task("text_prep_wrong_shape", text_preprocessor=_dummy_preprocessor({ "inputs": "a", "targets": ["a", "b"] })) with self.assertRaisesRegexp( ValueError, "Task dataset has incorrect rank for feature 'targets' after text " "preprocessing: Got 1, expected 0"): TaskRegistry.get_dataset("text_prep_wrong_shape", { "inputs": 13, "targets": 13 }, "train", use_cached=False)