def run_eval(mixture_or_task_name: str,
             predict_or_score_fn: PredictOrScoreFnCallable,
             checkpoint_steps: Iterable[int],
             dataset_fn: Optional[Callable[
                 [t5.data.Task, Mapping[str, int], int, str, Optional[bool]],
                 tf.data.Dataset]] = None,
             summary_dir: Optional[str] = None,
             split: Optional[str] = "validation",
             sequence_length: Optional[Mapping[str, int]] = None,
             batch_size: Optional[int] = None):
    """Run evaluation on the given mixture or task.

  Args:
    mixture_or_task_name: str, the name of the Mixture or Task to evaluate
      on. Must be pre-registered in the global `TaskRegistry` or
      `MixtureRegistry.`
    predict_or_score_fn: function, This function takes in the sequence length,
      checkpoint step, tasks to evaluate, an eval_dataset_fn, a dict mapping
      task names to cached examples, a dict mapping task names to datasets,
      and returns a list of outputs or a list of scores.
    checkpoint_steps: an iterator with integers for checkpoint steps to
      evaluate on.
    dataset_fn: function, This function takes a task and returns the dataset
      associated with it. If None, the default mesh_eval_dataset_fn is used.
    summary_dir: str, path to write TensorBoard events file summaries for
      eval. If None, use model_dir/eval_{split}.
    split: str, the mixture/task split to evaluate on.
    sequence_length: an integer or a dict from feature-key to integer
      the sequence length to pad or truncate to,
      e.g. {"inputs": 512, "targets": 128}.
      If None, sequence length is automatically computed during eval.
    batch_size: integer, used only to check that expected padding matches the
      targets. If None, the check is skipped.
  """

    vocabulary = model_utils.get_vocabulary(mixture_or_task_name)

    tasks = t5.data.get_subtasks(
        t5.data.get_mixture_or_task(mixture_or_task_name))
    tasks = seqio.evaluation.get_valid_eval_tasks(tasks, split)

    if not tasks:
        logging.info(
            "All provided tasks have metric_fns=[] or no matching splits; "
            "eval is not possible.")
        return

    if not dataset_fn:

        def _get_task_eval_dataset(task, sequence_length, split):
            # TODO(sharannarang): Replace with more general function.
            eval_datasets = mesh_transformer.mesh_eval_dataset_fn(
                sequence_length=sequence_length,
                dataset_split=split,
                mixture_or_task_name=task.name,
            )

            return eval_datasets[0].dataset_fn()

        dataset_fn = _get_task_eval_dataset

    summary_writer = None

    cached_targets, cached_datasets, max_sequence_length = \
        seqio.evaluation.get_targets_and_examples(
            tasks=tasks,
            dataset_fn=functools.partial(
                dataset_fn, split=split, sequence_length=None))

    if summary_dir:
        model_utils.write_targets_and_examples(summary_dir, cached_targets,
                                               cached_datasets)

    if sequence_length is None:
        logging.info("Setting sequence lengths to %s", max_sequence_length)
        sequence_length = max_sequence_length
    elif (sequence_length["inputs"] < max_sequence_length["inputs"]
          or sequence_length["targets"] < max_sequence_length["targets"]):
        logging.warning(
            "Given sequence lengths are insufficient for some evaluation inputs "
            "or targets. These sequences will be truncated to fit, likely "
            "leading to sub-optimal results. Consider passing `None` for "
            "sequence_length to have them be automatically computed.\n Got: %s, "
            "\n Max Lengths:%s", sequence_length, max_sequence_length)
    elif (sequence_length["inputs"] > max_sequence_length["inputs"]
          or sequence_length["targets"] > max_sequence_length["targets"]):
        logging.warning(
            "Given sequence lengths are longer than necessary for some "
            "evaluation inputs or targets, resulting in wasted computation. "
            "Consider passing `None` for sequence_length to have them be "
            "automatically computed.\n Got: %s,\n Max Lengths: %s",
            sequence_length, max_sequence_length)

    for step in checkpoint_steps:
        logging.info("Evaluating checkpoint step: %d", step)
        outputs = predict_or_score_fn(checkpoint_step=step,
                                      vocabulary=vocabulary,
                                      tasks=tasks,
                                      datasets=cached_datasets,
                                      sequence_length=sequence_length)

        for task in tasks:
            # Extract the portion of decodes corresponding to this dataset
            dataset = cached_datasets[task.name]
            dataset_size = len(cached_targets[task.name])
            predictions = [
                task.postprocess_fn(d, example=ex) for d, ex in zip(
                    outputs[:dataset_size], tfds.as_numpy(dataset))
            ]

            # Remove the used decodes.
            del outputs[:dataset_size]

            if summary_dir:
                predictions_filename = os.path.join(
                    summary_dir, "{}_{}_predictions".format(task.name, step))
                model_utils.write_lines_to_file(predictions,
                                                predictions_filename)

            with tf.Graph().as_default():
                if summary_dir:
                    summary_writer = summary_writer or tf.summary.FileWriter(
                        summary_dir)

                for metric_fn in task.metric_fns:
                    if summary_dir:
                        summary = tf.Summary()
                    targets = cached_targets[task.name]
                    metric_result = metric_fn(targets, predictions)
                    for metric_name, metric_value in metric_result.items():
                        tag = "eval/{}/{}".format(task.name, metric_name)
                        logging.info("%s at step %d: %.3f", tag, step,
                                     metric_value)
                        if summary_dir:
                            summary.value.add(tag=tag,
                                              simple_value=metric_value)
                            summary_writer.add_summary(summary, step)  # pytype: disable=attribute-error
                if summary_dir:
                    summary_writer.flush()  # pytype: disable=attribute-error

        # Only padding should remain.
        if batch_size:
            expected_pad = -sum(len(t)
                                for t in cached_targets.values()) % batch_size
            if outputs and len(outputs) != expected_pad:
                raise ValueError("{} padded outputs, {} expected.".format(
                    len(outputs), expected_pad))
Exemple #2
0
  def predict(
      self,
      inputs,
      sequence_length,
      batch_size,
      output_file=None,
      vocabulary=None,
      **generate_kwargs,
  ):
    """Evaluate the model on the given Mixture or Task.

    *Note*: If a checkpoint step is provided (i.e. `checkpoint_steps is not
    None`), the model's state will be replaced by the state in those
    checkpoints. If you have not saved your model before calling `eval`, you
    should call `save_checkpoint` before `eval` to avoid losing its parameter
    values and state.

    Args:
      inputs: list of str or str, either a list of inputs to feed into the
        model or the path to a text file that contains a single input on each
        line.
      sequence_length: dict of int, a dict mapping feature name to length.
      batch_size: int, the number of padded sequences in each batch.
      output_file: str or None, path to write out predictions or None to skip
        writing.
      vocabulary: t5.data.vocabularies.Vocabulary or dict or None. Either the
        Vocabulary to use for processing inputs and targets, a dict mapping
        "inputs" to a Vocabulary for encoding the inputs and "targets" for
        decoding the predictions, or None (default) to use a
        t5.data.SentencePieceVocabulary with the provided
        sentencepiece_model_path (as was used in all pre-trained T5 models).
      **generate_kwargs: Additional keyword arguments to pass to
        `transformers.PretrainedModel.generate()`, for example to change the
        decoding strategy. See the documentation for
        `transformers.PretrainedModel.generate()` for options.
    """
    if isinstance(inputs, str):
      if not tf.io.gfile.exists(inputs):
        raise ValueError(
            f"A str was provided for `inputs`, but the path {inputs} does not "
            "exist. If you want the model's output for {inputs}, you should "
            "feed in inputs=['{inputs}']"
        )
      with tf.io.gfile.GFile(inputs) as f:
        inputs = [l.strip() for l in f]

    if vocabulary is None:
      vocab = t5.data.get_default_vocabulary()
      vocabs = {"inputs": vocab, "targets": vocab}
    elif isinstance(vocabulary, t5.data.vocabularies.Vocabulary):
      vocabs = {"inputs": vocabulary, "targets": vocabulary}
    elif isinstance(vocabulary, dict):
      vocabs = vocabulary
    else:
      raise ValueError("vocabulary must be a dict, a Vocabulary, or None")

    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    dataset = dataset.map(
        lambda x: {"inputs": tf.cast(vocabs["inputs"].encode_tf(x), tf.int64)},
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    dataset = tokens_to_batches(
        dataset, sequence_length, batch_size, ["inputs"]
    )

    predictions = []
    for batch in dataset:
      predicted_tokens = self._model.generate(
          input_ids=self.to_tensor(batch["inputs"]), **generate_kwargs
      )
      predicted_tokens = predicted_tokens.cpu().numpy().tolist()
      predictions.extend(
          [vocabs["targets"].decode(p) for p in predicted_tokens]
      )

    for inp, pred in zip(inputs, predictions):
      logging.info("%s\n  -> %s", inp, pred)

    if output_file is not None:
      utils.write_lines_to_file(predictions, output_file)