Esempio n. 1
0
mnli_config = tfds.text.glue.Glue.builder_configs["mnli"]
# pylint: disable=protected-access
TaskRegistry.add(
    "mnli_v002",
    source=seqio.TfdsDataSource(tfds_name="glue/mnli:1.0.0"),
    preprocessors=[
        functools.partial(t5_preprocessors.glue,
                          benchmark_name="nli",
                          label_names=mnli_config.label_classes),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
    ],
    metric_fns=glue_utils.GLUE_METRICS["mnli"],
    output_features=DEFAULT_OUTPUT_FEATURES,
    postprocess_fn=glue_utils.get_glue_postprocess_fn(mnli_config),
)
for mnli_eval_set in ("matched", "mismatched"):
    TaskRegistry.add(
        "mnli_explain_eval_%s_v002" % mnli_eval_set,
        source=seqio.TfdsDataSource(tfds_name="glue/mnli_%s:1.0.0" %
                                    mnli_eval_set),
        preprocessors=[
            functools.partial(t5_preprocessors.glue,
                              benchmark_name="explain nli",
                              label_names=mnli_config.label_classes),
            seqio.preprocessors.tokenize,
            seqio.CacheDatasetPlaceholder(),
            seqio.preprocessors.append_eos_after_trim,
        ],
        metric_fns=[metrics.esnli_metric],
Esempio n. 2
0
                                                         "targets": "text"
                                                     }),
                 token_preprocessor=preprocessors.unsupervised,
                 output_features=DEFAULT_OUTPUT_FEATURES,
                 metric_fns=[])

# =================================== GLUE =====================================
for b in tfds.text.glue.Glue.builder_configs.values():
    TaskRegistry.add(
        "glue_%s_v002" % b.name,
        TfdsTask,
        tfds_name="glue/%s:1.0.0" % b.name,
        text_preprocessor=get_glue_text_preprocessor(b),
        metric_fns=get_glue_metric(b.name),
        output_features=DEFAULT_OUTPUT_FEATURES,
        postprocess_fn=get_glue_postprocess_fn(b),
        splits=["test"] if b.name == "ax" else None,
    )

# =============================== CNN DailyMail ================================
TaskRegistry.add("cnn_dailymail_v002",
                 TfdsTask,
                 tfds_name="cnn_dailymail:1.0.0",
                 text_preprocessor=functools.partial(preprocessors.summarize,
                                                     article_key="article",
                                                     summary_key="highlights"),
                 metric_fns=[metrics.rouge],
                 output_features=DEFAULT_OUTPUT_FEATURES)

# ==================================== WMT =====================================
# Format: year, tfds builder config, tfds version
Esempio n. 3
0
      output_features=DEFAULT_OUTPUT_FEATURES,
      metric_fns=[])

mnli_config = tfds.text.glue.Glue.builder_configs["mnli"]
# pylint: disable=protected-access
TaskRegistry.add(
    "mnli_v002",
    TfdsTask,
    tfds_name="glue/mnli:1.0.0",
    text_preprocessor=functools.partial(
        t5_preprocessors.glue,
        benchmark_name="nli",
        label_names=mnli_config.label_classes),
    metric_fns=t5.data.glue_utils.GLUE_METRICS["mnli"],
    output_features=DEFAULT_OUTPUT_FEATURES,
    postprocess_fn=get_glue_postprocess_fn(mnli_config),
)
for mnli_eval_set in ("matched", "mismatched"):
  TaskRegistry.add(
      "mnli_explain_eval_%s_v002" % mnli_eval_set,
      TfdsTask,
      tfds_name="glue/mnli_%s:1.0.0" % mnli_eval_set,
      text_preprocessor=functools.partial(
          t5_preprocessors.glue,
          benchmark_name="explain nli",
          label_names=mnli_config.label_classes),
      metric_fns=[metrics.esnli_metric],
      output_features=DEFAULT_OUTPUT_FEATURES,
      postprocess_fn=postprocessors.abstractive_explanations,
  )
# pylint: enable=protected-access