def eval(self, mixture_or_task_name, checkpoint_step, summary_dir, split): dataset_fn = functools.partial( mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name) utils.eval_model(self._estimator, self._vocabulary, self._sequence_length, self._batch_size, split, self._model_dir, dataset_fn, summary_dir, checkpoint_step)
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation"): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name) dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) utils.eval_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, split, self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation", eval_with_score=False, compute_sequence_length=False): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. eval_with_score: bool, whether to evaluate using log likelihood scores of targets instead of decoded predictions. compute_sequence_length: bool, automatically compute maximum sequence length to use during eval mode. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = _get_vocabulary(mixture_or_task_name) dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) estimator_fn = functools.partial(self.estimator, vocabulary, score_in_predict_mode=eval_with_score) utils.eval_model(estimator=estimator_fn, vocabulary=vocabulary, sequence_length=None if compute_sequence_length else self._sequence_length, batch_size=self.batch_size, dataset_split=split, model_dir=self._model_dir, eval_dataset_fn=dataset_fn, eval_summary_dir=summary_dir, eval_checkpoint_step=checkpoint_steps, eval_with_score=eval_with_score)