Ejemplo n.º 1
0
    def eval(self,
             mixture_or_task_name,
             checkpoint_steps=None,
             summary_dir=None,
             split="validation",
             eval_with_score=False,
             compute_sequence_length=True):
        """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
      eval_with_score: bool, whether to evaluate using log likelihood scores of
        targets instead of decoded predictions.
      compute_sequence_length: bool, automatically compute maximum sequence
        length to use during eval mode.
    """
        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))

        summary_dir = summary_dir or os.path.join(self._model_dir,
                                                  "{}_eval".format(split))

        checkpoint_steps = utils.get_checkpoints_iterator(
            checkpoint_steps, self._model_dir)

        run_eval(mixture_or_task_name=mixture_or_task_name,
                 predict_or_score_fn=functools.partial(
                     self._predict_or_score_fn,
                     eval_with_score=eval_with_score,
                     split=split),
                 checkpoint_steps=checkpoint_steps,
                 summary_dir=summary_dir,
                 split=split,
                 sequence_length=(None if compute_sequence_length else
                                  self._sequence_length),
                 batch_size=self._batch_size)
Ejemplo n.º 2
0
  def eval(
      self,
      mixture_or_task_name,
      sequence_length,
      batch_size,
      checkpoint_steps=None,
      summary_dir=None,
      split="validation",
      compute_sequence_length=False,
      **generate_kwargs,
  ):
    """Evaluate the model on the given Mixture or Task.

    *Note*: If a checkpoint step is provided (i.e. `checkpoint_steps is not
    None`), the model's state will be replaced by the state in those
    checkpoints. If you have not saved your model before calling `eval`, you
    should call `save_checkpoint` before `eval` to avoid losing its parameter
    values and state.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate
        on.  Must be pre-registered in the global `t5.data.TaskRegistry` or
        `t5.data.MixtureRegistry.`
      sequence_length: dict of int, a dict mapping feature name to length.
      batch_size: int, the number of padded sequences in each batch.
      checkpoint_steps: int, list of ints, "all", or None. If None, eval in the
        model in its current state without loading any checkpoints. If an int
        or list of ints, evaluation will be run on the checkpoint files in
        `model_dir` whose global steps are those provided. If -1, eval on the
        latest checkpoint from the model directory. If "all", evaluate all
        checkpoints in the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/{split}_eval.
      split: str, the mixture/task split to evaluate on.
      compute_sequence_length: bool, automatically compute sequence length
        during eval mode.
      **generate_kwargs: Additional keyword arguments to pass to
        `transformers.PretrainedModel.generate()`, for example to change the
        decoding strategy. See the documentation for
        `transformers.PretrainedModel.generate()` for options.
    """

    def _predict_from_tasks(tasks, vocabulary, checkpoint_step, sequence_length,
                            datasets, **unused_kwargs):

      if isinstance(vocabulary, tuple):
        vocab = vocabulary[1]

      if checkpoint_step != self._step:
        self.load_checkpoint(checkpoint_step)
      self._model.eval()
      outputs = []
      for task in tasks:
        if compute_sequence_length:
          ds = _get_dataset(task.name, sequence_length, split, shuffle=False)
        else:
          ds = datasets[task.name]

        ds = list(tokens_to_batches(
            ds, sequence_length, batch_size, tuple(task.output_features), task))
        for batch in ds:
          predicted_tokens = self._model.generate(
              input_ids=self.to_tensor(batch["inputs"]), **generate_kwargs
          )
          predicted_tokens = predicted_tokens.cpu().numpy().tolist()
          predictions = [vocab.decode(p) for p in predicted_tokens]

          outputs.extend(predictions)

      return outputs

    if checkpoint_steps is None:
      checkpoint_steps = [self._step]
    elif isinstance(checkpoint_steps, int):
      checkpoint_steps = [checkpoint_steps]
    elif checkpoint_steps == "all":
      checkpoint_steps = self.get_all_checkpoint_steps()
    elif not isinstance(checkpoint_steps, (list, tuple)):
      raise ValueError(
          f"checkpoint_steps must be None, int or list; got {checkpoint_steps}"
      )

    summary_dir = summary_dir or os.path.join(self._model_dir, f"{split}_eval")
    tf.io.gfile.makedirs(summary_dir)

    run_eval(
        mixture_or_task_name=mixture_or_task_name,
        predict_or_score_fn=_predict_from_tasks,
        checkpoint_steps=checkpoint_steps,
        dataset_fn=functools.partial(_get_dataset, shuffle=False),
        summary_dir=summary_dir,
        split=split,
        sequence_length=None if compute_sequence_length else sequence_length,
        batch_size=batch_size)