def get_preprocessors(strip_inputs: bool = True, strip_targets: bool = True) -> Sequence[Preprocessor]: """Returns BIG-bench preprocessors.""" preprocessors = [] if strip_inputs or strip_targets: keys_to_strip = [] if strip_inputs: keys_to_strip.append("inputs") if strip_targets: keys_to_strip.append("targets") keys_to_strip.append("answers") preprocessors.append( get_strip_preprocessor(keys_to_strip=keys_to_strip)) preprocessors.append(seqio.preprocessors.tokenize) preprocessors.append(seqio.CacheDatasetPlaceholder()) preprocessors.append(seqio.preprocessors.append_eos) return preprocessors
def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None): template = all_templates.get_dataset(dataset_name, subset_name)[template_name] task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name) if task_name in CLEAN_EVAL_TASKS: metrics = EVAL_METRICS[task_name] else: metrics = [t5.evaluation.metrics.sequence_accuracy] dataset_splits = utils.get_dataset_splits(dataset_name, subset_name) split_mapping = split_mapping or {k: k for k in dataset_splits.keys()} dataset_fn = functools.partial( get_tf_dataset, seed=None, dataset_name=dataset_name, subset_name=subset_name, template=template, split_mapping=split_mapping, ) data_source = seqio.FunctionDataSource( dataset_fn, splits=list(split_mapping.keys()), num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()}, ) output_features = { "inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32), "targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32), } preprocessors = [ seqio.preprocessors.tokenize, seqio.preprocessors.append_eos, seqio.CacheDatasetPlaceholder(required=False), ] # Add train and normal eval tasks seqio.TaskRegistry.add( task_name, data_source, preprocessors=preprocessors, output_features=output_features, metric_fns=metrics, postprocess_fn=maybe_get_class_id_postprocessor(template), ) # Add rank classification eval task labels = get_label_strings(template) if labels: rank_classification_preprocessor = functools.partial( t5.data.preprocessors.rank_classification, inputs_fn=lambda ex: tf.fill((len(labels),), ex["inputs"]), targets_fn=lambda ex: labels, is_correct_fn=lambda ex: tf.equal(labels, tf.strings.strip(ex["targets"])), weight_fn=lambda ex: 1.0, ) seqio.TaskRegistry.add( task_name + "_score_eval", data_source, preprocessors=[rank_classification_preprocessor] + preprocessors, output_features=output_features, metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=len(labels))], postprocess_fn=t5.data.postprocessors.rank_classification, )
def __init__(self, name, dataset_fn, splits, text_preprocessor, metric_fns=None, postprocess_fn=None, token_preprocessor=None, output_features=None, num_input_examples=None, supports_caching=True, shuffle_buffer_size=SHUFFLE_BUFFER_SIZE, source=None): if (dataset_fn, source).count(None) != 1: raise ValueError( "Exactly one of either `dataset_fn` or `source` must be provided.") if source and (splits or num_input_examples): raise ValueError( "If `source` is provided, `splits` and `num_input_examples` should " "not also be provided to the Task.") source = source or seqio.FunctionDataSource( dataset_fn=dataset_fn, splits=splits, num_input_examples=num_input_examples) if text_preprocessor and not hasattr(text_preprocessor, "__iter__"): text_preprocessor = [text_preprocessor] if token_preprocessor and not hasattr(token_preprocessor, "__iter__"): token_preprocessor = [token_preprocessor] preprocessors = list(text_preprocessor or []) preprocessors.append(seqio.preprocessors.tokenize) if supports_caching: preprocessors.append(seqio.CacheDatasetPlaceholder()) preprocessors.extend(token_preprocessor or []) preprocessors.append(seqio.preprocessors.append_eos_after_trim) if hasattr(output_features, "__len__") and not output_features: raise ValueError("output_features must be non-empty.") if output_features is None: output_features = seqio.Feature(utils.get_default_vocabulary()) if isinstance(output_features, dict): pass elif isinstance(output_features, seqio.Feature): output_features = {k: output_features for k in _DEFAULT_FEATURE_KEYS} elif isinstance(output_features, list) and all( isinstance(f, str) for f in output_features): output_features = { k: seqio.Feature(utils.get_default_vocabulary()) for k in output_features } else: raise ValueError( "output_features must be a dict, Feature, list of str, or None") if hasattr(postprocess_fn, "__iter__"): postprocess_fns = postprocess_fn def postprocess_fn(x, **postprocess_kwargs): # pylint:disable=function-redefined for post_fn in postprocess_fns: x = post_fn(x, **postprocess_kwargs) return x super().__init__( name=name, source=source, output_features=output_features, preprocessors=preprocessors, postprocess_fn=postprocess_fn, metric_fns=metric_fns, shuffle_buffer_size=shuffle_buffer_size)
for lang in MC4_LANGS: seqio.TaskRegistry.add("mc4.{}".format(lang.replace("-", "_")), source=seqio.TfdsDataSource( tfds_name="c4/multilingual:3.0.1", splits={ "train": lang, "validation": f"{lang}-validation" }), preprocessors=[ functools.partial(t5.data.preprocessors.rekey, key_map={ "inputs": None, "targets": "text" }), seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), t5.data.preprocessors.span_corruption, seqio.preprocessors.append_eos_after_trim, ], output_features=DEFAULT_OUTPUT_FEATURES, metric_fns=[]) mc4 = ["mc4.{}".format(lang.replace("-", "_")) for lang in MC4_LANGS] seqio.MixtureRegistry.add("mc4", mc4, default_rate=DEFAULT_MIX_RATE) # Wikipedia for lang in WIKI_LANGS: seqio.TaskRegistry.add( "mt5_wiki.{}".format(lang.replace("-", "_")), source=seqio.TfdsDataSource( tfds_name="wikipedia/20200301.{}:1.0.0".format(lang)),
def create_xnli_tasks_and_mixtures(task_prefix, task_suffix, output_features): """Helper function to create XNLI tasks and mixtures.""" if task_suffix: task_suffix = "_" + task_suffix seqio.TaskRegistry.add( f"{task_prefix}xnli_train{task_suffix}", source=seqio.TfdsDataSource( tfds_name="multi_nli:1.1.0", splits=["train"]), preprocessors=[ preprocessors.process_mnli, seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], output_features=output_features, metric_fns=[metrics.accuracy]) for xnli_lang in XNLI_LANGS: seqio.TaskRegistry.add( f"{task_prefix}xnli_dev_test{task_suffix}.{xnli_lang}", source=seqio.TfdsDataSource( tfds_name="xnli:1.1.0", splits=["validation", "test"]), preprocessors=[ functools.partial( preprocessors.process_xnli, target_languages=[xnli_lang]), seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], output_features=output_features, metric_fns=[metrics.accuracy]) if xnli_lang == "en": continue seqio.TaskRegistry.add( f"{task_prefix}xnli_translate_train{task_suffix}.{xnli_lang}", source=seqio.TfdsDataSource( tfds_name="xtreme_xnli:1.1.0", splits=["train"]), preprocessors=[ functools.partial( preprocessors.process_xnli, target_languages=[xnli_lang]), seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], output_features=output_features, metric_fns=[metrics.accuracy]) seqio.TaskRegistry.add( f"{task_prefix}xnli_dev_test{task_suffix}.all_langs", source=seqio.TfdsDataSource( tfds_name="xnli:1.1.0", splits=["validation", "test"]), preprocessors=[ functools.partial( preprocessors.process_xnli, target_languages=XNLI_LANGS), seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], output_features=output_features, metric_fns=[metrics.accuracy]) xnli_zeroshot = ([ f"{task_prefix}xnli_train{task_suffix}", f"{task_prefix}xnli_dev_test{task_suffix}.all_langs" ] + [ f"{task_prefix}xnli_dev_test{task_suffix}.{lang}" for lang in XNLI_LANGS ]) seqio.MixtureRegistry.add( f"{task_prefix}xnli_zeroshot{task_suffix}", xnli_zeroshot, default_rate=1.0) xnli_translate_train = xnli_zeroshot + [ f"{task_prefix}xnli_translate_train{task_suffix}.{lang}" for lang in XNLI_LANGS if lang != "en" ] seqio.MixtureRegistry.add( f"{task_prefix}xnli_translate_train{task_suffix}", xnli_translate_train, default_rate=1.0)