Пример #1
0
    def finetune(self,
                 mixture_or_task_name,
                 finetune_steps,
                 pretrained_model_dir,
                 pretrained_checkpoint_step=-1,
                 split="train"):
        """Finetunes a model from an existing checkpoint.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      finetune_steps: int, the number of additional steps to train for.
      pretrained_model_dir: str, directory with pretrained model checkpoints and
        operative config.
      pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
        -1 (default), use the latest checkpoint from the pretrained model
        directory.
      split: str, the mixture/task split to finetune on.
    """
        if pretrained_checkpoint_step == -1:
            checkpoint_step = utils.get_latest_checkpoint_from_dir(
                pretrained_model_dir)
        else:
            checkpoint_step = pretrained_checkpoint_step
        _parse_operative_config(pretrained_model_dir)

        model_ckpt = "model.ckpt-" + str(checkpoint_step)
        self.train(mixture_or_task_name,
                   checkpoint_step + finetune_steps,
                   init_checkpoint=os.path.join(pretrained_model_dir,
                                                model_ckpt),
                   split=split)
Пример #2
0
    def export(self,
               export_dir=None,
               checkpoint_step=-1,
               beam_size=1,
               temperature=1.0,
               keep_top_k=-1,
               vocabulary=None,
               eval_with_score=False):
        """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      keep_top_k: integer, a value between 1 and the vocabulary size. When
        sampling, only pick tokens that are in the k most likely.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
      eval_with_score: If True, compute log-likelihood scores of targets.
        If False, do inference to generate outputs.

    Returns:
      The string path to the exported directory.
    """
        if checkpoint_step == -1:
            checkpoint_step = utils.get_latest_checkpoint_from_dir(
                self._model_dir)
        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)
            gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k",
                               keep_top_k)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary()
        model_ckpt = "model.ckpt-" + str(checkpoint_step)
        export_dir = export_dir or self._model_dir
        estimator = self.estimator(vocabulary,
                                   disable_tpu=True,
                                   score_in_predict_mode=eval_with_score)
        return mtf_utils.export_model(estimator,
                                      export_dir,
                                      vocabulary,
                                      self._sequence_length,
                                      self._model_type,
                                      batch_size=self.batch_size,
                                      checkpoint_path=os.path.join(
                                          self._model_dir, model_ckpt),
                                      eval_with_score=eval_with_score)
Пример #3
0
    def predict(self,
                input_file,
                output_file,
                checkpoint_steps=-1,
                beam_size=1,
                temperature=1.0,
                keep_top_k=-1,
                vocabulary=None):
        """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      keep_top_k: integer, a value between 1 and the vocabulary size. When
        sampling, only pick tokens that are in the k most likely.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
        # TODO(sharannarang) : It would be nice to have a function like
        # load_checkpoint that loads the model once and then call decode_from_file
        # multiple times without having to restore the checkpoint weights again.
        # This would be particularly useful in colab demo.

        if checkpoint_steps == -1:
            checkpoint_steps = utils.get_latest_checkpoint_from_dir(
                self._model_dir)

        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)
            gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k",
                               keep_top_k)
            gin.bind_parameter("utils.decode_from_file.input_filename",
                               input_file)
            gin.bind_parameter("utils.decode_from_file.output_filename",
                               output_file)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary()
        mtf_utils.infer_model(self.estimator(vocabulary), vocabulary,
                              self._sequence_length, self.batch_size,
                              self._model_type, self._model_dir,
                              checkpoint_steps)
Пример #4
0
    def score(self,
              inputs=None,
              targets=None,
              mixture_or_task_name=None,
              mixture_or_task_split=None,
              scores_file=None,
              checkpoint_steps=-1,
              vocabulary=None):
        """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: optional - a string (filename), or a list of strings (targets)
      mixture_or_task_name: optional - a string, the name of the Mixture or Task
        to score on. Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.` Cannot be supplied in addition to `inputs` and
        `targets`.
      mixture_or_task_split: optional - a string, the split of the Mixture or
        Task to score on. Must be provided if scoring on a Mixture or Task.
      scores_file: optional - a string (filename), to write example scores to,
        one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.

    Returns:
      scores: a list of floating point scores matching the dataset order.
      targets: a list of scored strings matching the dataset order.
    """
        if bool(inputs or targets) == bool(mixture_or_task_name
                                           or mixture_or_task_split):
            raise ValueError(
                "Either 'inputs' and 'targets' or "
                "'mixture_or_task_name' and 'mixture_or_task_split' must be "
                "specified, but not both.")

        if checkpoint_steps == -1:
            checkpoint_steps = utils.get_latest_checkpoint_from_dir(
                self._model_dir)

        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.parse_config(self._gin_bindings)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary(mixture_or_task_name)

        estimator = self.estimator(vocabulary, score_in_predict_mode=True)
        score_postprocess_fn = functools.partial(mtf_utils.save_scores,
                                                 scores_filename=scores_file)

        if mixture_or_task_name:
            score_dataset_fn = functools.partial(
                t5.models.mesh_transformer.mesh_eval_dataset_fn,
                mixture_or_task_name=mixture_or_task_name,
            )
            return mtf_utils.score_from_dataset(
                estimator=estimator,
                vocabulary=vocabulary,
                batch_size=self.batch_size,
                sequence_length=self._sequence_length,
                model_dir=self._model_dir,
                eval_checkpoint_step=checkpoint_steps,
                dataset_split=mixture_or_task_split,
                score_dataset_fn=score_dataset_fn,
                score_postprocess_fn=score_postprocess_fn)
        else:
            return mtf_utils.score_from_strings(
                estimator=estimator,
                vocabulary=vocabulary,
                model_type=self._model_type,
                batch_size=self.batch_size,
                sequence_length=self._sequence_length,
                model_dir=self._model_dir,
                eval_checkpoint_step=checkpoint_steps,
                inputs=inputs,
                targets=targets,
                score_postprocess_fn=score_postprocess_fn)