def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir, pretrained_checkpoint_step=-1, split="train"): """Finetunes a model from an existing checkpoint. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` finetune_steps: int, the number of additional steps to train for. pretrained_model_dir: str, directory with pretrained model checkpoints and operative config. pretrained_checkpoint_step: int, checkpoint to initialize weights from. If -1 (default), use the latest checkpoint from the pretrained model directory. split: str, the mixture/task split to finetune on. """ if pretrained_checkpoint_step == -1: checkpoint_step = utils.get_latest_checkpoint_from_dir( pretrained_model_dir) else: checkpoint_step = pretrained_checkpoint_step _parse_operative_config(pretrained_model_dir) model_ckpt = "model.ckpt-" + str(checkpoint_step) self.train(mixture_or_task_name, checkpoint_step + finetune_steps, init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt), split=split)
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, keep_top_k=-1, vocabulary=None, eval_with_score=False): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. keep_top_k: integer, a value between 1 and the vocabulary size. When sampling, only pick tokens that are in the k most likely. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. eval_with_score: If True, compute log-likelihood scores of targets. If False, do inference to generate outputs. Returns: The string path to the exported directory. """ if checkpoint_step == -1: checkpoint_step = utils.get_latest_checkpoint_from_dir( self._model_dir) _parse_operative_config(self._model_dir) with gin.unlock_config(): gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k", keep_top_k) if vocabulary is None: vocabulary = utils.get_vocabulary() model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir estimator = self.estimator(vocabulary, disable_tpu=True, score_in_predict_mode=eval_with_score) return mtf_utils.export_model(estimator, export_dir, vocabulary, self._sequence_length, self._model_type, batch_size=self.batch_size, checkpoint_path=os.path.join( self._model_dir, model_ckpt), eval_with_score=eval_with_score)
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, keep_top_k=-1, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. keep_top_k: integer, a value between 1 and the vocabulary size. When sampling, only pick tokens that are in the k most likely. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = utils.get_latest_checkpoint_from_dir( self._model_dir) _parse_operative_config(self._model_dir) with gin.unlock_config(): gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("Bitransformer.decode.sampling_keep_top_k", keep_top_k) gin.bind_parameter("utils.decode_from_file.input_filename", input_file) gin.bind_parameter("utils.decode_from_file.output_filename", output_file) if vocabulary is None: vocabulary = utils.get_vocabulary() mtf_utils.infer_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps)
def score(self, inputs=None, targets=None, mixture_or_task_name=None, mixture_or_task_split=None, scores_file=None, checkpoint_steps=-1, vocabulary=None): """Computes log-likelihood of target per example in targets. Args: inputs: optional - a string (filename), or a list of strings (inputs) targets: optional - a string (filename), or a list of strings (targets) mixture_or_task_name: optional - a string, the name of the Mixture or Task to score on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` Cannot be supplied in addition to `inputs` and `targets`. mixture_or_task_split: optional - a string, the split of the Mixture or Task to score on. Must be provided if scoring on a Mixture or Task. scores_file: optional - a string (filename), to write example scores to, one per line. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. Returns: scores: a list of floating point scores matching the dataset order. targets: a list of scored strings matching the dataset order. """ if bool(inputs or targets) == bool(mixture_or_task_name or mixture_or_task_split): raise ValueError( "Either 'inputs' and 'targets' or " "'mixture_or_task_name' and 'mixture_or_task_split' must be " "specified, but not both.") if checkpoint_steps == -1: checkpoint_steps = utils.get_latest_checkpoint_from_dir( self._model_dir) _parse_operative_config(self._model_dir) with gin.unlock_config(): gin.parse_config(self._gin_bindings) if vocabulary is None: vocabulary = utils.get_vocabulary(mixture_or_task_name) estimator = self.estimator(vocabulary, score_in_predict_mode=True) score_postprocess_fn = functools.partial(mtf_utils.save_scores, scores_filename=scores_file) if mixture_or_task_name: score_dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) return mtf_utils.score_from_dataset( estimator=estimator, vocabulary=vocabulary, batch_size=self.batch_size, sequence_length=self._sequence_length, model_dir=self._model_dir, eval_checkpoint_step=checkpoint_steps, dataset_split=mixture_or_task_split, score_dataset_fn=score_dataset_fn, score_postprocess_fn=score_postprocess_fn) else: return mtf_utils.score_from_strings( estimator=estimator, vocabulary=vocabulary, model_type=self._model_type, batch_size=self.batch_size, sequence_length=self._sequence_length, model_dir=self._model_dir, eval_checkpoint_step=checkpoint_steps, inputs=inputs, targets=targets, score_postprocess_fn=score_postprocess_fn)