예제 #1
0
    def predict(self,
                input_file,
                output_file,
                checkpoint_steps=-1,
                beam_size=1,
                temperature=1.0,
                keep_top_k=-1,
                vocabulary=None):
        """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      keep_top_k: integer, a value between 1 and the vocabulary size. When
        sampling, only pick tokens that are in the k most likely.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
        # TODO(sharannarang) : It would be nice to have a function like
        # load_checkpoint that loads the model once and then call decode_from_file
        # multiple times without having to restore the checkpoint weights again.
        # This would be particularly useful in colab demo.

        if checkpoint_steps == -1:
            checkpoint_steps = utils.get_latest_checkpoint_from_dir(
                self._model_dir)

        _parse_operative_config(self._model_dir)
        with gin.unlock_config():
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)
            gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k",
                               keep_top_k)
            gin.bind_parameter("utils.decode_from_file.input_filename",
                               input_file)
            gin.bind_parameter("utils.decode_from_file.output_filename",
                               output_file)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary()
        mtf_utils.infer_model(self.estimator(vocabulary), vocabulary,
                              self._sequence_length, self.batch_size,
                              self._model_type, self._model_dir,
                              checkpoint_steps)
예제 #2
0
 def predict(self, checkpoint_step, input_file, output_file):
     # TODO(sharannarang) : Add the ability to decode from a collection of
     # strings instead of always requiring an input file.
     # TODO(sharannarang) : It would be nice to have a function like
     # load_checkpoint that loads the model once and then call decode_from_file
     # multiple times without having to restore the checkpoint weights again.
     # This would be particularly useful in colab demo.
     utils.infer_model(self._estimator, self._vocabulary,
                       self._sequence_length, self._batch_size,
                       self._model_type, self._model_dir, checkpoint_step,
                       input_file, output_file)
예제 #3
0
    def predict(self,
                input_file,
                output_file,
                checkpoint_steps=-1,
                beam_size=1,
                temperature=1.0,
                sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
                sampling_keep_top_p=1.0):
        """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      sentencepiece_model_path: str, path to the SentencePiece model file to use
        for decoding. Must match the one used during training.
    """
        # TODO(sharannarang) : It would be nice to have a function like
        # load_checkpoint that loads the model once and then call decode_from_file
        # multiple times without having to restore the checkpoint weights again.
        # This would be particularly useful in colab demo.

        if checkpoint_steps == -1:
            checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)
            gin.bind_parameter("Bitransformer.decode.sampling_keep_top_p",
                               sampling_keep_top_p)

        vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path)
        utils.infer_model(self.estimator(vocabulary), vocabulary,
                          self._sequence_length, self.batch_size,
                          self._model_type, self._model_dir, checkpoint_steps,
                          input_file, output_file)