b = tfds.translate.wmt_t2t.WmtT2tTranslate.builder_configs["de-en"] TaskRegistry.add("wmt_t2t_ende_v003", TfdsTask, tfds_name="wmt_t2t_translate/de-en:1.0.0", text_preprocessor=functools.partial( preprocessors.translate, source_language=b.language_pair[1], target_language=b.language_pair[0], ), metric_fns=[metrics.bleu], output_features=DEFAULT_OUTPUT_FEATURES) # ================================= SuperGlue ================================== SUPERGLUE_METRICS = collections.OrderedDict([ ("boolq", [metrics.accuracy]), ("cb", [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]), ("copa", [metrics.accuracy]), ("multirc", [ metrics.multirc_f1_over_all_answers, metrics.mean_group_metric(metrics.exact_match) ]), ("record", [metrics.squad]), ("rte", [metrics.accuracy]), ("wic", [metrics.accuracy]), ("axb", []), # Only test set available. ("axg", []), # Only test set available. ]) for b in tfds.text.super_glue.SuperGlue.builder_configs.values(): # We use a simplified version of WSC, defined below if "wsc" in b.name:
"non_natural_language", "generative_non_true_implausible", ] AGGRESSIVE_EXCLUDE_CRETERIA = [ "generative_non_true_task", "nontrivial_choices_hidden", "awkward_phrasing", "ungrammatical", ] + SAFE_EXCLUDE_CRETERIA NON_GLUE_METRICS = { # for those with do_eval = True "anli": [accuracy], "hans": [accuracy], "circa_goldstandard1_judgement": [mean_multiclass_f1(num_classes=8), accuracy], "circa_goldstandard2_judgement": [mean_multiclass_f1(num_classes=5), accuracy], "mc_taco": [accuracy], "nq_open": [accuracy], "qa_srl": [accuracy], "openbookqa": [accuracy], "race": [accuracy], "social_i_qa": [accuracy], "emo": [mean_multiclass_f1(num_classes=4)], "xsum": [rouge], } def exclude_bad_prompts(prompt: Dict) -> bool: for criterion in SAFE_EXCLUDE_CRETERIA: # or AGGRESSIVE_EXCLUDE_CRETERIA if prompt.get(criterion):
def test_multiclass_f1(self): self.assertDictClose( metrics.mean_multiclass_f1(num_classes=3)([0, 1, 1, 2], [0, 0, 2, 2]), {"mean_3class_f1": 44.44444444444444})
"wmt_t2t_ende_v003", TfdsTask, tfds_name="wmt_t2t_translate/de-en:1.0.0", text_preprocessor=functools.partial( preprocessors.translate, source_language=b.language_pair[1], target_language=b.language_pair[0], ), metric_fns=[metrics.bleu], output_features=DEFAULT_OUTPUT_FEATURES) # ================================= SuperGlue ================================== SUPERGLUE_METRICS = collections.OrderedDict([ ("boolq", [metrics.accuracy]), ("cb", [ metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy ]), ("copa", [metrics.accuracy]), ("multirc", [ metrics.multirc_f1_over_all_answers, metrics.mean_group_metric(metrics.exact_match) ]), ("record", [metrics.squad]), ("rte", [metrics.accuracy]), ("wic", [metrics.accuracy]), ("axb", []), # Only test set available. ("axg", []), # Only test set available. ]) for b in tfds.text.super_glue.SuperGlue.builder_configs.values():