def load_annotated_prompts() -> List[Dict]: annotated_csv_path = pkg_resources.resource_filename( __name__, "dataset_subset_template.csv") with open(annotated_csv_path) as in_file: reader = csv.DictReader(in_file) all_tasks = [row for row in reader] clean_tasks = list(filter(exclude_bad_prompts, all_tasks)) # Assign metrics non_glue_eval_sets = list(NON_GLUE_METRICS.keys()) for task in clean_tasks: if not task["do_eval"]: continue full_name = task["dataset_subset_template"] if full_name.startswith("glue"): subset = full_name.split("_")[1] task["metrics"] = get_glue_metric(subset) elif full_name.startswith("super_glue"): subset = full_name.split("_")[2] if subset in ("wsc.fixed", "multirc"): # TODO: WSC and MultiRC need special pre/postprocesing task["metrics"] = [accuracy] continue task["metrics"] = get_super_glue_metric(subset) for dataset_name in non_glue_eval_sets: if full_name.startswith(dataset_name): task["metrics"] = NON_GLUE_METRICS[dataset_name] # Skip rank_classification for now until we actually support it # if task["nontrivial_choices_hidden"]: # # Trick of plugging in answer options and rank LM probabilites as predictions. # # Required for all prompts with non_trivial_choices_hidden, # # but could be used for other tasks as well where answer choices are given. # if "metrics" not in task: # task["metrics"] = [rank_classification] # elif rank_classification not in task["metrics"]: # task["metrics"].append(rank_classification) # should be already handled by NON_GLUE_METRICS # if task['generative_true_task'] or task['generative_non_true_task']: # task['metrics'] = rouge return clean_tasks
functools.partial(preprocessors.rekey, key_map={ "premise": "sentence1", "hypothesis": "sentence2", "label": "label", "idx": "idx", }), get_glue_text_preprocessor(b) ] else: text_preprocessor = get_glue_text_preprocessor(b) TaskRegistry.add("super_glue_%s_v102" % b.name, TfdsTask, tfds_name="super_glue/%s:1.0.2" % b.name, text_preprocessor=text_preprocessor, metric_fns=get_super_glue_metric(b.name), output_features=DEFAULT_OUTPUT_FEATURES, postprocess_fn=get_glue_postprocess_fn(b), splits=["test"] if b.name in ["axb", "axg"] else None) # ======================== Definite Pronoun Resolution ========================= TaskRegistry.add( "dpr_v001_simple", TfdsTask, tfds_name="definite_pronoun_resolution:1.1.0", text_preprocessor=preprocessors.definite_pronoun_resolution_simple, metric_fns=[metrics.accuracy], output_features=DEFAULT_OUTPUT_FEATURES) # =================================== WSC ====================================== TaskRegistry.add("super_glue_wsc_v102_simple_train",