DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(vocabulary=get_default_vocabulary(), add_eos=True),
    "targets": Feature(vocabulary=get_default_vocabulary(), add_eos=True)
}

# ==================================== C4 ======================================
# Configurable tasks used for comparisons in Raffel et al., 2019.
_c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"]
for config_suffix in _c4_config_suffixes:
    TaskRegistry.add(
        "c4{name}_v020_unsupervised".format(
            name=config_suffix.replace(".", "_")),
        TfdsTask,
        tfds_name="c4/en{config}:2.2.0".format(config=config_suffix),
        text_preprocessor=functools.partial(preprocessors.rekey,
                                            key_map={
                                                "inputs": None,
                                                "targets": "text"
                                            }),
        token_preprocessor=preprocessors.unsupervised,
        output_features=DEFAULT_OUTPUT_FEATURES,
        metric_fns=[])

# Final pretraining task used in Raffel et al., 2019.
TaskRegistry.add(
    "c4_v220_span_corruption",
    TfdsTask,
    tfds_name="c4/en:2.2.0".format(config=config_suffix),
    text_preprocessor=functools.partial(preprocessors.rekey,
                                        key_map={
                                            "inputs": None,
Пример #2
0
DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(vocabulary=get_default_vocabulary(), add_eos=True),
    "targets": Feature(vocabulary=get_default_vocabulary(), add_eos=True)
}

# ==================================== C4 ======================================
_c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"]
for config_suffix in _c4_config_suffixes:
    TaskRegistry.add(
        "c4{name}_v020_unsupervised".format(
            name=config_suffix.replace(".", "_")),
        TfdsTask,
        tfds_name="c4/en{config}:2.2.0".format(config=config_suffix),
        text_preprocessor=functools.partial(preprocessors.rekey,
                                            key_map={
                                                "inputs": None,
                                                "targets": "text"
                                            }),
        token_preprocessor=preprocessors.unsupervised,
        output_features=DEFAULT_OUTPUT_FEATURES,
        metric_fns=[])

# ================================ Wikipedia ===================================
TaskRegistry.add("wikipedia_20190301.en_v003_unsupervised",
                 TfdsTask,
                 tfds_name="wikipedia/20190301.en:1.0.0",
                 text_preprocessor=functools.partial(preprocessors.rekey,
                                                     key_map={
                                                         "inputs": None,
                                                         "targets": "text"
Пример #3
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

import t5.data.tasks  # pylint: disable=unused-import
from t5.data.utils import MixtureRegistry
from t5.data.utils import rate_num_examples
from t5.data.utils import rate_unsupervised
from t5.data.utils import TaskRegistry
import tensorflow_datasets as tfds

# Add single-task "mixture" for each individual task
for task_name in TaskRegistry.names():
    MixtureRegistry.add(task_name, [(task_name, 1.0)])

# We omit WNLI because we train on WSC/DPR simple instead
_glue_tasks = [
    "glue_%s_v002" % b.name
    for b in tfds.text.glue.Glue.builder_configs.values()
    if "wnli" not in b.name
]

_wsc_dpr_tasks = [
    "dpr_v001_simple",
    "super_glue_wsc_v102_simple_train",
    "super_glue_wsc_v102_simple_eval",
]
_super_glue_tasks = _wsc_dpr_tasks + [
def main(_):
    out_file = os.path.join(
        FLAGS.out_dir,
        "{}.{{extension}}".format(FILE_NAME_MAP[FLAGS.tfds_name]))

    ds = TaskRegistry.get_dataset(FLAGS.task,
                                  _FAKE_LEN,
                                  FLAGS.split,
                                  use_cached=FLAGS.cached,
                                  shuffle=False)
    examples = [{k: v.numpy() for k, v in ex.items()} for ex in ds]

    with tf.io.gfile.GFile(FLAGS.predictions_file) as f:
        prediction_lines = f.readlines()
    if FLAGS.tfds_name == "record":
        # record just uses raw strings
        predictions = [l.strip() for l in prediction_lines]
    else:
        # everything else uses Python code strings
        predictions = [ast.literal_eval(l.strip()) for l in prediction_lines]

    if FLAGS.tfds_name in USES_TEXT:
        if FLAGS.super:
            builder_configs = tfds.text.super_glue.SuperGlue.builder_configs
        else:
            builder_configs = tfds.text.glue.Glue.builder_configs
        label_classes = builder_configs[FLAGS.tfds_name].label_classes
        predictions = [label_classes[p] for p in predictions]
    elif FLAGS.tfds_name in ["boolq", "wic"]:
        predictions = [("false", "true")[p] for p in predictions]
    elif FLAGS.tfds_name == "wsc":
        predictions = [("False", "True")[p] for p in predictions]
    elif FLAGS.tfds_name == "multirc":
        # multirc is so different from the rest that we special-case everything
        rows = collections.defaultdict(lambda: collections.defaultdict(dict))
        predictions = [int(p["value"]) for p in predictions]
        for p, e in zip(predictions, examples):
            e = {
                k: int(e["idx/" + k])
                for k in ["paragraph", "question", "answer"]
            }
            rows[e["paragraph"]][e["question"]][e["answer"]] = p
        with tf.io.gfile.GFile(out_file.format(extension="jsonl"), "w") as f:
            for pidx, passage in rows.items():
                qs = [{
                    "idx": i,
                    "answers": [{
                        "idx": j,
                        "label": q[j]
                    } for j in q]
                } for i, q in passage.items()]
                f.write(
                    json.dumps({
                        "idx": pidx,
                        "passage": {
                            "questions": qs
                        }
                    }) + os.linesep)
        return

    if len(predictions) != len(examples):
        raise ValueError(
            "Number of predictions in {} ({}) != of examples in {} split of {} "
            "({}).".format(
                FLAGS.predictions_file,
                len(predictions),
                FLAGS.split,
                FLAGS.task,
                len(examples),
            ))

    if "record" in FLAGS.task:
        indices = [ex["idx/query"] for ex in examples]
    else:
        indices = [ex.get("idx", None) for ex in examples]

    if FLAGS.super:
        lines = [
            json.dumps({
                "idx": int(i),
                "label": p
            }) + os.linesep for i, p in zip(indices, predictions)
        ]
        with tf.io.gfile.GFile(out_file.format(extension="jsonl"), "w") as f:
            for line in lines:
                f.write(line)
    else:
        with tf.io.gfile.GFile(out_file.format(extension="tsv"),
                               "w") as out_file:
            tsv_writer = csv.writer(out_file, delimiter="\t")
            tsv_writer.writerow(["index", "prediction"])
            tsv_writer.writerows([i, p] for i, p in zip(indices, predictions))
Пример #5
0
from t5.data import preprocessors
from t5.data.utils import DEFAULT_SPM_PATH
from t5.data.utils import TaskRegistry
from t5.data.utils import TfdsTask
from t5.evaluation import metrics

# ==================================== C4 ======================================
_c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"]
for config_suffix in _c4_config_suffixes:
    TaskRegistry.add(
        "c4{name}_v020_unsupervised".format(
            name=config_suffix.replace(".", "_")),
        TfdsTask,
        tfds_name="c4/en{config}:1.0.0".format(config=config_suffix),
        text_preprocessor=functools.partial(preprocessors.rekey,
                                            key_map={
                                                "inputs": None,
                                                "targets": "text"
                                            }),
        token_preprocessor=preprocessors.unsupervised,
        sentencepiece_model_path=DEFAULT_SPM_PATH,
        metric_fns=[])

# ================================ Wikipedia ===================================
TaskRegistry.add(
    "wikipedia_20190301.en_v003_unsupervised",
    TfdsTask,
    # 0.0.4 is identical to 0.0.3 except empty records removed.
    tfds_name="wikipedia/20190301.en:0.0.4",
    text_preprocessor=functools.partial(preprocessors.rekey,
                                        key_map={