def get_preprocessors(strip_inputs: bool = True,
                      strip_targets: bool = True) -> Sequence[Preprocessor]:
    """Returns BIG-bench preprocessors."""
    preprocessors = []

    if strip_inputs or strip_targets:
        keys_to_strip = []
        if strip_inputs:
            keys_to_strip.append("inputs")
        if strip_targets:
            keys_to_strip.append("targets")
            keys_to_strip.append("answers")

        preprocessors.append(
            get_strip_preprocessor(keys_to_strip=keys_to_strip))

    preprocessors.append(seqio.preprocessors.tokenize)
    preprocessors.append(seqio.CacheDatasetPlaceholder())
    preprocessors.append(seqio.preprocessors.append_eos)

    return preprocessors
示例#2
0
def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None):

    template = all_templates.get_dataset(dataset_name, subset_name)[template_name]

    task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name)
    if task_name in CLEAN_EVAL_TASKS:
        metrics = EVAL_METRICS[task_name]
    else:
        metrics = [t5.evaluation.metrics.sequence_accuracy]

    dataset_splits = utils.get_dataset_splits(dataset_name, subset_name)
    split_mapping = split_mapping or {k: k for k in dataset_splits.keys()}

    dataset_fn = functools.partial(
        get_tf_dataset,
        seed=None,
        dataset_name=dataset_name,
        subset_name=subset_name,
        template=template,
        split_mapping=split_mapping,
    )
    data_source = seqio.FunctionDataSource(
        dataset_fn,
        splits=list(split_mapping.keys()),
        num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()},
    )
    output_features = {
        "inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32),
        "targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32),
    }
    preprocessors = [
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
        seqio.CacheDatasetPlaceholder(required=False),
    ]

    # Add train and normal eval tasks
    seqio.TaskRegistry.add(
        task_name,
        data_source,
        preprocessors=preprocessors,
        output_features=output_features,
        metric_fns=metrics,
        postprocess_fn=maybe_get_class_id_postprocessor(template),
    )

    # Add rank classification eval task
    labels = get_label_strings(template)
    if labels:
        rank_classification_preprocessor = functools.partial(
            t5.data.preprocessors.rank_classification,
            inputs_fn=lambda ex: tf.fill((len(labels),), ex["inputs"]),
            targets_fn=lambda ex: labels,
            is_correct_fn=lambda ex: tf.equal(labels, tf.strings.strip(ex["targets"])),
            weight_fn=lambda ex: 1.0,
        )
        seqio.TaskRegistry.add(
            task_name + "_score_eval",
            data_source,
            preprocessors=[rank_classification_preprocessor] + preprocessors,
            output_features=output_features,
            metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=len(labels))],
            postprocess_fn=t5.data.postprocessors.rank_classification,
        )
  def __init__(self,
               name,
               dataset_fn,
               splits,
               text_preprocessor,
               metric_fns=None,
               postprocess_fn=None,
               token_preprocessor=None,
               output_features=None,
               num_input_examples=None,
               supports_caching=True,
               shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
               source=None):

    if (dataset_fn, source).count(None) != 1:
      raise ValueError(
          "Exactly one of either `dataset_fn` or `source` must be provided.")

    if source and (splits or num_input_examples):
      raise ValueError(
          "If `source` is provided, `splits` and `num_input_examples` should "
          "not also be provided to the Task.")
    source = source or seqio.FunctionDataSource(
        dataset_fn=dataset_fn,
        splits=splits,
        num_input_examples=num_input_examples)

    if text_preprocessor and not hasattr(text_preprocessor, "__iter__"):
      text_preprocessor = [text_preprocessor]
    if token_preprocessor and not hasattr(token_preprocessor, "__iter__"):
      token_preprocessor = [token_preprocessor]

    preprocessors = list(text_preprocessor or [])
    preprocessors.append(seqio.preprocessors.tokenize)
    if supports_caching:
      preprocessors.append(seqio.CacheDatasetPlaceholder())
    preprocessors.extend(token_preprocessor or [])
    preprocessors.append(seqio.preprocessors.append_eos_after_trim)

    if hasattr(output_features, "__len__") and not output_features:
      raise ValueError("output_features must be non-empty.")
    if output_features is None:
      output_features = seqio.Feature(utils.get_default_vocabulary())
    if isinstance(output_features, dict):
      pass
    elif isinstance(output_features, seqio.Feature):
      output_features = {k: output_features for k in _DEFAULT_FEATURE_KEYS}
    elif isinstance(output_features, list) and all(
        isinstance(f, str) for f in output_features):
      output_features = {
          k: seqio.Feature(utils.get_default_vocabulary())
          for k in output_features
      }
    else:
      raise ValueError(
          "output_features must be a dict, Feature, list of str, or None")

    if hasattr(postprocess_fn, "__iter__"):
      postprocess_fns = postprocess_fn

      def postprocess_fn(x, **postprocess_kwargs):  # pylint:disable=function-redefined
        for post_fn in postprocess_fns:
          x = post_fn(x, **postprocess_kwargs)
        return x

    super().__init__(
        name=name,
        source=source,
        output_features=output_features,
        preprocessors=preprocessors,
        postprocess_fn=postprocess_fn,
        metric_fns=metric_fns,
        shuffle_buffer_size=shuffle_buffer_size)
示例#4
0
for lang in MC4_LANGS:
    seqio.TaskRegistry.add("mc4.{}".format(lang.replace("-", "_")),
                           source=seqio.TfdsDataSource(
                               tfds_name="c4/multilingual:3.0.1",
                               splits={
                                   "train": lang,
                                   "validation": f"{lang}-validation"
                               }),
                           preprocessors=[
                               functools.partial(t5.data.preprocessors.rekey,
                                                 key_map={
                                                     "inputs": None,
                                                     "targets": "text"
                                                 }),
                               seqio.preprocessors.tokenize,
                               seqio.CacheDatasetPlaceholder(),
                               t5.data.preprocessors.span_corruption,
                               seqio.preprocessors.append_eos_after_trim,
                           ],
                           output_features=DEFAULT_OUTPUT_FEATURES,
                           metric_fns=[])

mc4 = ["mc4.{}".format(lang.replace("-", "_")) for lang in MC4_LANGS]
seqio.MixtureRegistry.add("mc4", mc4, default_rate=DEFAULT_MIX_RATE)

# Wikipedia
for lang in WIKI_LANGS:
    seqio.TaskRegistry.add(
        "mt5_wiki.{}".format(lang.replace("-", "_")),
        source=seqio.TfdsDataSource(
            tfds_name="wikipedia/20200301.{}:1.0.0".format(lang)),
示例#5
0
def create_xnli_tasks_and_mixtures(task_prefix, task_suffix, output_features):
  """Helper function to create XNLI tasks and mixtures."""
  if task_suffix:
    task_suffix = "_" + task_suffix
  seqio.TaskRegistry.add(
      f"{task_prefix}xnli_train{task_suffix}",
      source=seqio.TfdsDataSource(
          tfds_name="multi_nli:1.1.0", splits=["train"]),
      preprocessors=[
          preprocessors.process_mnli,
          seqio.preprocessors.tokenize,
          seqio.CacheDatasetPlaceholder(),
          seqio.preprocessors.append_eos_after_trim,
      ],
      output_features=output_features,
      metric_fns=[metrics.accuracy])
  for xnli_lang in XNLI_LANGS:
    seqio.TaskRegistry.add(
        f"{task_prefix}xnli_dev_test{task_suffix}.{xnli_lang}",
        source=seqio.TfdsDataSource(
            tfds_name="xnli:1.1.0", splits=["validation", "test"]),
        preprocessors=[
            functools.partial(
                preprocessors.process_xnli, target_languages=[xnli_lang]),
            seqio.preprocessors.tokenize,
            seqio.CacheDatasetPlaceholder(),
            seqio.preprocessors.append_eos_after_trim,
        ],
        output_features=output_features,
        metric_fns=[metrics.accuracy])
    if xnli_lang == "en":
      continue
    seqio.TaskRegistry.add(
        f"{task_prefix}xnli_translate_train{task_suffix}.{xnli_lang}",
        source=seqio.TfdsDataSource(
            tfds_name="xtreme_xnli:1.1.0", splits=["train"]),
        preprocessors=[
            functools.partial(
                preprocessors.process_xnli, target_languages=[xnli_lang]),
            seqio.preprocessors.tokenize,
            seqio.CacheDatasetPlaceholder(),
            seqio.preprocessors.append_eos_after_trim,
        ],
        output_features=output_features,
        metric_fns=[metrics.accuracy])
  seqio.TaskRegistry.add(
      f"{task_prefix}xnli_dev_test{task_suffix}.all_langs",
      source=seqio.TfdsDataSource(
          tfds_name="xnli:1.1.0", splits=["validation", "test"]),
      preprocessors=[
          functools.partial(
              preprocessors.process_xnli, target_languages=XNLI_LANGS),
          seqio.preprocessors.tokenize,
          seqio.CacheDatasetPlaceholder(),
          seqio.preprocessors.append_eos_after_trim,
      ],
      output_features=output_features,
      metric_fns=[metrics.accuracy])
  xnli_zeroshot = ([
      f"{task_prefix}xnli_train{task_suffix}",
      f"{task_prefix}xnli_dev_test{task_suffix}.all_langs"
  ] + [
      f"{task_prefix}xnli_dev_test{task_suffix}.{lang}" for lang in XNLI_LANGS
  ])
  seqio.MixtureRegistry.add(
      f"{task_prefix}xnli_zeroshot{task_suffix}",
      xnli_zeroshot,
      default_rate=1.0)
  xnli_translate_train = xnli_zeroshot + [
      f"{task_prefix}xnli_translate_train{task_suffix}.{lang}"
      for lang in XNLI_LANGS
      if lang != "en"
  ]
  seqio.MixtureRegistry.add(
      f"{task_prefix}xnli_translate_train{task_suffix}",
      xnli_translate_train,
      default_rate=1.0)