def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, keep_top_k=-1, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. keep_top_k: integer, a value between 1 and the vocabulary size. When sampling, only pick tokens that are in the k most likely. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = utils.get_latest_checkpoint_from_dir( self._model_dir) _parse_operative_config(self._model_dir) with gin.unlock_config(): gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k", keep_top_k) gin.bind_parameter("utils.decode_from_file.input_filename", input_file) gin.bind_parameter("utils.decode_from_file.output_filename", output_file) if vocabulary is None: vocabulary = utils.get_vocabulary() mtf_utils.infer_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps)
def predict(self, checkpoint_step, input_file, output_file): # TODO(sharannarang) : Add the ability to decode from a collection of # strings instead of always requiring an input file. # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. utils.infer_model(self._estimator, self._vocabulary, self._sequence_length, self._batch_size, self._model_type, self._model_dir, checkpoint_step, input_file, output_file)
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH, sampling_keep_top_p=1.0): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sentencepiece_model_path: str, path to the SentencePiece model file to use for decoding. Must match the one used during training. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("Bitransformer.decode.sampling_keep_top_p", sampling_keep_top_p) vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path) utils.infer_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps, input_file, output_file)