예제 #1
0
 def eval(self, mixture_or_task_name, checkpoint_step, summary_dir, split):
     dataset_fn = functools.partial(
         mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name)
     utils.eval_model(self._estimator, self._vocabulary,
                      self._sequence_length, self._batch_size, split,
                      self._model_dir, dataset_fn, summary_dir,
                      checkpoint_step)
  def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,
           split="validation"):
    """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
    vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name)
    dataset_fn = functools.partial(
        t5.models.mesh_transformer.mesh_eval_dataset_fn,
        mixture_or_task_name=mixture_or_task_name,
    )
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
    utils.eval_model(self.estimator(vocabulary), vocabulary,
                     self._sequence_length, self.batch_size, split,
                     self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
    def eval(self,
             mixture_or_task_name,
             checkpoint_steps=None,
             summary_dir=None,
             split="validation",
             eval_with_score=False,
             compute_sequence_length=False):
        """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
      eval_with_score: bool, whether to evaluate using log likelihood scores of
        targets instead of decoded predictions.
      compute_sequence_length: bool, automatically compute maximum sequence
        length to use during eval mode.
    """
        if checkpoint_steps == -1:
            checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
        vocabulary = _get_vocabulary(mixture_or_task_name)
        dataset_fn = functools.partial(
            t5.models.mesh_transformer.mesh_eval_dataset_fn,
            mixture_or_task_name=mixture_or_task_name,
        )
        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))
        estimator_fn = functools.partial(self.estimator,
                                         vocabulary,
                                         score_in_predict_mode=eval_with_score)
        utils.eval_model(estimator=estimator_fn,
                         vocabulary=vocabulary,
                         sequence_length=None
                         if compute_sequence_length else self._sequence_length,
                         batch_size=self.batch_size,
                         dataset_split=split,
                         model_dir=self._model_dir,
                         eval_dataset_fn=dataset_fn,
                         eval_summary_dir=summary_dir,
                         eval_checkpoint_step=checkpoint_steps,
                         eval_with_score=eval_with_score)