def test_optional_features(self):
        def _dummy_preprocessor(output):
            return lambda _: tf.data.Dataset.from_tensors(output)

        default_vocab = test_utils.sentencepiece_vocab()
        features = {
            "inputs": seqio.Feature(vocabulary=default_vocab, required=False),
            "targets": seqio.Feature(vocabulary=default_vocab, required=True),
        }

        task = self.add_t5_task("task_missing_optional_feature",
                                dataset_providers.TfdsTask,
                                tfds_name="fake:0.0.0",
                                output_features=features,
                                text_preprocessor=_dummy_preprocessor(
                                    {"targets": "a"}))
        task.get_dataset({"targets": 13}, "train", use_cached=False)

        task = self.add_t5_task("task_missing_required_feature",
                                dataset_providers.TfdsTask,
                                tfds_name="fake:0.0.0",
                                output_features=features,
                                text_preprocessor=_dummy_preprocessor(
                                    {"inputs": "a"}))
        with self.assertRaisesRegex(
                ValueError,
                "Task dataset is missing expected output feature after preprocessing: "
                "targets"):
            task.get_dataset({"inputs": 13}, "train", use_cached=False)
 def test_no_eos(self):
     default_vocab = test_utils.sentencepiece_vocab()
     features = {
         "inputs": seqio.Feature(add_eos=True, vocabulary=default_vocab),
         "targets": seqio.Feature(add_eos=False, vocabulary=default_vocab),
     }
     self.add_t5_task("task_no_eos",
                      dataset_providers.TfdsTask,
                      tfds_name="fake:0.0.0",
                      output_features=features)
     self.verify_task_matches_fake_datasets("task_no_eos", use_cached=False)
 def add_t5_task(self,
                 name,
                 cls,
                 text_preprocessor=(test_utils.test_text_preprocessor, ),
                 output_features=None,
                 **kwargs):
     output_features = output_features or {
         "inputs": seqio.Feature(test_utils.sentencepiece_vocab()),
         "targets": seqio.Feature(test_utils.sentencepiece_vocab())
     }
     return TaskRegistry.add(name,
                             cls,
                             text_preprocessor=text_preprocessor,
                             metric_fns=[],
                             output_features=output_features,
                             **kwargs)
from t5.data import postprocessors
from t5.data import preprocessors
from t5.data.glue_utils import get_glue_metric
from t5.data.glue_utils import get_glue_postprocess_fn
from t5.data.glue_utils import get_glue_text_preprocessor
from t5.data.glue_utils import get_super_glue_metric
from t5.evaluation import metrics
import tensorflow_datasets as tfds

TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

DEFAULT_OUTPUT_FEATURES = {
    "inputs":
    seqio.Feature(vocabulary=t5.data.get_default_vocabulary(),
                  add_eos=True,
                  required=False),
    "targets":
    seqio.Feature(vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
}

# ==================================== C4 ======================================
# Final pretraining task used in Raffel et al., 2019.
TaskRegistry.add("c4_v220_span_corruption",
                 TfdsTask,
                 tfds_name="c4/en:2.2.0",
                 text_preprocessor=functools.partial(preprocessors.rekey,
                                                     key_map={
                                                         "inputs": None,
                                                         "targets": "text"
                                                     }),
Пример #5
0
    def __init__(self,
                 name,
                 dataset_fn,
                 splits,
                 text_preprocessor,
                 metric_fns=None,
                 postprocess_fn=None,
                 token_preprocessor=None,
                 output_features=None,
                 num_input_examples=None,
                 supports_caching=True,
                 shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
                 source=None):

        if (dataset_fn, source).count(None) != 1:
            raise ValueError(
                "Exactly one of either `dataset_fn` or `source` must be provided."
            )

        if source and (splits or num_input_examples):
            raise ValueError(
                "If `source` is provided, `splits` and `num_input_examples` should "
                "not also be provided to the Task.")
        source = source or seqio.FunctionDataSource(
            dataset_fn=dataset_fn,
            splits=splits,
            num_input_examples=num_input_examples)

        if text_preprocessor and not hasattr(text_preprocessor, "__iter__"):
            text_preprocessor = [text_preprocessor]
        if token_preprocessor and not hasattr(token_preprocessor, "__iter__"):
            token_preprocessor = [token_preprocessor]

        preprocessors = list(text_preprocessor or [])
        preprocessors.append(seqio.preprocessors.tokenize)
        if supports_caching:
            preprocessors.append(seqio.CacheDatasetPlaceholder())
        preprocessors.extend(token_preprocessor or [])
        preprocessors.append(seqio.preprocessors.append_eos_after_trim)

        if hasattr(output_features, "__len__") and not output_features:
            raise ValueError("output_features must be non-empty.")
        if output_features is None:
            output_features = seqio.Feature(utils.get_default_vocabulary())
        if isinstance(output_features, dict):
            pass
        elif isinstance(output_features, seqio.Feature):
            output_features = {
                k: output_features
                for k in _DEFAULT_FEATURE_KEYS
            }
        elif isinstance(output_features, list) and all(
                isinstance(f, str) for f in output_features):
            output_features = {
                k: seqio.Feature(utils.get_default_vocabulary())
                for k in output_features
            }
        else:
            raise ValueError(
                "output_features must be a dict, Feature, list of str, or None"
            )

        if hasattr(postprocess_fn, "__iter__"):
            postprocess_fns = postprocess_fn

            def postprocess_fn(x, **postprocess_kwargs):  # pylint:disable=function-redefined
                for post_fn in postprocess_fns:
                    x = post_fn(x, **postprocess_kwargs)
                return x

        super().__init__(name=name,
                         source=source,
                         output_features=output_features,
                         preprocessors=preprocessors,
                         postprocess_fn=postprocess_fn,
                         metric_fns=metric_fns,
                         shuffle_buffer_size=shuffle_buffer_size)