def test_mean_group_metric(self): metric_fn = metrics.mean_group_metric(metrics.accuracy) self.assertDictClose( metric_fn( [{"group": "a", "value": 0}, {"group": "a", "value": 1}, {"group": "b", "value": 0}], [{"value": 0}, {"value": 0}, {"value": 1}]), {"accuracy": 25.})
text_preprocessor=functools.partial( preprocessors.translate, source_language=b.language_pair[1], target_language=b.language_pair[0], ), metric_fns=[metrics.bleu], output_features=DEFAULT_OUTPUT_FEATURES) # ================================= SuperGlue ================================== SUPERGLUE_METRICS = collections.OrderedDict([ ("boolq", [metrics.accuracy]), ("cb", [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]), ("copa", [metrics.accuracy]), ("multirc", [ metrics.multirc_f1_over_all_answers, metrics.mean_group_metric(metrics.exact_match) ]), ("record", [metrics.squad]), ("rte", [metrics.accuracy]), ("wic", [metrics.accuracy]), ("axb", []), # Only test set available. ("axg", []), # Only test set available. ]) for b in tfds.text.super_glue.SuperGlue.builder_configs.values(): # We use a simplified version of WSC, defined below if "wsc" in b.name: continue if b.name == "axb": text_preprocessor = [ functools.partial(preprocessors.rekey,
tfds_name="wmt_t2t_translate/de-en:0.0.1", text_preprocessor=functools.partial( preprocessors.translate, source_language=b.language_pair[1], target_language=b.language_pair[0], ), metric_fns=[metrics.bleu], sentencepiece_model_path=DEFAULT_SPM_PATH) # ================================= SuperGlue ================================== SUPERGLUE_METRICS = collections.OrderedDict([ ("boolq", [metrics.accuracy]), ("cb", [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]), ("copa", [metrics.accuracy]), ("multirc", [ metrics.mean_group_metric(metrics.f1_score_with_invalid), metrics.mean_group_metric(metrics.exact_match) ]), ("record", [metrics.qa]), ("rte", [metrics.accuracy]), ("wic", [metrics.accuracy]), ("axb", []), # Only test set available. ("axg", []), # Only test set available. ]) for b in tfds.text.super_glue.SuperGlue.builder_configs.values(): # We use a simplified version of WSC, defined below if "wsc" in b.name: continue if b.name == "axb": text_preprocessor = [
("mnli_mismatched", [metrics.accuracy]), ("qnli", [metrics.accuracy]), ("rte", [metrics.accuracy]), ("wnli", [metrics.accuracy]), ("ax", []), # Only test set available. ]) def get_glue_metric(task_name): return GLUE_METRICS[task_name] SUPERGLUE_METRICS = collections.OrderedDict([ ("boolq", [metrics.accuracy]), ("cb", [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]), ("copa", [metrics.accuracy]), ("multirc", [ metrics.multirc_f1_over_all_answers, metrics.mean_group_metric(metrics.all_match) ]), ("record", [metrics.deduplicate_metric(metrics.squad)]), ("rte", [metrics.accuracy]), ("wic", [metrics.accuracy]), ("axb", []), # Only test set available. ("axg", []), # Only test set available. ]) def get_super_glue_metric(task_name): return SUPERGLUE_METRICS[task_name]