mnli_config = tfds.text.glue.Glue.builder_configs["mnli"] # pylint: disable=protected-access TaskRegistry.add( "mnli_v002", source=seqio.TfdsDataSource(tfds_name="glue/mnli:1.0.0"), preprocessors=[ functools.partial(t5_preprocessors.glue, benchmark_name="nli", label_names=mnli_config.label_classes), seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], metric_fns=glue_utils.GLUE_METRICS["mnli"], output_features=DEFAULT_OUTPUT_FEATURES, postprocess_fn=glue_utils.get_glue_postprocess_fn(mnli_config), ) for mnli_eval_set in ("matched", "mismatched"): TaskRegistry.add( "mnli_explain_eval_%s_v002" % mnli_eval_set, source=seqio.TfdsDataSource(tfds_name="glue/mnli_%s:1.0.0" % mnli_eval_set), preprocessors=[ functools.partial(t5_preprocessors.glue, benchmark_name="explain nli", label_names=mnli_config.label_classes), seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], metric_fns=[metrics.esnli_metric],
"targets": "text" }), token_preprocessor=preprocessors.unsupervised, output_features=DEFAULT_OUTPUT_FEATURES, metric_fns=[]) # =================================== GLUE ===================================== for b in tfds.text.glue.Glue.builder_configs.values(): TaskRegistry.add( "glue_%s_v002" % b.name, TfdsTask, tfds_name="glue/%s:1.0.0" % b.name, text_preprocessor=get_glue_text_preprocessor(b), metric_fns=get_glue_metric(b.name), output_features=DEFAULT_OUTPUT_FEATURES, postprocess_fn=get_glue_postprocess_fn(b), splits=["test"] if b.name == "ax" else None, ) # =============================== CNN DailyMail ================================ TaskRegistry.add("cnn_dailymail_v002", TfdsTask, tfds_name="cnn_dailymail:1.0.0", text_preprocessor=functools.partial(preprocessors.summarize, article_key="article", summary_key="highlights"), metric_fns=[metrics.rouge], output_features=DEFAULT_OUTPUT_FEATURES) # ==================================== WMT ===================================== # Format: year, tfds builder config, tfds version
output_features=DEFAULT_OUTPUT_FEATURES, metric_fns=[]) mnli_config = tfds.text.glue.Glue.builder_configs["mnli"] # pylint: disable=protected-access TaskRegistry.add( "mnli_v002", TfdsTask, tfds_name="glue/mnli:1.0.0", text_preprocessor=functools.partial( t5_preprocessors.glue, benchmark_name="nli", label_names=mnli_config.label_classes), metric_fns=t5.data.glue_utils.GLUE_METRICS["mnli"], output_features=DEFAULT_OUTPUT_FEATURES, postprocess_fn=get_glue_postprocess_fn(mnli_config), ) for mnli_eval_set in ("matched", "mismatched"): TaskRegistry.add( "mnli_explain_eval_%s_v002" % mnli_eval_set, TfdsTask, tfds_name="glue/mnli_%s:1.0.0" % mnli_eval_set, text_preprocessor=functools.partial( t5_preprocessors.glue, benchmark_name="explain nli", label_names=mnli_config.label_classes), metric_fns=[metrics.esnli_metric], output_features=DEFAULT_OUTPUT_FEATURES, postprocess_fn=postprocessors.abstractive_explanations, ) # pylint: enable=protected-access