DEFAULT_OUTPUT_FEATURES = { "inputs": Feature(vocabulary=get_default_vocabulary(), add_eos=True), "targets": Feature(vocabulary=get_default_vocabulary(), add_eos=True) } # ==================================== C4 ====================================== # Configurable tasks used for comparisons in Raffel et al., 2019. _c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"] for config_suffix in _c4_config_suffixes: TaskRegistry.add( "c4{name}_v020_unsupervised".format( name=config_suffix.replace(".", "_")), TfdsTask, tfds_name="c4/en{config}:2.2.0".format(config=config_suffix), text_preprocessor=functools.partial(preprocessors.rekey, key_map={ "inputs": None, "targets": "text" }), token_preprocessor=preprocessors.unsupervised, output_features=DEFAULT_OUTPUT_FEATURES, metric_fns=[]) # Final pretraining task used in Raffel et al., 2019. TaskRegistry.add( "c4_v220_span_corruption", TfdsTask, tfds_name="c4/en:2.2.0".format(config=config_suffix), text_preprocessor=functools.partial(preprocessors.rekey, key_map={ "inputs": None,
DEFAULT_OUTPUT_FEATURES = { "inputs": Feature(vocabulary=get_default_vocabulary(), add_eos=True), "targets": Feature(vocabulary=get_default_vocabulary(), add_eos=True) } # ==================================== C4 ====================================== _c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"] for config_suffix in _c4_config_suffixes: TaskRegistry.add( "c4{name}_v020_unsupervised".format( name=config_suffix.replace(".", "_")), TfdsTask, tfds_name="c4/en{config}:2.2.0".format(config=config_suffix), text_preprocessor=functools.partial(preprocessors.rekey, key_map={ "inputs": None, "targets": "text" }), token_preprocessor=preprocessors.unsupervised, output_features=DEFAULT_OUTPUT_FEATURES, metric_fns=[]) # ================================ Wikipedia =================================== TaskRegistry.add("wikipedia_20190301.en_v003_unsupervised", TfdsTask, tfds_name="wikipedia/20190301.en:1.0.0", text_preprocessor=functools.partial(preprocessors.rekey, key_map={ "inputs": None, "targets": "text"
from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import t5.data.tasks # pylint: disable=unused-import from t5.data.utils import MixtureRegistry from t5.data.utils import rate_num_examples from t5.data.utils import rate_unsupervised from t5.data.utils import TaskRegistry import tensorflow_datasets as tfds # Add single-task "mixture" for each individual task for task_name in TaskRegistry.names(): MixtureRegistry.add(task_name, [(task_name, 1.0)]) # We omit WNLI because we train on WSC/DPR simple instead _glue_tasks = [ "glue_%s_v002" % b.name for b in tfds.text.glue.Glue.builder_configs.values() if "wnli" not in b.name ] _wsc_dpr_tasks = [ "dpr_v001_simple", "super_glue_wsc_v102_simple_train", "super_glue_wsc_v102_simple_eval", ] _super_glue_tasks = _wsc_dpr_tasks + [
def main(_): out_file = os.path.join( FLAGS.out_dir, "{}.{{extension}}".format(FILE_NAME_MAP[FLAGS.tfds_name])) ds = TaskRegistry.get_dataset(FLAGS.task, _FAKE_LEN, FLAGS.split, use_cached=FLAGS.cached, shuffle=False) examples = [{k: v.numpy() for k, v in ex.items()} for ex in ds] with tf.io.gfile.GFile(FLAGS.predictions_file) as f: prediction_lines = f.readlines() if FLAGS.tfds_name == "record": # record just uses raw strings predictions = [l.strip() for l in prediction_lines] else: # everything else uses Python code strings predictions = [ast.literal_eval(l.strip()) for l in prediction_lines] if FLAGS.tfds_name in USES_TEXT: if FLAGS.super: builder_configs = tfds.text.super_glue.SuperGlue.builder_configs else: builder_configs = tfds.text.glue.Glue.builder_configs label_classes = builder_configs[FLAGS.tfds_name].label_classes predictions = [label_classes[p] for p in predictions] elif FLAGS.tfds_name in ["boolq", "wic"]: predictions = [("false", "true")[p] for p in predictions] elif FLAGS.tfds_name == "wsc": predictions = [("False", "True")[p] for p in predictions] elif FLAGS.tfds_name == "multirc": # multirc is so different from the rest that we special-case everything rows = collections.defaultdict(lambda: collections.defaultdict(dict)) predictions = [int(p["value"]) for p in predictions] for p, e in zip(predictions, examples): e = { k: int(e["idx/" + k]) for k in ["paragraph", "question", "answer"] } rows[e["paragraph"]][e["question"]][e["answer"]] = p with tf.io.gfile.GFile(out_file.format(extension="jsonl"), "w") as f: for pidx, passage in rows.items(): qs = [{ "idx": i, "answers": [{ "idx": j, "label": q[j] } for j in q] } for i, q in passage.items()] f.write( json.dumps({ "idx": pidx, "passage": { "questions": qs } }) + os.linesep) return if len(predictions) != len(examples): raise ValueError( "Number of predictions in {} ({}) != of examples in {} split of {} " "({}).".format( FLAGS.predictions_file, len(predictions), FLAGS.split, FLAGS.task, len(examples), )) if "record" in FLAGS.task: indices = [ex["idx/query"] for ex in examples] else: indices = [ex.get("idx", None) for ex in examples] if FLAGS.super: lines = [ json.dumps({ "idx": int(i), "label": p }) + os.linesep for i, p in zip(indices, predictions) ] with tf.io.gfile.GFile(out_file.format(extension="jsonl"), "w") as f: for line in lines: f.write(line) else: with tf.io.gfile.GFile(out_file.format(extension="tsv"), "w") as out_file: tsv_writer = csv.writer(out_file, delimiter="\t") tsv_writer.writerow(["index", "prediction"]) tsv_writer.writerows([i, p] for i, p in zip(indices, predictions))
from t5.data import preprocessors from t5.data.utils import DEFAULT_SPM_PATH from t5.data.utils import TaskRegistry from t5.data.utils import TfdsTask from t5.evaluation import metrics # ==================================== C4 ====================================== _c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"] for config_suffix in _c4_config_suffixes: TaskRegistry.add( "c4{name}_v020_unsupervised".format( name=config_suffix.replace(".", "_")), TfdsTask, tfds_name="c4/en{config}:1.0.0".format(config=config_suffix), text_preprocessor=functools.partial(preprocessors.rekey, key_map={ "inputs": None, "targets": "text" }), token_preprocessor=preprocessors.unsupervised, sentencepiece_model_path=DEFAULT_SPM_PATH, metric_fns=[]) # ================================ Wikipedia =================================== TaskRegistry.add( "wikipedia_20190301.en_v003_unsupervised", TfdsTask, # 0.0.4 is identical to 0.0.3 except empty records removed. tfds_name="wikipedia/20190301.en:0.0.4", text_preprocessor=functools.partial(preprocessors.rekey, key_map={