def register_seqio_task( bigbench_task_name: str, bigbench_task_path: str, bigbench_task_type: bb.BigBenchTaskType, vocab: SeqIOVocabulary, num_shots: int, bigbench_subtask_name: Optional[str] = None, max_examples: Optional[int] = None, strip_inputs: bool = True, strip_targets: bool = True, add_inputs_eos: bool = False, add_targets_eos: bool = False, json_util: json_utils.JsonUtils = json_utils.get_default_json_utils(), min_validation_examples: int = _GLOBAL_MIN_VALIDATION_EXAMPLES, additional_metrics: Optional[Sequence[seqio.MetricFnCallable]] = None, ) -> str: """Registers a BIG-bench SeqIO Task and returns the Task name.""" seqio_task_name = get_seqio_name(bigbench_task_name, bigbench_task_type, vocab, num_shots, bigbench_subtask_name, max_examples) if seqio_task_name in seqio.TaskRegistry.names(): return seqio_task_name additional_metrics = additional_metrics or [] seqio.TaskRegistry.add( seqio_task_name, source=seqio.FunctionDataSource(bb.get_dataset_fn( task_name=bigbench_task_name, task_path=bigbench_task_path, subtask_name=bigbench_subtask_name, num_shots=num_shots, bigbench_task_type=bigbench_task_type, max_examples=max_examples, json_util=json_util, min_validation_examples=min_validation_examples, ), splits=["all", "train", "validation"]), preprocessors=bb.get_preprocessors(strip_inputs=strip_inputs, strip_targets=strip_targets), output_features=bb.get_output_features( vocab=vocab.vocabulary, add_inputs_eos=add_inputs_eos, add_targets_eos=add_targets_eos), postprocess_fn=bb.get_postprocess_fn( task_name=bigbench_task_name, task_path=bigbench_task_path, subtask_name=bigbench_subtask_name, bigbench_task_type=bigbench_task_type, json_util=json_util), metric_fns=[ bb.get_metric_fn(task_name=bigbench_task_name, task_path=bigbench_task_path, subtask_name=bigbench_subtask_name, bigbench_task_type=bigbench_task_type, json_util=json_util) ] + additional_metrics) return seqio_task_name
def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None): template = all_templates.get_dataset(dataset_name, subset_name)[template_name] task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name) if task_name in CLEAN_EVAL_TASKS: metrics = EVAL_METRICS[task_name] else: metrics = [t5.evaluation.metrics.sequence_accuracy] dataset_splits = utils.get_dataset_splits(dataset_name, subset_name) split_mapping = split_mapping or {k: k for k in dataset_splits.keys()} dataset_fn = functools.partial( get_tf_dataset, seed=None, dataset_name=dataset_name, subset_name=subset_name, template=template, split_mapping=split_mapping, ) data_source = seqio.FunctionDataSource( dataset_fn, splits=list(split_mapping.keys()), num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()}, ) output_features = { "inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32), "targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32), } preprocessors = [ seqio.preprocessors.tokenize, seqio.preprocessors.append_eos, seqio.CacheDatasetPlaceholder(required=False), ] # Add train and normal eval tasks seqio.TaskRegistry.add( task_name, data_source, preprocessors=preprocessors, output_features=output_features, metric_fns=metrics, postprocess_fn=maybe_get_class_id_postprocessor(template), ) # Add rank classification eval task labels = get_label_strings(template) if labels: rank_classification_preprocessor = functools.partial( t5.data.preprocessors.rank_classification, inputs_fn=lambda ex: tf.fill((len(labels),), ex["inputs"]), targets_fn=lambda ex: labels, is_correct_fn=lambda ex: tf.equal(labels, tf.strings.strip(ex["targets"])), weight_fn=lambda ex: 1.0, ) seqio.TaskRegistry.add( task_name + "_score_eval", data_source, preprocessors=[rank_classification_preprocessor] + preprocessors, output_features=output_features, metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=len(labels))], postprocess_fn=t5.data.postprocessors.rank_classification, )
source=seqio.TfdsDataSource( tfds_name="xtreme_pawsx/{}:1.0.0".format(lang), splits=["train"]), preprocessors=[ text_preprocessor, seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], output_features=DEFAULT_OUTPUT_FEATURES, postprocess_fn=postprocess_fn, metric_fns=[metrics.accuracy]) seqio.TaskRegistry.add("mt5_pawsx_dev_test.all_langs", source=seqio.FunctionDataSource( dataset_fn=utils.pawsx_all_langs_dataset_fn, splits=["validation", "test"]), preprocessors=[ text_preprocessor, seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim, ], output_features=DEFAULT_OUTPUT_FEATURES, postprocess_fn=postprocess_fn, metric_fns=[metrics.accuracy]) # PAWSX Zero-Shot pawsx_eval = [ "mt5_pawsx_dev_test.{}".format(lang) for lang in utils.PAWSX_LANGS ] + ["mt5_pawsx_dev_test.all_langs"]
def __init__(self, name, dataset_fn, splits, text_preprocessor, metric_fns=None, postprocess_fn=None, token_preprocessor=None, output_features=None, num_input_examples=None, supports_caching=True, shuffle_buffer_size=SHUFFLE_BUFFER_SIZE, source=None): if (dataset_fn, source).count(None) != 1: raise ValueError( "Exactly one of either `dataset_fn` or `source` must be provided.") if source and (splits or num_input_examples): raise ValueError( "If `source` is provided, `splits` and `num_input_examples` should " "not also be provided to the Task.") source = source or seqio.FunctionDataSource( dataset_fn=dataset_fn, splits=splits, num_input_examples=num_input_examples) if text_preprocessor and not hasattr(text_preprocessor, "__iter__"): text_preprocessor = [text_preprocessor] if token_preprocessor and not hasattr(token_preprocessor, "__iter__"): token_preprocessor = [token_preprocessor] preprocessors = list(text_preprocessor or []) preprocessors.append(seqio.preprocessors.tokenize) if supports_caching: preprocessors.append(seqio.CacheDatasetPlaceholder()) preprocessors.extend(token_preprocessor or []) preprocessors.append(seqio.preprocessors.append_eos_after_trim) if hasattr(output_features, "__len__") and not output_features: raise ValueError("output_features must be non-empty.") if output_features is None: output_features = seqio.Feature(utils.get_default_vocabulary()) if isinstance(output_features, dict): pass elif isinstance(output_features, seqio.Feature): output_features = {k: output_features for k in _DEFAULT_FEATURE_KEYS} elif isinstance(output_features, list) and all( isinstance(f, str) for f in output_features): output_features = { k: seqio.Feature(utils.get_default_vocabulary()) for k in output_features } else: raise ValueError( "output_features must be a dict, Feature, list of str, or None") if hasattr(postprocess_fn, "__iter__"): postprocess_fns = postprocess_fn def postprocess_fn(x, **postprocess_kwargs): # pylint:disable=function-redefined for post_fn in postprocess_fns: x = post_fn(x, **postprocess_kwargs) return x super().__init__( name=name, source=source, output_features=output_features, preprocessors=preprocessors, postprocess_fn=postprocess_fn, metric_fns=metric_fns, shuffle_buffer_size=shuffle_buffer_size)