def _task_from_tensor_slices(name, tensor_slices, label_classes): return dataset_providers.Task( name, dataset_providers.FunctionDataSource( lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices(tensor_slices), splits=("validation")), preprocessors=[utils.map_over_dataset(lambda ex: { "inputs": tf.range(ex["inputs_lengths"]), "targets": tf.range(ex["targets_lengths"]), "targets_pretokenized": ex["targets_pretokenized"], })], postprocess_fn=functools.partial( _string_label_to_class_id_postprocessor, label_classes=label_classes), output_features={"inputs": dataset_providers.Feature(mock.Mock()), "targets": dataset_providers.Feature(mock.Mock())} )
def test_dtype(self): default_vocab = test_utils.sentencepiece_vocab() features = { "inputs": # defaults to int32 dataset_providers.Feature(vocabulary=default_vocab), "targets": dataset_providers.Feature(dtype=tf.int64, vocabulary=default_vocab), } self.add_task( "task_dtypes", self.function_source, preprocessors=self.DEFAULT_PREPROCESSORS + ( utils.map_over_dataset( lambda x: { k: tf.cast(v, tf.int64) if k == "targets" else v # pylint:disable=g-long-lambda for k, v in x.items() }), ), output_features=features) self.verify_task_matches_fake_datasets("task_dtypes", use_cached=False)
def test_fewshot_data_source(self): def fake_dataset_fn(split, shuffle_files): del shuffle_files return tf.data.Dataset.range( *((0, 2) if split == 'validation' else (3, 5))) # 0 shot src = experimental.FewshotDataSource( dataset_providers.FunctionDataSource( dataset_fn=fake_dataset_fn, splits=['train', 'validation']), num_shots=0) dataset = src.get_dataset('validation') assert_dataset(dataset, [{ 'eval': 0, }, { 'eval': 1 }]) # 3 shot src = experimental.FewshotDataSource( dataset_providers.FunctionDataSource( dataset_fn=fake_dataset_fn, splits=['train', 'validation']), train_preprocessors=[ utils.map_over_dataset(lambda x: { 'inputs': 0, 'targets': x }) ], num_shots=3) dataset = src.get_dataset('validation') assert_dataset(dataset, [ { 'eval': 0, 'train': { 'inputs': [0, 0, 0], 'targets': [3, 4, 3] } }, { 'eval': 1, 'train': { 'inputs': [0, 0, 0], 'targets': [4, 3, 4] } }, ]) # 3-shot, sharded. assert_dataset( src.get_dataset('validation', shard_info=ShardInfo(0, 2)), [ { 'eval': 0, 'train': { 'inputs': [0, 0, 0], 'targets': [3, 3, 3] } }, ]) assert_dataset( src.get_dataset('validation', shard_info=ShardInfo(1, 2)), [ { 'eval': 1, 'train': { 'inputs': [0, 0, 0], 'targets': [4, 4, 4] } }, ]) # Missing train src = experimental.FewshotDataSource( dataset_providers.FunctionDataSource(dataset_fn=fake_dataset_fn, splits=['validation']), num_shots=3) with self.assertRaisesRegex( ValueError, 'Train split \'train\' is not one of the original source splits: ' r'\(\'validation\',\)'): dataset = src.get_dataset('validation')