def score(self,
            inputs,
            targets,
            scores_file=None,
            checkpoint_steps=-1,
            vocabulary=None):
    """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: a string (filename), or a list of strings (targets)
      scores_file: str, path to write example scores to, one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      # The following config setting ensures we do scoring instead of inference.
      gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()

    utils.score_from_strings(self.estimator(vocabulary), vocabulary,
                             self._model_type, self.batch_size,
                             self._sequence_length, self._model_dir,
                             checkpoint_steps, inputs, targets, scores_file)
Пример #2
0
    def score(self,
              inputs,
              targets,
              scores_file=None,
              checkpoint_steps=-1,
              sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
              vocabulary=None):
        """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: a string (filename), or a list of strings (targets)
      scores_file: str, path to write example scores to, one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      sentencepiece_model_path: str, path to the SentencePiece model file to use
        for decoding. Must match the one used during training.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use a SentencePieceVocabulary with the provided
        sentencepiece_model_path.
    """
        if checkpoint_steps == -1:
            checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))

        if vocabulary is None:
            vocabulary = t5.data.SentencePieceVocabulary(
                sentencepiece_model_path)

        if isinstance(targets, str):
            tf.logging.info("scoring targets from file %s" % targets)
            utils.score_from_files(self.estimator(vocabulary), vocabulary,
                                   self._model_type, self.batch_size,
                                   self._sequence_length, self._model_dir,
                                   checkpoint_steps, inputs, targets,
                                   scores_file)
        else:
            tf.logging.info("scoring targets from list of strings")
            utils.score_from_strings(self.estimator(vocabulary), vocabulary,
                                     self._model_type, self.batch_size,
                                     self._sequence_length, self._model_dir,
                                     checkpoint_steps, inputs, targets,
                                     scores_file)
Пример #3
0
    def score(self,
              inputs=None,
              targets=None,
              mixture_or_task_name=None,
              mixture_or_task_split=None,
              scores_file=None,
              checkpoint_steps=-1,
              vocabulary=None):
        """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: optional - a string (filename), or a list of strings (targets)
      mixture_or_task_name: optional - a string, the name of the Mixture or Task
        to score on. Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.` Cannot be supplied in addition to `inputs` and
        `targets`.
      mixture_or_task_split: optional - a string, the split of the Mixture or
        Task to score on. Must be provided if scoring on a Mixture or Task.
      scores_file: optional - a string (filename), to write example scores to,
        one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.

    Returns:
      scores: a list of floating point scores matching the dataset order.
      targets: a list of scored strings matching the dataset order.
    """
        if bool(inputs or targets) == bool(mixture_or_task_name
                                           or mixture_or_task_split):
            raise ValueError(
                "Either 'inputs' and 'targets' or "
                "'mixture_or_task_name' and 'mixture_or_task_split' must be "
                "specified, but not both.")

        if checkpoint_steps == -1:
            checkpoint_steps = utils.get_latest_checkpoint_from_dir(
                self._model_dir)

        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.parse_config(self._gin_bindings)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary(mixture_or_task_name)

        estimator = self.estimator(vocabulary, score_in_predict_mode=True)
        score_postprocess_fn = functools.partial(mtf_utils.save_scores,
                                                 scores_filename=scores_file)

        if mixture_or_task_name:
            score_dataset_fn = functools.partial(
                t5.models.mesh_transformer.mesh_eval_dataset_fn,
                mixture_or_task_name=mixture_or_task_name,
            )
            return mtf_utils.score_from_dataset(
                estimator=estimator,
                vocabulary=vocabulary,
                batch_size=self.batch_size,
                sequence_length=self._sequence_length,
                model_dir=self._model_dir,
                eval_checkpoint_step=checkpoint_steps,
                dataset_split=mixture_or_task_split,
                score_dataset_fn=score_dataset_fn,
                score_postprocess_fn=score_postprocess_fn)
        else:
            return mtf_utils.score_from_strings(
                estimator=estimator,
                vocabulary=vocabulary,
                model_type=self._model_type,
                batch_size=self.batch_size,
                sequence_length=self._sequence_length,
                model_dir=self._model_dir,
                eval_checkpoint_step=checkpoint_steps,
                inputs=inputs,
                targets=targets,
                score_postprocess_fn=score_postprocess_fn)