def run_eval(mixture_or_task_name: str, predict_or_score_fn: PredictOrScoreFnCallable, checkpoint_steps: Iterable[int], dataset_fn: Optional[Callable[ [t5.data.Task, Mapping[str, int], int, str, Optional[bool]], tf.data.Dataset]] = None, summary_dir: Optional[str] = None, split: Optional[str] = "validation", sequence_length: Optional[Mapping[str, int]] = None, batch_size: Optional[int] = None): """Run evaluation on the given mixture or task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` predict_or_score_fn: function, This function takes in the sequence length, checkpoint step, tasks to evaluate, an eval_dataset_fn, a dict mapping task names to cached examples, a dict mapping task names to datasets, and returns a list of outputs or a list of scores. checkpoint_steps: an iterator with integers for checkpoint steps to evaluate on. dataset_fn: function, This function takes a task and returns the dataset associated with it. If None, the default mesh_eval_dataset_fn is used. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. sequence_length: an integer or a dict from feature-key to integer the sequence length to pad or truncate to, e.g. {"inputs": 512, "targets": 128}. If None, sequence length is automatically computed during eval. batch_size: integer, used only to check that expected padding matches the targets. If None, the check is skipped. """ vocabulary = model_utils.get_vocabulary(mixture_or_task_name) tasks = t5.data.get_subtasks( t5.data.get_mixture_or_task(mixture_or_task_name)) tasks = seqio.evaluation.get_valid_eval_tasks(tasks, split) if not tasks: logging.info( "All provided tasks have metric_fns=[] or no matching splits; " "eval is not possible.") return if not dataset_fn: def _get_task_eval_dataset(task, sequence_length, split): # TODO(sharannarang): Replace with more general function. eval_datasets = mesh_transformer.mesh_eval_dataset_fn( sequence_length=sequence_length, dataset_split=split, mixture_or_task_name=task.name, ) return eval_datasets[0].dataset_fn() dataset_fn = _get_task_eval_dataset summary_writer = None cached_targets, cached_datasets, max_sequence_length = \ seqio.evaluation.get_targets_and_examples( tasks=tasks, dataset_fn=functools.partial( dataset_fn, split=split, sequence_length=None)) if summary_dir: model_utils.write_targets_and_examples(summary_dir, cached_targets, cached_datasets) if sequence_length is None: logging.info("Setting sequence lengths to %s", max_sequence_length) sequence_length = max_sequence_length elif (sequence_length["inputs"] < max_sequence_length["inputs"] or sequence_length["targets"] < max_sequence_length["targets"]): logging.warning( "Given sequence lengths are insufficient for some evaluation inputs " "or targets. These sequences will be truncated to fit, likely " "leading to sub-optimal results. Consider passing `None` for " "sequence_length to have them be automatically computed.\n Got: %s, " "\n Max Lengths:%s", sequence_length, max_sequence_length) elif (sequence_length["inputs"] > max_sequence_length["inputs"] or sequence_length["targets"] > max_sequence_length["targets"]): logging.warning( "Given sequence lengths are longer than necessary for some " "evaluation inputs or targets, resulting in wasted computation. " "Consider passing `None` for sequence_length to have them be " "automatically computed.\n Got: %s,\n Max Lengths: %s", sequence_length, max_sequence_length) for step in checkpoint_steps: logging.info("Evaluating checkpoint step: %d", step) outputs = predict_or_score_fn(checkpoint_step=step, vocabulary=vocabulary, tasks=tasks, datasets=cached_datasets, sequence_length=sequence_length) for task in tasks: # Extract the portion of decodes corresponding to this dataset dataset = cached_datasets[task.name] dataset_size = len(cached_targets[task.name]) predictions = [ task.postprocess_fn(d, example=ex) for d, ex in zip( outputs[:dataset_size], tfds.as_numpy(dataset)) ] # Remove the used decodes. del outputs[:dataset_size] if summary_dir: predictions_filename = os.path.join( summary_dir, "{}_{}_predictions".format(task.name, step)) model_utils.write_lines_to_file(predictions, predictions_filename) with tf.Graph().as_default(): if summary_dir: summary_writer = summary_writer or tf.summary.FileWriter( summary_dir) for metric_fn in task.metric_fns: if summary_dir: summary = tf.Summary() targets = cached_targets[task.name] metric_result = metric_fn(targets, predictions) for metric_name, metric_value in metric_result.items(): tag = "eval/{}/{}".format(task.name, metric_name) logging.info("%s at step %d: %.3f", tag, step, metric_value) if summary_dir: summary.value.add(tag=tag, simple_value=metric_value) summary_writer.add_summary(summary, step) # pytype: disable=attribute-error if summary_dir: summary_writer.flush() # pytype: disable=attribute-error # Only padding should remain. if batch_size: expected_pad = -sum(len(t) for t in cached_targets.values()) % batch_size if outputs and len(outputs) != expected_pad: raise ValueError("{} padded outputs, {} expected.".format( len(outputs), expected_pad))
def predict( self, inputs, sequence_length, batch_size, output_file=None, vocabulary=None, **generate_kwargs, ): """Evaluate the model on the given Mixture or Task. *Note*: If a checkpoint step is provided (i.e. `checkpoint_steps is not None`), the model's state will be replaced by the state in those checkpoints. If you have not saved your model before calling `eval`, you should call `save_checkpoint` before `eval` to avoid losing its parameter values and state. Args: inputs: list of str or str, either a list of inputs to feed into the model or the path to a text file that contains a single input on each line. sequence_length: dict of int, a dict mapping feature name to length. batch_size: int, the number of padded sequences in each batch. output_file: str or None, path to write out predictions or None to skip writing. vocabulary: t5.data.vocabularies.Vocabulary or dict or None. Either the Vocabulary to use for processing inputs and targets, a dict mapping "inputs" to a Vocabulary for encoding the inputs and "targets" for decoding the predictions, or None (default) to use a t5.data.SentencePieceVocabulary with the provided sentencepiece_model_path (as was used in all pre-trained T5 models). **generate_kwargs: Additional keyword arguments to pass to `transformers.PretrainedModel.generate()`, for example to change the decoding strategy. See the documentation for `transformers.PretrainedModel.generate()` for options. """ if isinstance(inputs, str): if not tf.io.gfile.exists(inputs): raise ValueError( f"A str was provided for `inputs`, but the path {inputs} does not " "exist. If you want the model's output for {inputs}, you should " "feed in inputs=['{inputs}']" ) with tf.io.gfile.GFile(inputs) as f: inputs = [l.strip() for l in f] if vocabulary is None: vocab = t5.data.get_default_vocabulary() vocabs = {"inputs": vocab, "targets": vocab} elif isinstance(vocabulary, t5.data.vocabularies.Vocabulary): vocabs = {"inputs": vocabulary, "targets": vocabulary} elif isinstance(vocabulary, dict): vocabs = vocabulary else: raise ValueError("vocabulary must be a dict, a Vocabulary, or None") dataset = tf.data.Dataset.from_tensor_slices(inputs) dataset = dataset.map( lambda x: {"inputs": tf.cast(vocabs["inputs"].encode_tf(x), tf.int64)}, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) dataset = tokens_to_batches( dataset, sequence_length, batch_size, ["inputs"] ) predictions = [] for batch in dataset: predicted_tokens = self._model.generate( input_ids=self.to_tensor(batch["inputs"]), **generate_kwargs ) predicted_tokens = predicted_tokens.cpu().numpy().tolist() predictions.extend( [vocabs["targets"].decode(p) for p in predicted_tokens] ) for inp, pred in zip(inputs, predictions): logging.info("%s\n -> %s", inp, pred) if output_file is not None: utils.write_lines_to_file(predictions, output_file)