def __init__(
      self,
      name,
      tfds_name,
      text_preprocessor,
      metric_fns,
      tfds_data_dir=None,
      splits=None,
      **task_kwargs):
    """TfdsTask constructor.

    Args:
      name: string, a unique name for the Task. A ValueError will be raised if
        another task with this name is already registered.
      tfds_name: string, the name and version number of a TFDS dataset,
        optionally with a config.
      text_preprocessor: a function (or list of functions) that (each) takes in
        a tf.data.Dataset of string features and returns a tf.data.Dataset of
        string features. Can be set to None as a no-op. If a list is given, they
        will be executed sequentially.
      metric_fns: list(callable), list of metric functions with the signature
        metric_fn(targets, predictions) to use during evaluation.
      tfds_data_dir: string, an optional path to a specific TFDS data directory
        to use.
      splits: a list(string) of allowable splits to load, a dict mapping
        allowable canonical splits (e.g., 'validation') to TFDS splits or slices
        (e.g., 'train[':1%']), or None. The default, None, uses all available
          splits from the TFDS dataset info.
      **task_kwargs: dict, additional keyword arguments for the parent `Task`
        class.
    """
    super().__init__(
        name,
        source=seqio.TfdsDataSource(
            tfds_name=tfds_name, tfds_data_dir=tfds_data_dir, splits=splits),
        text_preprocessor=text_preprocessor,
        metric_fns=metric_fns,
        dataset_fn=None,
        splits=None,
        **task_kwargs)
示例#2
0
    "es", "et", "eu", "fa", "fi", "fr", "fy", "ga", "gl", "gu", "he", "hi",
    "hr", "ht", "hu", "hy", "id", "io", "is", "it", "ja", "jv", "ka", "kk",
    "kn", "ko", "ky", "la", "lb", "lmo", "lt", "lv", "mg", "min", "mk", "ml",
    "mn", "mr", "ms", "my", "nds-nl", "ne", "new", "nl", "nn", "no", "oc",
    "pa", "pl", "pms", "pnb", "pt", "ro", "ru", "scn", "sco", "sh", "sk", "sl",
    "sq", "sr", "su", "sv", "sw", "ta", "te", "tg", "th", "tl", "tr", "tt",
    "uk", "ur", "uz", "vi", "vo", "war", "yo", "zh"
]

# =========================== Pretraining Tasks/Mixtures =======================
# mC4
for lang in MC4_LANGS:
    seqio.TaskRegistry.add("mc4.{}".format(lang.replace("-", "_")),
                           source=seqio.TfdsDataSource(
                               tfds_name="c4/multilingual:3.0.1",
                               splits={
                                   "train": lang,
                                   "validation": f"{lang}-validation"
                               }),
                           preprocessors=[
                               functools.partial(t5.data.preprocessors.rekey,
                                                 key_map={
                                                     "inputs": None,
                                                     "targets": "text"
                                                 }),
                               seqio.preprocessors.tokenize,
                               seqio.CacheDatasetPlaceholder(),
                               t5.data.preprocessors.span_corruption,
                               seqio.preprocessors.append_eos_after_trim,
                           ],
                           output_features=DEFAULT_OUTPUT_FEATURES,
                           metric_fns=[])
示例#3
0
文件: tasks.py 项目: AIRC-KETI/ke-t5
                  add_eos=False,
                  required=False,
                  dtype=tf.int32),
    "targets":
    seqio.Feature(vocabulary=DEFAULT_VOCAB, add_eos=True, dtype=tf.int32)
}

# CLASSIFICATION_OUTPUT_FEATURES = {
#     "inputs": seqio.Feature(
#         vocabulary=DEFAULT_VOCAB, add_eos=False, required=False, dtype=tf.int32)
# }

# ============ KLUE topic classification: Generative ============
seqio.TaskRegistry.add(
    "klue_tc_gen",
    seqio.TfdsDataSource(tfds_name="klue/tc:1.0.0"),
    preprocessors=[
        functools.partial(seqio.preprocessors.rekey,
                          key_map={
                              "id": "guid",
                              "title": "title",
                              "label": "label",
                          }),
        functools.partial(
            base_preproc_for_classification,
            benchmark_name='klue_tc',
            input_keys=['title'],
            label_names=KLUE_META['tc_classes'],
            with_feature_key=True,
        ), seqio.preprocessors.tokenize, seqio.preprocessors.append_eos
    ],
示例#4
0
# evaluation procedure.
# The model is trained to predict all ground-truth answers
# and is only considered correct if it predicts all answers for any one of the
# annotators. As in the official evaluation, we consider questions with fewer
# than two non-null annotations unanswerable (given the context) but because we
# cannot predict unanswerability without the context, we only compute the recall
# metric. Further, because our model does not have access to the oracle context,
# we also normalize predicted and ground-truth answers when comparing them.

# This task uses a portion of the train set for validation.
TaskRegistry.add(
    "natural_questions_nocontext",
    source=seqio.TfdsDataSource(
        tfds_name="natural_questions:0.0.2",
        splits={
            "train": f"train[{NQ_TRAIN_SPLIT_START}:{NQ_TRAIN_SPLIT_END}]",
            "validation": f"train[:{NQ_TRAIN_SPLIT_START}]",
            "test": "validation"
        }),
    preprocessors=[
        preprocessors.natural_questions_nocontext,
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features=DEFAULT_OUTPUT_FEATURES,
    postprocess_fn=postprocessors.natural_questions,
    metric_fns=[
        functools.partial(
            metrics.natural_questions,
            # Train set does not contain multiple annotations.
示例#5
0
WMT14_VOCAB_EXTRA_100 = seqio.SentencePieceVocabulary(WMT14_CUSTOM_SPM_PATH,
                                                      extra_ids=100)
EN_VOCAB_EXTRA_100 = seqio.SentencePieceVocabulary(EN_VOCAB_SPM_PATH,
                                                   extra_ids=100)

EN_VOCAB_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(vocabulary=EN_VOCAB_EXTRA_100, add_eos=True),
    "targets": seqio.Feature(vocabulary=EN_VOCAB_EXTRA_100, add_eos=True)
}

#================================ English only vocab ===========================
for version in ("2.2.0", "2.3.0", "2.3.1"):
    TaskRegistry.add(
        "c4_v{}_unsupervised_en32k".format(version.replace(".", "")),
        source=seqio.TfdsDataSource(tfds_name="c4/en:{}".format(version)),
        preprocessors=[
            functools.partial(t5_preprocessors.rekey,
                              key_map={
                                  "inputs": None,
                                  "targets": "text"
                              }),
            seqio.preprocessors.tokenize,
            seqio.CacheDatasetPlaceholder(),
            t5_preprocessors.unsupervised,
            seqio.preprocessors.append_eos_after_trim,
        ],
        output_features=EN_VOCAB_OUTPUT_FEATURES,
        metric_fns=[])

#================================ XSUM =========================================
示例#6
0
from t5.data import glue_utils
from t5.data import postprocessors as t5_postprocessors
from t5.data import preprocessors as t5_preprocessors
from t5.evaluation import metrics as t5_metrics
import tensorflow_datasets as tfds

TaskRegistry = seqio.TaskRegistry

DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=True),
    "targets": seqio.Feature(vocabulary=get_default_vocabulary(), add_eos=True)
}

# ======================== CoS-E Corpus Task ==================================
TaskRegistry.add("cos_e_v001",
                 source=seqio.TfdsDataSource(tfds_name="cos_e:0.0.1"),
                 preprocessors=[
                     preprocessors.cos_e,
                     seqio.preprocessors.tokenize,
                     seqio.CacheDatasetPlaceholder(),
                     seqio.preprocessors.append_eos_after_trim,
                 ],
                 postprocess_fn=postprocessors.abstractive_explanations,
                 output_features=DEFAULT_OUTPUT_FEATURES,
                 metric_fns=[metrics.esnli_metric])

# CoS-E with no explanations, and modified prefixes like e-SNLI.
TaskRegistry.add("cos_e_v001_0_expln_like_esnli",
                 source=seqio.TfdsDataSource(tfds_name="cos_e:0.0.1"),
                 preprocessors=[
                     functools.partial(preprocessors.cos_e,
示例#7
0
        "seqio.SentencePieceVocabulary.sentencepiece_model_file")

    custom_vocab = seqio.SentencePieceVocabulary(sentence_piece_model_path,
                                                 extra_ids)
    return {
        "inputs":
        seqio.Feature(vocabulary=custom_vocab, add_eos=add_eos,
                      required=False),
        "targets":
        seqio.Feature(vocabulary=custom_vocab, add_eos=add_eos)
    }


# C4 c4_v220_span_corruption data with a custom sentencepiece.
TaskRegistry.add("c4_v220_span_corruption_custom_sp",
                 source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"),
                 preprocessors=[
                     functools.partial(preprocessors.rekey,
                                       key_map={
                                           "inputs": None,
                                           "targets": "text"
                                       }),
                     seqio.preprocessors.tokenize,
                     seqio.CacheDatasetPlaceholder(),
                     preprocessors.span_corruption,
                     seqio.preprocessors.append_eos_after_trim,
                 ],
                 output_features=get_custom_output_features(),
                 metric_fns=[])

# =================================== GLUE =====================================
示例#8
0
def create_xnli_tasks_and_mixtures(task_prefix, task_suffix, output_features):
  """Helper function to create XNLI tasks and mixtures."""
  if task_suffix:
    task_suffix = "_" + task_suffix
  seqio.TaskRegistry.add(
      f"{task_prefix}xnli_train{task_suffix}",
      source=seqio.TfdsDataSource(
          tfds_name="multi_nli:1.1.0", splits=["train"]),
      preprocessors=[
          preprocessors.process_mnli,
          seqio.preprocessors.tokenize,
          seqio.CacheDatasetPlaceholder(),
          seqio.preprocessors.append_eos_after_trim,
      ],
      output_features=output_features,
      metric_fns=[metrics.accuracy])
  for xnli_lang in XNLI_LANGS:
    seqio.TaskRegistry.add(
        f"{task_prefix}xnli_dev_test{task_suffix}.{xnli_lang}",
        source=seqio.TfdsDataSource(
            tfds_name="xnli:1.1.0", splits=["validation", "test"]),
        preprocessors=[
            functools.partial(
                preprocessors.process_xnli, target_languages=[xnli_lang]),
            seqio.preprocessors.tokenize,
            seqio.CacheDatasetPlaceholder(),
            seqio.preprocessors.append_eos_after_trim,
        ],
        output_features=output_features,
        metric_fns=[metrics.accuracy])
    if xnli_lang == "en":
      continue
    seqio.TaskRegistry.add(
        f"{task_prefix}xnli_translate_train{task_suffix}.{xnli_lang}",
        source=seqio.TfdsDataSource(
            tfds_name="xtreme_xnli:1.1.0", splits=["train"]),
        preprocessors=[
            functools.partial(
                preprocessors.process_xnli, target_languages=[xnli_lang]),
            seqio.preprocessors.tokenize,
            seqio.CacheDatasetPlaceholder(),
            seqio.preprocessors.append_eos_after_trim,
        ],
        output_features=output_features,
        metric_fns=[metrics.accuracy])
  seqio.TaskRegistry.add(
      f"{task_prefix}xnli_dev_test{task_suffix}.all_langs",
      source=seqio.TfdsDataSource(
          tfds_name="xnli:1.1.0", splits=["validation", "test"]),
      preprocessors=[
          functools.partial(
              preprocessors.process_xnli, target_languages=XNLI_LANGS),
          seqio.preprocessors.tokenize,
          seqio.CacheDatasetPlaceholder(),
          seqio.preprocessors.append_eos_after_trim,
      ],
      output_features=output_features,
      metric_fns=[metrics.accuracy])
  xnli_zeroshot = ([
      f"{task_prefix}xnli_train{task_suffix}",
      f"{task_prefix}xnli_dev_test{task_suffix}.all_langs"
  ] + [
      f"{task_prefix}xnli_dev_test{task_suffix}.{lang}" for lang in XNLI_LANGS
  ])
  seqio.MixtureRegistry.add(
      f"{task_prefix}xnli_zeroshot{task_suffix}",
      xnli_zeroshot,
      default_rate=1.0)
  xnli_translate_train = xnli_zeroshot + [
      f"{task_prefix}xnli_translate_train{task_suffix}.{lang}"
      for lang in XNLI_LANGS
      if lang != "en"
  ]
  seqio.MixtureRegistry.add(
      f"{task_prefix}xnli_translate_train{task_suffix}",
      xnli_translate_train,
      default_rate=1.0)
示例#9
0
    "hr", "ht", "hu", "hy", "id", "io", "is", "it", "ja", "jv", "ka", "kk",
    "kn", "ko", "ky", "la", "lb", "lmo", "lt", "lv", "mg", "min", "mk", "ml",
    "mn", "mr", "ms", "my", "nds-nl", "ne", "new", "nl", "nn", "no", "oc",
    "pa", "pl", "pms", "pnb", "pt", "ro", "ru", "scn", "sco", "sh", "sk", "sl",
    "sq", "sr", "su", "sv", "sw", "ta", "te", "tg", "th", "tl", "tr", "tt",
    "uk", "ur", "uz", "vi", "vo", "war", "yo", "zh"
]

# =========================== Pretraining Tasks/Mixtures =======================
# mC4
for lang in MC4_LANGS:
  seqio.TaskRegistry.add(
      "mc4.{}".format(lang.replace("-", "_")),
      source=seqio.TfdsDataSource(
          tfds_name="c4/multilingual:3.0.1",
          splits={
              "train": lang,
              "validation": f"{lang}-validation"
          }),
      preprocessors=[
          functools.partial(
              t5.data.preprocessors.rekey,
              key_map={
                  "inputs": None,
                  "targets": "text"
              }),
          seqio.preprocessors.tokenize,
          seqio.CacheDatasetPlaceholder(),
          t5.data.preprocessors.span_corruption,
          seqio.preprocessors.append_eos_after_trim,
      ],
      output_features=DEFAULT_OUTPUT_FEATURES,


DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=t5.data.get_default_vocabulary(), add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
}

# ==================================== C4 ======================================
# Final pretraining task used in Raffel et al., 2019.
TaskRegistry.add(
    "c4_v220_span_corruption",
    source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,

    ],
    output_features=DEFAULT_OUTPUT_FEATURES,
    metric_fns=[])