예제 #1
0
def register_seqio_task(
    bigbench_task_name: str,
    bigbench_task_path: str,
    bigbench_task_type: bb.BigBenchTaskType,
    vocab: SeqIOVocabulary,
    num_shots: int,
    bigbench_subtask_name: Optional[str] = None,
    max_examples: Optional[int] = None,
    strip_inputs: bool = True,
    strip_targets: bool = True,
    add_inputs_eos: bool = False,
    add_targets_eos: bool = False,
    json_util: json_utils.JsonUtils = json_utils.get_default_json_utils(),
    min_validation_examples: int = _GLOBAL_MIN_VALIDATION_EXAMPLES,
    additional_metrics: Optional[Sequence[seqio.MetricFnCallable]] = None,
) -> str:
    """Registers a BIG-bench SeqIO Task and returns the Task name."""
    seqio_task_name = get_seqio_name(bigbench_task_name, bigbench_task_type,
                                     vocab, num_shots, bigbench_subtask_name,
                                     max_examples)

    if seqio_task_name in seqio.TaskRegistry.names():
        return seqio_task_name

    additional_metrics = additional_metrics or []
    seqio.TaskRegistry.add(
        seqio_task_name,
        source=seqio.FunctionDataSource(bb.get_dataset_fn(
            task_name=bigbench_task_name,
            task_path=bigbench_task_path,
            subtask_name=bigbench_subtask_name,
            num_shots=num_shots,
            bigbench_task_type=bigbench_task_type,
            max_examples=max_examples,
            json_util=json_util,
            min_validation_examples=min_validation_examples,
        ),
                                        splits=["all", "train", "validation"]),
        preprocessors=bb.get_preprocessors(strip_inputs=strip_inputs,
                                           strip_targets=strip_targets),
        output_features=bb.get_output_features(
            vocab=vocab.vocabulary,
            add_inputs_eos=add_inputs_eos,
            add_targets_eos=add_targets_eos),
        postprocess_fn=bb.get_postprocess_fn(
            task_name=bigbench_task_name,
            task_path=bigbench_task_path,
            subtask_name=bigbench_subtask_name,
            bigbench_task_type=bigbench_task_type,
            json_util=json_util),
        metric_fns=[
            bb.get_metric_fn(task_name=bigbench_task_name,
                             task_path=bigbench_task_path,
                             subtask_name=bigbench_subtask_name,
                             bigbench_task_type=bigbench_task_type,
                             json_util=json_util)
        ] + additional_metrics)
    return seqio_task_name
예제 #2
0
def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None):

    template = all_templates.get_dataset(dataset_name, subset_name)[template_name]

    task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name)
    if task_name in CLEAN_EVAL_TASKS:
        metrics = EVAL_METRICS[task_name]
    else:
        metrics = [t5.evaluation.metrics.sequence_accuracy]

    dataset_splits = utils.get_dataset_splits(dataset_name, subset_name)
    split_mapping = split_mapping or {k: k for k in dataset_splits.keys()}

    dataset_fn = functools.partial(
        get_tf_dataset,
        seed=None,
        dataset_name=dataset_name,
        subset_name=subset_name,
        template=template,
        split_mapping=split_mapping,
    )
    data_source = seqio.FunctionDataSource(
        dataset_fn,
        splits=list(split_mapping.keys()),
        num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()},
    )
    output_features = {
        "inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32),
        "targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32),
    }
    preprocessors = [
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
        seqio.CacheDatasetPlaceholder(required=False),
    ]

    # Add train and normal eval tasks
    seqio.TaskRegistry.add(
        task_name,
        data_source,
        preprocessors=preprocessors,
        output_features=output_features,
        metric_fns=metrics,
        postprocess_fn=maybe_get_class_id_postprocessor(template),
    )

    # Add rank classification eval task
    labels = get_label_strings(template)
    if labels:
        rank_classification_preprocessor = functools.partial(
            t5.data.preprocessors.rank_classification,
            inputs_fn=lambda ex: tf.fill((len(labels),), ex["inputs"]),
            targets_fn=lambda ex: labels,
            is_correct_fn=lambda ex: tf.equal(labels, tf.strings.strip(ex["targets"])),
            weight_fn=lambda ex: 1.0,
        )
        seqio.TaskRegistry.add(
            task_name + "_score_eval",
            data_source,
            preprocessors=[rank_classification_preprocessor] + preprocessors,
            output_features=output_features,
            metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=len(labels))],
            postprocess_fn=t5.data.postprocessors.rank_classification,
        )
예제 #3
0
            source=seqio.TfdsDataSource(
                tfds_name="xtreme_pawsx/{}:1.0.0".format(lang),
                splits=["train"]),
            preprocessors=[
                text_preprocessor,
                seqio.preprocessors.tokenize,
                seqio.CacheDatasetPlaceholder(),
                seqio.preprocessors.append_eos_after_trim,
            ],
            output_features=DEFAULT_OUTPUT_FEATURES,
            postprocess_fn=postprocess_fn,
            metric_fns=[metrics.accuracy])

seqio.TaskRegistry.add("mt5_pawsx_dev_test.all_langs",
                       source=seqio.FunctionDataSource(
                           dataset_fn=utils.pawsx_all_langs_dataset_fn,
                           splits=["validation", "test"]),
                       preprocessors=[
                           text_preprocessor,
                           seqio.preprocessors.tokenize,
                           seqio.CacheDatasetPlaceholder(),
                           seqio.preprocessors.append_eos_after_trim,
                       ],
                       output_features=DEFAULT_OUTPUT_FEATURES,
                       postprocess_fn=postprocess_fn,
                       metric_fns=[metrics.accuracy])

# PAWSX Zero-Shot
pawsx_eval = [
    "mt5_pawsx_dev_test.{}".format(lang) for lang in utils.PAWSX_LANGS
] + ["mt5_pawsx_dev_test.all_langs"]
  def __init__(self,
               name,
               dataset_fn,
               splits,
               text_preprocessor,
               metric_fns=None,
               postprocess_fn=None,
               token_preprocessor=None,
               output_features=None,
               num_input_examples=None,
               supports_caching=True,
               shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
               source=None):

    if (dataset_fn, source).count(None) != 1:
      raise ValueError(
          "Exactly one of either `dataset_fn` or `source` must be provided.")

    if source and (splits or num_input_examples):
      raise ValueError(
          "If `source` is provided, `splits` and `num_input_examples` should "
          "not also be provided to the Task.")
    source = source or seqio.FunctionDataSource(
        dataset_fn=dataset_fn,
        splits=splits,
        num_input_examples=num_input_examples)

    if text_preprocessor and not hasattr(text_preprocessor, "__iter__"):
      text_preprocessor = [text_preprocessor]
    if token_preprocessor and not hasattr(token_preprocessor, "__iter__"):
      token_preprocessor = [token_preprocessor]

    preprocessors = list(text_preprocessor or [])
    preprocessors.append(seqio.preprocessors.tokenize)
    if supports_caching:
      preprocessors.append(seqio.CacheDatasetPlaceholder())
    preprocessors.extend(token_preprocessor or [])
    preprocessors.append(seqio.preprocessors.append_eos_after_trim)

    if hasattr(output_features, "__len__") and not output_features:
      raise ValueError("output_features must be non-empty.")
    if output_features is None:
      output_features = seqio.Feature(utils.get_default_vocabulary())
    if isinstance(output_features, dict):
      pass
    elif isinstance(output_features, seqio.Feature):
      output_features = {k: output_features for k in _DEFAULT_FEATURE_KEYS}
    elif isinstance(output_features, list) and all(
        isinstance(f, str) for f in output_features):
      output_features = {
          k: seqio.Feature(utils.get_default_vocabulary())
          for k in output_features
      }
    else:
      raise ValueError(
          "output_features must be a dict, Feature, list of str, or None")

    if hasattr(postprocess_fn, "__iter__"):
      postprocess_fns = postprocess_fn

      def postprocess_fn(x, **postprocess_kwargs):  # pylint:disable=function-redefined
        for post_fn in postprocess_fns:
          x = post_fn(x, **postprocess_kwargs)
        return x

    super().__init__(
        name=name,
        source=source,
        output_features=output_features,
        preprocessors=preprocessors,
        postprocess_fn=postprocess_fn,
        metric_fns=metric_fns,
        shuffle_buffer_size=shuffle_buffer_size)