def test_value_errors(self): dataset_fn = (lambda split, shuffle_files: tf.data.Dataset. from_tensors(["test"])) output_features = { "inputs": dataset_providers.Feature(test_utils.sentencepiece_vocab()) } with self.assertRaisesWithLiteralMatch( ValueError, "`CacheDatasetPlaceholder` can appear at most once in the " "preprocessing pipeline. Found 2 in 'multiple_cache_placeholders'." ): dataset_providers.Task( "multiple_cache_placeholders", source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, dataset_providers.CacheDatasetPlaceholder(), test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[]) with self.assertRaisesWithLiteralMatch( ValueError, "'test_token_preprocessor' has a `sequence_length` argument but occurs " "before `CacheDatasetPlaceholder` in 'sequence_length_pre_cache'. This " "is not allowed since the sequence length is specified at run time." ): dataset_providers.Task( "sequence_length_pre_cache", dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train"], ), preprocessors=[ test_utils.test_text_preprocessor, preprocessors.tokenize, test_utils.test_token_preprocessor, dataset_providers.CacheDatasetPlaceholder() ], output_features=output_features, metric_fns=[])
def __enter__(self): def ds_fn(split, shuffle_files): del shuffle_files data = self.per_split_data[split] ds = tf.data.Dataset.from_tensors(data) return ds mock_source = dataset_providers.FunctionDataSource( ds_fn, splits=self.per_split_data.keys()) self._task._source = mock_source self._mock_source = mock_source
def test_function_source_signature(self): # Good signatures. def good_fn(split, shuffle_files): del split del shuffle_files dataset_providers.FunctionDataSource(good_fn, splits=("train",)) def default_good_fn(split, shuffle_files=False): del split del shuffle_files dataset_providers.FunctionDataSource(default_good_fn, splits=("train",)) def seed_fn(split, shuffle_files=True, seed=0): del split del shuffle_files del seed dataset_providers.FunctionDataSource(seed_fn, splits=("train",)) def extra_kwarg_good_fn(split, shuffle_files, unused_kwarg=True): del split del shuffle_files dataset_providers.FunctionDataSource(extra_kwarg_good_fn, splits=("train",)) # Bad signatures. with self.assertRaisesWithLiteralMatch( ValueError, "'missing_shuff' must have positional args ('split', 'shuffle_files'), " "got: ('split',)"): def missing_shuff(split): del split dataset_providers.FunctionDataSource(missing_shuff, splits=("train",)) with self.assertRaisesWithLiteralMatch( ValueError, "'missing_split' must have positional args ('split', 'shuffle_files'), " "got: ('shuffle_files',)"): def missing_split(shuffle_files): del shuffle_files dataset_providers.FunctionDataSource(missing_split, splits=("train",)) with self.assertRaisesWithLiteralMatch( ValueError, "'extra_pos_arg' may only have positional args ('split', " "'shuffle_files'), got: ('split', 'shuffle_files', 'unused_arg')"): def extra_pos_arg(split, shuffle_files, unused_arg): del split del shuffle_files dataset_providers.FunctionDataSource(extra_pos_arg, splits=("train",))
def _task_from_tensor_slices(name, tensor_slices, label_classes): return dataset_providers.Task( name, dataset_providers.FunctionDataSource( lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices(tensor_slices), splits=("validation")), preprocessors=[utils.map_over_dataset(lambda ex: { "inputs": tf.range(ex["inputs_lengths"]), "targets": tf.range(ex["targets_lengths"]), "targets_pretokenized": ex["targets_pretokenized"], })], postprocess_fn=functools.partial( _string_label_to_class_id_postprocessor, label_classes=label_classes), output_features={"inputs": dataset_providers.Feature(mock.Mock()), "targets": dataset_providers.Feature(mock.Mock())} )
def register_dummy_task( task_name: str, dataset_fn: Callable[[str, str], tf.data.Dataset], output_feature_names: Sequence[str] = ("inputs", "targets")) -> None: """Register a dummy task for GetDatasetTest.""" dataset_providers.TaskRegistry.add( task_name, source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[ dataset_providers.CacheDatasetPlaceholder(), preprocessors.append_eos_after_trim, ], output_features={ feat: dataset_providers.Feature(test_utils.sentencepiece_vocab()) for feat in output_feature_names }, metric_fns=[])
def setUp(self): super().setUp() TaskRegistry.reset() MixtureRegistry.reset() self.fake_source = dataset_providers.FunctionDataSource( lambda split, shuffle_files: tf.data.Dataset.range(2), ['train']) self.vocabulary = vocabularies.PassThroughVocabulary(100) self.metrics_fns = [lambda targets, predictions: 0] def fake_preprocessor(ds): """Adds one and casts to int32.""" return ds.map(lambda x: tf.cast(x + 1, tf.int32)) def fake_preprocessor_of(ds, output_features): """Creates output feature dict from scalar input.""" return ds.map(lambda x: {k: [x] for k in output_features}) def fake_preprocessor_sl(ds, sequence_length): """Concatenates the sequence length to each feature.""" return ds.map( lambda x: { # pylint:disable=g-long-lambda k: tf.concat([v, [sequence_length[k]]], 0) for k, v in x.items() }) def fake_preprocessor_sl_of(ds, sequence_length, output_features): """Adds the sequence length to each feature with `add_eos` enabled.""" return ds.map( lambda x: { # pylint:disable=g-long-lambda k: tf.concat([v, [sequence_length[k]]], 0) if output_features[k].add_eos else v for k, v in x.items() }) self.preprocessors = [ fake_preprocessor, fake_preprocessor_of, fake_preprocessor_sl, fake_preprocessor_sl_of, ]
def register_dummy_task(task_name: str, dataset_fn: Callable[[str, str], tf.data.Dataset], output_feature_names: Sequence[str] = ("inputs", "targets"), postprocess_fn=None, metrics_fn=None) -> None: """Register a dummy task for GetDatasetTest.""" dataset_providers.TaskRegistry.add( task_name, source=dataset_providers.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"]), preprocessors=[preprocessors.append_eos_after_trim], postprocess_fn=postprocess_fn, output_features={ # Mock the sentencepiece vocabulary. feat: dataset_providers.Feature(mock.Mock(eos_id=True)) for feat in output_feature_names }, metric_fns=metrics_fn)
def test_data_injection(self): def ds_fn(split, shuffle_files): del shuffle_files data = {'train': {'data': b'not used'}} ds = tf.data.Dataset.from_tensors(data[split]) return ds source = dataset_providers.FunctionDataSource( dataset_fn=ds_fn, splits=['train']) dataset_providers.TaskRegistry.add( 'test_data_injection_task', source=source, preprocessors=[], output_features={}, metric_fns=[]) data = {'train': {'data': b'This data is not used.'}} with DataInjector('test_data_injection_task', data): pass task = dataset_providers.TaskRegistry.get('test_data_injection_task') self.assertIs(source, task._source)
def test_fewshot_data_source(self): def fake_dataset_fn(split, shuffle_files): del shuffle_files return tf.data.Dataset.range( *((0, 2) if split == 'validation' else (3, 5))) # 0 shot src = experimental.FewshotDataSource( dataset_providers.FunctionDataSource( dataset_fn=fake_dataset_fn, splits=['train', 'validation']), num_shots=0) dataset = src.get_dataset('validation') assert_dataset(dataset, [{ 'eval': 0, }, { 'eval': 1 }]) # 3 shot src = experimental.FewshotDataSource( dataset_providers.FunctionDataSource( dataset_fn=fake_dataset_fn, splits=['train', 'validation']), train_preprocessors=[ utils.map_over_dataset(lambda x: { 'inputs': 0, 'targets': x }) ], num_shots=3) dataset = src.get_dataset('validation') assert_dataset(dataset, [ { 'eval': 0, 'train': { 'inputs': [0, 0, 0], 'targets': [3, 4, 3] } }, { 'eval': 1, 'train': { 'inputs': [0, 0, 0], 'targets': [4, 3, 4] } }, ]) # 3-shot, sharded. assert_dataset( src.get_dataset('validation', shard_info=ShardInfo(0, 2)), [ { 'eval': 0, 'train': { 'inputs': [0, 0, 0], 'targets': [3, 3, 3] } }, ]) assert_dataset( src.get_dataset('validation', shard_info=ShardInfo(1, 2)), [ { 'eval': 1, 'train': { 'inputs': [0, 0, 0], 'targets': [4, 4, 4] } }, ]) # Missing train src = experimental.FewshotDataSource( dataset_providers.FunctionDataSource(dataset_fn=fake_dataset_fn, splits=['validation']), num_shots=3) with self.assertRaisesRegex( ValueError, 'Train split \'train\' is not one of the original source splits: ' r'\(\'validation\',\)'): dataset = src.get_dataset('validation')
def _prepare_sources_and_tasks(self): clear_tasks() clear_mixtures() # Prepare TfdsSource # Note we don't use mock.Mock since they fail to pickle. fake_tfds_paths = { "train": [ { # pylint:disable=g-complex-comprehension "filename": "train.tfrecord-%05d-of-00002" % i, "skip": 0, "take": -1 } for i in range(2) ], "validation": [{ "filename": "validation.tfrecord-00000-of-00001", "skip": 0, "take": -1 }], } def _load_shard(shard_instruction, shuffle_files, seed): del shuffle_files del seed fname = shard_instruction["filename"] if "train" in fname: ds = get_fake_dataset("train") if fname.endswith("00000-of-00002"): return ds.take(2) else: return ds.skip(2) else: return get_fake_dataset("validation") fake_tfds = FakeLazyTfds(name="fake:0.0.0", load=get_fake_dataset, load_shard=_load_shard, info=FakeTfdsInfo(splits={ "train": None, "validation": None }), files=fake_tfds_paths.get, size=lambda x: 30 if x == "train" else 10) self._tfds_patcher = mock.patch("t5.seqio.utils.LazyTfdsLoader", new=mock.Mock(return_value=fake_tfds)) self._tfds_patcher.start() # Set up data directory. self.test_tmpdir = self.get_tempdir() self.test_data_dir = os.path.join(self.test_tmpdir, "test_data") shutil.copytree(TEST_DATA_DIR, self.test_data_dir) for root, dirs, _ in os.walk(self.test_data_dir): for d in dirs + [""]: os.chmod(os.path.join(root, d), 0o777) # Prepare uncached TextLineTask. self.tfds_source = dataset_providers.TfdsDataSource( tfds_name="fake:0.0.0", splits=("train", "validation")) self.add_task("tfds_task", source=self.tfds_source) # Prepare TextLineSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tsv"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tsv) self.text_line_source = dataset_providers.TextLineDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tsv*"), }, skip_header_lines=1, ) self.add_task("text_line_task", source=self.text_line_source, preprocessors=(split_tsv_preprocessor, ) + self.DEFAULT_PREPROCESSORS) # Prepare TFExampleSource. _dump_fake_dataset(os.path.join(self.test_data_dir, "train.tfrecord"), _FAKE_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.tf_example_source = dataset_providers.TFExampleDataSource( split_to_filepattern={ "train": os.path.join(self.test_data_dir, "train.tfrecord*"), }, feature_description={ "prefix": tf.io.FixedLenFeature([], tf.string), "suffix": tf.io.FixedLenFeature([], tf.string), }) self.add_task("tf_example_task", source=self.tf_example_source) # Prepare FunctionDataSource self.function_source = dataset_providers.FunctionDataSource( dataset_fn=get_fake_dataset, splits=["train", "validation"]) self.add_task("function_task", source=self.function_source) # Prepare Task that is tokenized and preprocessed before caching. self.add_task("fully_processed_precache", source=self.function_source, preprocessors=( test_text_preprocessor, preprocessors.tokenize, token_preprocessor_no_sequence_length, dataset_providers.CacheDatasetPlaceholder(), )) # Prepare Task that is tokenized after caching. self.add_task("tokenized_postcache", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, token_preprocessor_no_sequence_length, )) # Prepare Task with randomization. self.random_task = self.add_task( "random_task", source=self.function_source, preprocessors=( test_text_preprocessor, dataset_providers.CacheDatasetPlaceholder(), preprocessors.tokenize, random_token_preprocessor, )) self.uncached_task = self.add_task("uncached_task", source=self.tfds_source) # Prepare cached task. dataset_utils.set_global_cache_dirs([self.test_data_dir]) self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task") _dump_fake_dataset( os.path.join(self.cached_task_dir, "train.tfrecord"), _FAKE_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) _dump_fake_dataset( os.path.join(self.cached_task_dir, "validation.tfrecord"), _FAKE_TOKENIZED_DATASET["validation"], [2], _dump_examples_to_tfrecord) self.cached_task = self.add_task("cached_task", source=self.tfds_source) # Prepare cached plaintext task. _dump_fake_dataset( os.path.join(self.test_data_dir, "cached_plaintext_task", "train.tfrecord"), _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], [2, 1], _dump_examples_to_tfrecord) self.cached_plaintext_task = self.add_task( "cached_plaintext_task", source=self.tfds_source, preprocessors=self.DEFAULT_PREPROCESSORS + (test_token_preprocessor, ))