def test_get_dataset_cached(self): test_utils.verify_task_matches_fake_datasets( self.cached_task, use_cached=True) # Test with token preprocessor. self.cached_task._token_preprocessor = test_utils.test_token_preprocessor test_utils.verify_task_matches_fake_datasets( self.cached_task, use_cached=False, token_preprocessed=True)
def test_get_dataset_cached(self): test_utils.verify_task_matches_fake_datasets( self.cached_task, use_cached=True, token_preprocessed=True) # Test without token preprocessor. test_utils.verify_task_matches_fake_datasets( TaskRegistry.get("cached_task_no_token_prep"), use_cached=True, token_preprocessed=False)
def test_tasks(self): test_utils.add_task("task1", test_utils.get_fake_dataset) test_utils.add_task("task2", test_utils.get_fake_dataset) MixtureRegistry.add("test_mix1", [("task1", 1), ("task2", 1)]) mix = MixtureRegistry.get("test_mix1") self.assertEqual(len(mix.tasks), 2) for task in mix.tasks: test_utils.verify_task_matches_fake_datasets(task, use_cached=False) self.assertEqual(mix.get_rate(task), 1)
def test_no_eos(self): features = { "inputs": utils.Feature(add_eos=True), "targets": utils.Feature(add_eos=False), } test_utils.add_task("task_no_eos", test_utils.get_fake_dataset, output_features=features) fn_task = TaskRegistry.get("task_no_eos") test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
def test_no_eos(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": utils.Feature(add_eos=True, vocabulary=default_vocab), "targets": utils.Feature(add_eos=False, vocabulary=default_vocab), } test_utils.add_task("task_no_eos", test_utils.get_fake_dataset, output_features=features) fn_task = TaskRegistry.get("task_no_eos") test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
def test_sharding(self): for i in range(3): test_utils.verify_task_matches_fake_datasets( self.cached_task, use_cached=False, num_shards=i, token_preprocessed=True) test_utils.verify_task_matches_fake_datasets( self.cached_task, use_cached=True, num_shards=i, token_preprocessed=True)
def test_get_dataset_onthefly(self): test_utils.verify_task_matches_fake_datasets( self.uncached_task, use_cached=False) # Test with token preprocessor. self.uncached_task._token_preprocessor = test_utils.test_token_preprocessor test_utils.verify_task_matches_fake_datasets( self.uncached_task, use_cached=False, token_preprocessed=True) # Override mock to get more examples. def fake_load(s, shuffle_files=False): del shuffle_files # Unused, to mimic TFDS API return test_utils.get_fake_dataset(s).repeat().take(20) test_utils.add_fake_tfds( utils.LazyTfdsLoader("fake:0.0.0")._replace(load=fake_load))
def test_get_dataset_onthefly(self): test_utils.verify_task_matches_fake_datasets( self.uncached_task, use_cached=False, token_preprocessed=True) # Test without token preprocessor. test_utils.verify_task_matches_fake_datasets( TaskRegistry.get("uncached_task_no_token_prep"), use_cached=False, token_preprocessed=False) # Override mock to get more examples. def fake_load(s, shuffle_files=False): del shuffle_files # Unused, to mimic TFDS API return test_utils.get_fake_dataset(s).repeat().take(20) self._tfds_patcher.new.return_value = ( self._tfds_patcher.new.return_value._replace(load=fake_load))
def test_dtype(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": # defaults to int32 dataset_providers.Feature(vocabulary=default_vocab), "targets": dataset_providers.Feature(dtype=tf.int64, vocabulary=default_vocab), } test_utils.add_task("task_dtypes", test_utils.get_fake_dataset, output_features=features) dtype_task = TaskRegistry.get("task_dtypes") test_utils.verify_task_matches_fake_datasets(dtype_task, use_cached=False)
def test_dataset_fn(self): test_utils.add_task("fn_task", test_utils.get_fake_dataset) fn_task = TaskRegistry.get("fn_task") test_utils.verify_task_matches_fake_datasets(fn_task, use_cached=False)
def test_get_dataset_v3(self): test_utils.verify_task_matches_fake_datasets( self.task_v3, use_cached=False, token_preprocessed=True)
def test_tf_example_task(self): test_utils.verify_task_matches_fake_datasets( self.tf_example_task, use_cached=False, splits=["train"])