Esempio n. 1
0
                 text_preprocessor=functools.partial(preprocessors.rekey,
                                                     key_map={
                                                         "inputs": None,
                                                         "targets": "text"
                                                     }),
                 token_preprocessor=preprocessors.unsupervised,
                 output_features=DEFAULT_OUTPUT_FEATURES,
                 metric_fns=[])

# =================================== GLUE =====================================
for b in tfds.text.glue.Glue.builder_configs.values():
    TaskRegistry.add(
        "glue_%s_v002" % b.name,
        TfdsTask,
        tfds_name="glue/%s:1.0.0" % b.name,
        text_preprocessor=get_glue_text_preprocessor(b),
        metric_fns=get_glue_metric(b.name),
        output_features=DEFAULT_OUTPUT_FEATURES,
        postprocess_fn=get_glue_postprocess_fn(b),
        splits=["test"] if b.name == "ax" else None,
    )

# =============================== CNN DailyMail ================================
TaskRegistry.add("cnn_dailymail_v002",
                 TfdsTask,
                 tfds_name="cnn_dailymail:1.0.0",
                 text_preprocessor=functools.partial(preprocessors.summarize,
                                                     article_key="article",
                                                     summary_key="highlights"),
                 metric_fns=[metrics.rouge],
                 output_features=DEFAULT_OUTPUT_FEATURES)
Esempio n. 2
0
#============================ SuperGLUE English Vocab===========================
for b in tfds.text.super_glue.SuperGlue.builder_configs.values():
    # We use a simplified version of WSC, defined below
    if "wsc" in b.name:
        continue
    if b.name == "axb":
        text_preprocessor = [
            functools.partial(t5_preprocessors.rekey,
                              key_map={
                                  "premise": "sentence1",
                                  "hypothesis": "sentence2",
                                  "label": "label",
                                  "idx": "idx",
                              }),
            get_glue_text_preprocessor(b)
        ]
    else:
        text_preprocessor = [get_glue_text_preprocessor(b)]
    TaskRegistry.add(
        "super_glue_%s_v102_envocab" % b.name,
        source=seqio.TfdsDataSource(
            tfds_name="super_glue/%s:1.0.2" % b.name,
            splits=["test"] if b.name in ["axb", "axg"] else None),
        preprocessors=text_preprocessor + [
            seqio.preprocessors.tokenize,
            seqio.CacheDatasetPlaceholder(),
            seqio.preprocessors.append_eos_after_trim,
        ],
        metric_fns=get_super_glue_metric(b.name),
        output_features=EN_VOCAB_OUTPUT_FEATURES,