b = tfds.translate.wmt_t2t.WmtT2tTranslate.builder_configs["de-en"]
TaskRegistry.add("wmt_t2t_ende_v003",
                 TfdsTask,
                 tfds_name="wmt_t2t_translate/de-en:1.0.0",
                 text_preprocessor=functools.partial(
                     preprocessors.translate,
                     source_language=b.language_pair[1],
                     target_language=b.language_pair[0],
                 ),
                 metric_fns=[metrics.bleu],
                 output_features=DEFAULT_OUTPUT_FEATURES)

# ================================= SuperGlue ==================================
SUPERGLUE_METRICS = collections.OrderedDict([
    ("boolq", [metrics.accuracy]),
    ("cb", [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]),
    ("copa", [metrics.accuracy]),
    ("multirc", [
        metrics.multirc_f1_over_all_answers,
        metrics.mean_group_metric(metrics.exact_match)
    ]),
    ("record", [metrics.squad]),
    ("rte", [metrics.accuracy]),
    ("wic", [metrics.accuracy]),
    ("axb", []),  # Only test set available.
    ("axg", []),  # Only test set available.
])

for b in tfds.text.super_glue.SuperGlue.builder_configs.values():
    # We use a simplified version of WSC, defined below
    if "wsc" in b.name:
示例#2
0
    "non_natural_language",
    "generative_non_true_implausible",
]

AGGRESSIVE_EXCLUDE_CRETERIA = [
    "generative_non_true_task",
    "nontrivial_choices_hidden",
    "awkward_phrasing",
    "ungrammatical",
] + SAFE_EXCLUDE_CRETERIA


NON_GLUE_METRICS = {  # for those with do_eval = True
    "anli": [accuracy],
    "hans": [accuracy],
    "circa_goldstandard1_judgement": [mean_multiclass_f1(num_classes=8), accuracy],
    "circa_goldstandard2_judgement": [mean_multiclass_f1(num_classes=5), accuracy],
    "mc_taco": [accuracy],
    "nq_open": [accuracy],
    "qa_srl": [accuracy],
    "openbookqa": [accuracy],
    "race": [accuracy],
    "social_i_qa": [accuracy],
    "emo": [mean_multiclass_f1(num_classes=4)],
    "xsum": [rouge],
}


def exclude_bad_prompts(prompt: Dict) -> bool:
    for criterion in SAFE_EXCLUDE_CRETERIA:  # or AGGRESSIVE_EXCLUDE_CRETERIA
        if prompt.get(criterion):
 def test_multiclass_f1(self):
     self.assertDictClose(
         metrics.mean_multiclass_f1(num_classes=3)([0, 1, 1, 2],
                                                   [0, 0, 2, 2]),
         {"mean_3class_f1": 44.44444444444444})
示例#4
0
    "wmt_t2t_ende_v003",
    TfdsTask,
    tfds_name="wmt_t2t_translate/de-en:1.0.0",
    text_preprocessor=functools.partial(
        preprocessors.translate,
        source_language=b.language_pair[1],
        target_language=b.language_pair[0],
        ),
    metric_fns=[metrics.bleu],
    output_features=DEFAULT_OUTPUT_FEATURES)

# ================================= SuperGlue ==================================
SUPERGLUE_METRICS = collections.OrderedDict([
    ("boolq", [metrics.accuracy]),
    ("cb", [
        metrics.mean_multiclass_f1(num_classes=3),
        metrics.accuracy
    ]),
    ("copa", [metrics.accuracy]),
    ("multirc", [
        metrics.multirc_f1_over_all_answers,
        metrics.mean_group_metric(metrics.exact_match)
    ]),
    ("record", [metrics.squad]),
    ("rte", [metrics.accuracy]),
    ("wic", [metrics.accuracy]),
    ("axb", []),  # Only test set available.
    ("axg", []),  # Only test set available.
])

for b in tfds.text.super_glue.SuperGlue.builder_configs.values():