Beispiel #1
0
def load_annotated_prompts() -> List[Dict]:
    annotated_csv_path = pkg_resources.resource_filename(
        __name__, "dataset_subset_template.csv")
    with open(annotated_csv_path) as in_file:
        reader = csv.DictReader(in_file)
        all_tasks = [row for row in reader]

    clean_tasks = list(filter(exclude_bad_prompts, all_tasks))

    # Assign metrics
    non_glue_eval_sets = list(NON_GLUE_METRICS.keys())
    for task in clean_tasks:
        if not task["do_eval"]:
            continue

        full_name = task["dataset_subset_template"]
        if full_name.startswith("glue"):
            subset = full_name.split("_")[1]
            task["metrics"] = get_glue_metric(subset)
        elif full_name.startswith("super_glue"):
            subset = full_name.split("_")[2]
            if subset in ("wsc.fixed", "multirc"):
                # TODO: WSC and MultiRC need special pre/postprocesing
                task["metrics"] = [accuracy]
                continue
            task["metrics"] = get_super_glue_metric(subset)

        for dataset_name in non_glue_eval_sets:
            if full_name.startswith(dataset_name):
                task["metrics"] = NON_GLUE_METRICS[dataset_name]

        # Skip rank_classification for now until we actually support it
        # if task["nontrivial_choices_hidden"]:
        #     # Trick of plugging in answer options and rank LM probabilites as predictions.
        #     # Required for all prompts with non_trivial_choices_hidden,
        #     # but could be used for other tasks as well where answer choices are given.
        #     if "metrics" not in task:
        #         task["metrics"] = [rank_classification]
        #     elif rank_classification not in task["metrics"]:
        #         task["metrics"].append(rank_classification)

        # should be already handled by NON_GLUE_METRICS
        # if task['generative_true_task'] or task['generative_non_true_task']:
        #     task['metrics'] = rouge

    return clean_tasks
Beispiel #2
0
            functools.partial(preprocessors.rekey,
                              key_map={
                                  "premise": "sentence1",
                                  "hypothesis": "sentence2",
                                  "label": "label",
                                  "idx": "idx",
                              }),
            get_glue_text_preprocessor(b)
        ]
    else:
        text_preprocessor = get_glue_text_preprocessor(b)
    TaskRegistry.add("super_glue_%s_v102" % b.name,
                     TfdsTask,
                     tfds_name="super_glue/%s:1.0.2" % b.name,
                     text_preprocessor=text_preprocessor,
                     metric_fns=get_super_glue_metric(b.name),
                     output_features=DEFAULT_OUTPUT_FEATURES,
                     postprocess_fn=get_glue_postprocess_fn(b),
                     splits=["test"] if b.name in ["axb", "axg"] else None)

# ======================== Definite Pronoun Resolution =========================
TaskRegistry.add(
    "dpr_v001_simple",
    TfdsTask,
    tfds_name="definite_pronoun_resolution:1.1.0",
    text_preprocessor=preprocessors.definite_pronoun_resolution_simple,
    metric_fns=[metrics.accuracy],
    output_features=DEFAULT_OUTPUT_FEATURES)

# =================================== WSC ======================================
TaskRegistry.add("super_glue_wsc_v102_simple_train",